Run all callbacks in parallel

This commit is contained in:
James McGuire 2018-05-14 11:45:57 -07:00
parent edafec0fc7
commit fc944ef429
5 changed files with 79 additions and 45 deletions

14
irc.go

@ -21,7 +21,6 @@ package irc
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -80,10 +79,6 @@ func (irc *Connection) readLoop() {
irc.lastMessageMutex.Unlock() irc.lastMessageMutex.Unlock()
event, err := parseToEvent(msg) event, err := parseToEvent(msg)
event.Connection = irc event.Connection = irc
event.Ctx = context.Background()
if irc.CallbackTimeout != 0 {
event.Ctx, _ = context.WithTimeout(event.Ctx, irc.CallbackTimeout)
}
if err == nil { if err == nil {
/* XXX: len(args) == 0: args should be empty */ /* XXX: len(args) == 0: args should be empty */
irc.RunCallbacks(event) irc.RunCallbacks(event)
@ -236,7 +231,9 @@ func (irc *Connection) Loop() {
errChan := irc.ErrorChan() errChan := irc.ErrorChan()
for !irc.isQuitting() { for !irc.isQuitting() {
err := <-errChan err := <-errChan
close(irc.end) if irc.end != nil {
close(irc.end)
}
irc.Wait() irc.Wait()
for !irc.isQuitting() { for !irc.isQuitting() {
irc.Log.Printf("Error, disconnected: %s\n", err) irc.Log.Printf("Error, disconnected: %s\n", err)
@ -400,13 +397,14 @@ func (irc *Connection) Disconnect() {
close(irc.end) close(irc.end)
} }
irc.Wait()
irc.end = nil irc.end = nil
if irc.pwrite != nil { if irc.pwrite != nil {
close(irc.pwrite) close(irc.pwrite)
} }
irc.Wait()
if irc.socket != nil { if irc.socket != nil {
irc.socket.Close() irc.socket.Close()
} }
@ -473,7 +471,7 @@ func (irc *Connection) Connect(server string) error {
irc.Log.Printf("Connected to %s (%s)\n", irc.Server, irc.socket.RemoteAddr()) irc.Log.Printf("Connected to %s (%s)\n", irc.Server, irc.socket.RemoteAddr())
irc.pwrite = make(chan string, 10) irc.pwrite = make(chan string, 10)
irc.Error = make(chan error, 2) irc.Error = make(chan error, 10)
irc.Add(3) irc.Add(3)
go irc.readLoop() go irc.readLoop()
go irc.writeLoop() go irc.writeLoop()

@ -1,7 +1,7 @@
package irc package irc
import ( import (
"fmt" "context"
"reflect" "reflect"
"runtime" "runtime"
"strconv" "strconv"
@ -129,17 +129,20 @@ func (irc *Connection) RunCallbacks(event *Event) {
} }
irc.eventsMutex.Lock() irc.eventsMutex.Lock()
callbacks := []func(*Event){} callbacks := make(map[int]func(*Event))
eventCallbacks, ok := irc.events[event.Code] eventCallbacks, ok := irc.events[event.Code]
id := 0
if ok { if ok {
for _, callback := range eventCallbacks { for _, callback := range eventCallbacks {
callbacks = append(callbacks, callback) callbacks[id] = callback
id++
} }
} }
allCallbacks, ok := irc.events["*"] allCallbacks, ok := irc.events["*"]
if ok { if ok {
for _, callback := range allCallbacks { for _, callback := range allCallbacks {
callbacks = append(callbacks, callback) callbacks[id] = callback
id++
} }
} }
irc.eventsMutex.Unlock() irc.eventsMutex.Unlock()
@ -148,36 +151,42 @@ func (irc *Connection) RunCallbacks(event *Event) {
irc.Log.Printf("%v (%v) >> %#v\n", event.Code, len(callbacks), event) irc.Log.Printf("%v (%v) >> %#v\n", event.Code, len(callbacks), event)
} }
done := make(chan bool) event.Ctx = context.Background()
possibleLogs := []string{} if irc.CallbackTimeout != 0 {
for i, callback := range callbacks { event.Ctx, _ = context.WithTimeout(event.Ctx, irc.CallbackTimeout)
go func(done chan bool) { }
callback(event)
done <- true done := make(chan int)
}(done) for id, callback := range callbacks {
callbackName := getFunctionName(callback) go func(id int, done chan<- int, cb func(*Event), event *Event) {
start := time.Now() start := time.Now()
cb(event)
select {
case done <- id:
case <-event.Ctx.Done(): // If we timed out, report how long until we eventually finished
irc.Log.Printf("Canceled callback %s finished in %s >> %#v\n",
getFunctionName(cb),
time.Since(start),
event,
)
}
}(id, done, callback, event)
}
for len(callbacks) > 0 {
select { select {
case jobID := <-done:
delete(callbacks, jobID)
case <-event.Ctx.Done(): // context timed out! case <-event.Ctx.Done(): // context timed out!
irc.Log.Printf("TIMEOUT: %s timeout expired while executing %s, abandoning remaining callbacks", irc.CallbackTimeout, callbackName) timedOutCallbacks := []string{}
for _, cb := range callbacks { // Everything left here did not finish
// If we timed out let's include context for how long each previous handler took timedOutCallbacks = append(timedOutCallbacks, getFunctionName(cb))
for _, logItem := range possibleLogs {
irc.Log.Println(logItem)
} }
irc.Log.Printf("Callback %s ran for %s prior to timeout", callbackName, time.Since(start)) irc.Log.Printf("Timeout while waiting for %d callback(s) to finish (%s)\n",
if len(callbacks) > i { len(callbacks),
for _, callback := range callbacks[i+1:] { strings.Join(timedOutCallbacks, ", "),
irc.Log.Printf("Callback %s did not run", getFunctionName(callback)) )
}
}
// At this point our context has expired and it's not safe to execute anything else, lets bail.
return return
case <-done:
elapsed := time.Since(start)
logMsg := fmt.Sprintf("Callback %s took %s", getFunctionName(callback), elapsed)
possibleLogs = append(possibleLogs, logMsg)
} }
} }
} }

@ -1,7 +1,6 @@
package irc package irc
import ( import (
"fmt"
"testing" "testing"
) )
@ -37,7 +36,7 @@ func TestParseTags(t *testing.T) {
t.Fatal("Parse PRIVMSG with tags failed") t.Fatal("Parse PRIVMSG with tags failed")
} }
checkResult(t, event) checkResult(t, event)
fmt.Printf("%s", event.Tags) t.Logf("%s", event.Tags)
if _, ok := event.Tags["tag"]; !ok { if _, ok := event.Tags["tag"]; !ok {
t.Fatal("Parsing value-less tag failed") t.Fatal("Parsing value-less tag failed")
} }

@ -16,6 +16,9 @@ func TestConnectionSASL(t *testing.T) {
if SASLLogin == "" { if SASLLogin == "" {
t.Skip("Define SASLLogin and SASLPasword environment varables to test SASL") t.Skip("Define SASLLogin and SASLPasword environment varables to test SASL")
} }
if testing.Short() {
t.Skip("skipping test in short mode.")
}
irccon := IRC("go-eventirc", "go-eventirc") irccon := IRC("go-eventirc", "go-eventirc")
irccon.VerboseCallbackHandler = true irccon.VerboseCallbackHandler = true
irccon.Debug = true irccon.Debug = true

@ -3,6 +3,7 @@ package irc
import ( import (
"crypto/tls" "crypto/tls"
"math/rand" "math/rand"
"sort"
"testing" "testing"
"time" "time"
) )
@ -115,7 +116,7 @@ func TestRemoveCallback(t *testing.T) {
results = append(results, <-done) results = append(results, <-done)
results = append(results, <-done) results = append(results, <-done)
if len(results) != 2 || results[0] == 2 || results[1] == 2 { if !compareResults(results, 1, 3) {
t.Error("Callback 2 not removed") t.Error("Callback 2 not removed")
} }
} }
@ -138,7 +139,7 @@ func TestWildcardCallback(t *testing.T) {
results = append(results, <-done) results = append(results, <-done)
results = append(results, <-done) results = append(results, <-done)
if len(results) != 2 || !(results[0] == 1 && results[1] == 2) { if !compareResults(results, 1, 2) {
t.Error("Wildcard callback not called") t.Error("Wildcard callback not called")
} }
} }
@ -164,7 +165,7 @@ func TestClearCallback(t *testing.T) {
results = append(results, <-done) results = append(results, <-done)
results = append(results, <-done) results = append(results, <-done)
if len(results) != 2 || !(results[0] == 2 && results[1] == 3) { if !compareResults(results, 2, 3) {
t.Error("Callbacks not cleared") t.Error("Callbacks not cleared")
} }
} }
@ -185,6 +186,9 @@ func TestIRCemptyUser(t *testing.T) {
} }
} }
func TestConnection(t *testing.T) { func TestConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
ircnick1 := randStr(8) ircnick1 := randStr(8)
ircnick2 := randStr(8) ircnick2 := randStr(8)
@ -266,8 +270,12 @@ func TestConnection(t *testing.T) {
} }
func TestReconnect(t *testing.T) { func TestReconnect(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ircnick1 := randStr(8) ircnick1 := randStr(8)
irccon := IRC(ircnick1, "IRCTestRe") irccon := IRC(ircnick1, "IRCTestRe")
irccon.PingFreq = time.Second * 3
debugTest(irccon) debugTest(irccon)
connects := 0 connects := 0
@ -277,11 +285,11 @@ func TestReconnect(t *testing.T) {
connects += 1 connects += 1
if connects > 2 { if connects > 2 {
irccon.Privmsgf(channel, "Connection nr %d (test done)\n", connects) irccon.Privmsgf(channel, "Connection nr %d (test done)\n", connects)
irccon.Quit() go irccon.Quit()
} else { } else {
irccon.Privmsgf(channel, "Connection nr %d\n", connects) irccon.Privmsgf(channel, "Connection nr %d\n", connects)
time.Sleep(100) //Need to let the thraed actually send before closing socket time.Sleep(100) //Need to let the thraed actually send before closing socket
irccon.Disconnect() go irccon.Disconnect()
} }
}) })
@ -298,6 +306,9 @@ func TestReconnect(t *testing.T) {
} }
func TestConnectionSSL(t *testing.T) { func TestConnectionSSL(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ircnick1 := randStr(8) ircnick1 := randStr(8)
irccon := IRC(ircnick1, "IRCTestSSL") irccon := IRC(ircnick1, "IRCTestSSL")
debugTest(irccon) debugTest(irccon)
@ -333,3 +344,17 @@ func debugTest(irccon *Connection) *Connection {
irccon.Debug = debug_tests irccon.Debug = debug_tests
return irccon return irccon
} }
func compareResults(received []int, desired ...int) bool {
if len(desired) != len(received) {
return false
}
sort.IntSlice(desired).Sort()
sort.IntSlice(received).Sort()
for i := 0; i < len(desired); i++ {
if desired[i] != received[i] {
return false
}
}
return true
}