Run all callbacks in parallel

This commit is contained in:
James McGuire 2018-05-14 11:45:57 -07:00
parent edafec0fc7
commit fc944ef429
5 changed files with 79 additions and 45 deletions

14
irc.go

@ -21,7 +21,6 @@ package irc
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@ -80,10 +79,6 @@ func (irc *Connection) readLoop() {
irc.lastMessageMutex.Unlock()
event, err := parseToEvent(msg)
event.Connection = irc
event.Ctx = context.Background()
if irc.CallbackTimeout != 0 {
event.Ctx, _ = context.WithTimeout(event.Ctx, irc.CallbackTimeout)
}
if err == nil {
/* XXX: len(args) == 0: args should be empty */
irc.RunCallbacks(event)
@ -236,7 +231,9 @@ func (irc *Connection) Loop() {
errChan := irc.ErrorChan()
for !irc.isQuitting() {
err := <-errChan
close(irc.end)
if irc.end != nil {
close(irc.end)
}
irc.Wait()
for !irc.isQuitting() {
irc.Log.Printf("Error, disconnected: %s\n", err)
@ -400,13 +397,14 @@ func (irc *Connection) Disconnect() {
close(irc.end)
}
irc.Wait()
irc.end = nil
if irc.pwrite != nil {
close(irc.pwrite)
}
irc.Wait()
if irc.socket != nil {
irc.socket.Close()
}
@ -473,7 +471,7 @@ func (irc *Connection) Connect(server string) error {
irc.Log.Printf("Connected to %s (%s)\n", irc.Server, irc.socket.RemoteAddr())
irc.pwrite = make(chan string, 10)
irc.Error = make(chan error, 2)
irc.Error = make(chan error, 10)
irc.Add(3)
go irc.readLoop()
go irc.writeLoop()

@ -1,7 +1,7 @@
package irc
import (
"fmt"
"context"
"reflect"
"runtime"
"strconv"
@ -129,17 +129,20 @@ func (irc *Connection) RunCallbacks(event *Event) {
}
irc.eventsMutex.Lock()
callbacks := []func(*Event){}
callbacks := make(map[int]func(*Event))
eventCallbacks, ok := irc.events[event.Code]
id := 0
if ok {
for _, callback := range eventCallbacks {
callbacks = append(callbacks, callback)
callbacks[id] = callback
id++
}
}
allCallbacks, ok := irc.events["*"]
if ok {
for _, callback := range allCallbacks {
callbacks = append(callbacks, callback)
callbacks[id] = callback
id++
}
}
irc.eventsMutex.Unlock()
@ -148,36 +151,42 @@ func (irc *Connection) RunCallbacks(event *Event) {
irc.Log.Printf("%v (%v) >> %#v\n", event.Code, len(callbacks), event)
}
done := make(chan bool)
possibleLogs := []string{}
for i, callback := range callbacks {
go func(done chan bool) {
callback(event)
done <- true
}(done)
callbackName := getFunctionName(callback)
start := time.Now()
event.Ctx = context.Background()
if irc.CallbackTimeout != 0 {
event.Ctx, _ = context.WithTimeout(event.Ctx, irc.CallbackTimeout)
}
done := make(chan int)
for id, callback := range callbacks {
go func(id int, done chan<- int, cb func(*Event), event *Event) {
start := time.Now()
cb(event)
select {
case done <- id:
case <-event.Ctx.Done(): // If we timed out, report how long until we eventually finished
irc.Log.Printf("Canceled callback %s finished in %s >> %#v\n",
getFunctionName(cb),
time.Since(start),
event,
)
}
}(id, done, callback, event)
}
for len(callbacks) > 0 {
select {
case jobID := <-done:
delete(callbacks, jobID)
case <-event.Ctx.Done(): // context timed out!
irc.Log.Printf("TIMEOUT: %s timeout expired while executing %s, abandoning remaining callbacks", irc.CallbackTimeout, callbackName)
// If we timed out let's include context for how long each previous handler took
for _, logItem := range possibleLogs {
irc.Log.Println(logItem)
timedOutCallbacks := []string{}
for _, cb := range callbacks { // Everything left here did not finish
timedOutCallbacks = append(timedOutCallbacks, getFunctionName(cb))
}
irc.Log.Printf("Callback %s ran for %s prior to timeout", callbackName, time.Since(start))
if len(callbacks) > i {
for _, callback := range callbacks[i+1:] {
irc.Log.Printf("Callback %s did not run", getFunctionName(callback))
}
}
// At this point our context has expired and it's not safe to execute anything else, lets bail.
irc.Log.Printf("Timeout while waiting for %d callback(s) to finish (%s)\n",
len(callbacks),
strings.Join(timedOutCallbacks, ", "),
)
return
case <-done:
elapsed := time.Since(start)
logMsg := fmt.Sprintf("Callback %s took %s", getFunctionName(callback), elapsed)
possibleLogs = append(possibleLogs, logMsg)
}
}
}

@ -1,7 +1,6 @@
package irc
import (
"fmt"
"testing"
)
@ -37,7 +36,7 @@ func TestParseTags(t *testing.T) {
t.Fatal("Parse PRIVMSG with tags failed")
}
checkResult(t, event)
fmt.Printf("%s", event.Tags)
t.Logf("%s", event.Tags)
if _, ok := event.Tags["tag"]; !ok {
t.Fatal("Parsing value-less tag failed")
}

@ -16,6 +16,9 @@ func TestConnectionSASL(t *testing.T) {
if SASLLogin == "" {
t.Skip("Define SASLLogin and SASLPasword environment varables to test SASL")
}
if testing.Short() {
t.Skip("skipping test in short mode.")
}
irccon := IRC("go-eventirc", "go-eventirc")
irccon.VerboseCallbackHandler = true
irccon.Debug = true

@ -3,6 +3,7 @@ package irc
import (
"crypto/tls"
"math/rand"
"sort"
"testing"
"time"
)
@ -115,7 +116,7 @@ func TestRemoveCallback(t *testing.T) {
results = append(results, <-done)
results = append(results, <-done)
if len(results) != 2 || results[0] == 2 || results[1] == 2 {
if !compareResults(results, 1, 3) {
t.Error("Callback 2 not removed")
}
}
@ -138,7 +139,7 @@ func TestWildcardCallback(t *testing.T) {
results = append(results, <-done)
results = append(results, <-done)
if len(results) != 2 || !(results[0] == 1 && results[1] == 2) {
if !compareResults(results, 1, 2) {
t.Error("Wildcard callback not called")
}
}
@ -164,7 +165,7 @@ func TestClearCallback(t *testing.T) {
results = append(results, <-done)
results = append(results, <-done)
if len(results) != 2 || !(results[0] == 2 && results[1] == 3) {
if !compareResults(results, 2, 3) {
t.Error("Callbacks not cleared")
}
}
@ -185,6 +186,9 @@ func TestIRCemptyUser(t *testing.T) {
}
}
func TestConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
rand.Seed(time.Now().UnixNano())
ircnick1 := randStr(8)
ircnick2 := randStr(8)
@ -266,8 +270,12 @@ func TestConnection(t *testing.T) {
}
func TestReconnect(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ircnick1 := randStr(8)
irccon := IRC(ircnick1, "IRCTestRe")
irccon.PingFreq = time.Second * 3
debugTest(irccon)
connects := 0
@ -277,11 +285,11 @@ func TestReconnect(t *testing.T) {
connects += 1
if connects > 2 {
irccon.Privmsgf(channel, "Connection nr %d (test done)\n", connects)
irccon.Quit()
go irccon.Quit()
} else {
irccon.Privmsgf(channel, "Connection nr %d\n", connects)
time.Sleep(100) //Need to let the thraed actually send before closing socket
irccon.Disconnect()
go irccon.Disconnect()
}
})
@ -298,6 +306,9 @@ func TestReconnect(t *testing.T) {
}
func TestConnectionSSL(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ircnick1 := randStr(8)
irccon := IRC(ircnick1, "IRCTestSSL")
debugTest(irccon)
@ -333,3 +344,17 @@ func debugTest(irccon *Connection) *Connection {
irccon.Debug = debug_tests
return irccon
}
func compareResults(received []int, desired ...int) bool {
if len(desired) != len(received) {
return false
}
sort.IntSlice(desired).Sort()
sort.IntSlice(received).Sort()
for i := 0; i < len(desired); i++ {
if desired[i] != received[i] {
return false
}
}
return true
}