From e9d62eeee7589d1d3c7b19cef88d8b13d85d1adb Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 10 Mar 2021 02:21:53 -0500 Subject: [PATCH 1/6] support batch and labeled-response --- ircevent/irc.go | 40 ++++ ircevent/irc_callback.go | 283 ++++++++++++++++++++++++++- ircevent/irc_labeledresponse_test.go | 282 ++++++++++++++++++++++++++ ircevent/irc_parse_test.go | 22 ++- ircevent/irc_struct.go | 54 ++++- 5 files changed, 664 insertions(+), 17 deletions(-) create mode 100644 ircevent/irc_labeledresponse_test.go diff --git a/ircevent/irc.go b/ircevent/irc.go index d8f95c8..6071b67 100644 --- a/ircevent/irc.go +++ b/ircevent/irc.go @@ -55,6 +55,8 @@ var ( ServerDisconnected = errors.New("Disconnected by server") SASLFailed = errors.New("SASL setup timed out. Does the server support SASL?") + CapabilityNotNegotiated = errors.New("The IRCv3 capability required for this was not negotiated") + serverDidNotQuit = errors.New("server did not respond to QUIT") clientHasQuit = errors.New("client has called Quit()") ) @@ -110,6 +112,8 @@ func (irc *Connection) readLoop() { errChan := make(chan error) go readMsgLoop(irc.socket, irc.MaxLineLen, msgChan, errChan, irc.end) + lastExpireCheck := time.Now() + for { select { case <-irc.end: @@ -129,6 +133,11 @@ func (irc *Connection) readLoop() { irc.setError(err) return } + + if irc.labelNegotiated && time.Since(lastExpireCheck) > irc.Timeout { + irc.expireLabels(false) + lastExpireCheck = time.Now() + } } } @@ -297,6 +306,8 @@ func (irc *Connection) waitForStop() { if irc.socket != nil { irc.socket.Close() } + + irc.expireLabels(true) } // Quit the current connection and disconnect from the server @@ -359,6 +370,27 @@ func (irc *Connection) Send(command string, params ...string) error { return irc.SendWithTags(nil, command, params...) } +// SendWithLabel sends an IRC message using the IRCv3 labeled-response specification. +// Instead of being processed by normal event handlers, the server response to the +// command will be collected into a *Batch and passed to the provided callback. +// If the server fails to respond correctly, the callback will be invoked with `nil` +// as the argument. +func (irc *Connection) SendWithLabel(callback func(*Batch), tags map[string]string, command string, params ...string) error { + if !irc.labelNegotiated { + return CapabilityNotNegotiated + } + + label := irc.registerLabel(callback) + + msg := ircmsg.MakeMessage(tags, "", command, params...) + msg.SetTag("label", label) + err := irc.SendIRCMessage(msg) + if err != nil { + irc.unregisterLabel(label) + } + return err +} + // Send a raw string. func (irc *Connection) SendRaw(message string) error { mlen := len(message) @@ -603,6 +635,11 @@ func (irc *Connection) Connect() (err error) { irc.capsAcked = make(map[string]string) irc.capsAdvertised = nil irc.stateMutex.Unlock() + irc.batchMutex.Lock() + irc.batches = make(map[string]batchInProgress) + irc.labelCallbacks = make(map[int64]pendingLabel) + irc.labelCounter = 0 + irc.batchMutex.Unlock() go irc.readLoop() go irc.writeLoop() @@ -659,6 +696,9 @@ func (irc *Connection) negotiateCaps() error { for _, c := range acknowledgedCaps { irc.capsAcked[c] = irc.capsAdvertised[c] } + _, irc.batchNegotiated = irc.capsAcked["batch"] + _, labelNegotiated := irc.capsAcked["labeled-response"] + irc.labelNegotiated = irc.batchNegotiated && labelNegotiated }() irc.Send("CAP", "LS", "302") diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index ff4fe48..d90d344 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -1,11 +1,12 @@ package ircevent import ( + "errors" "fmt" - "log" "runtime/debug" "strconv" "strings" + "time" "github.com/goshuirc/irc-go/ircmsg" ) @@ -71,6 +72,8 @@ func (irc *Connection) RemoveCallback(id CallbackID) { case registrationEvent: irc.removeCallbackNoMutex(RPL_ENDOFMOTD, id.id) irc.removeCallbackNoMutex(ERR_NOMOTD, id.id) + case "BATCH": + irc.removeBatchCallbackNoMutex(id.id) default: irc.removeCallbackNoMutex(id.eventCode, id.id) } @@ -114,6 +117,39 @@ func (irc *Connection) ReplaceCallback(id CallbackID, callback func(Event)) bool return false } +// AddBatchCallback adds a callback for handling BATCH'ed server responses. +// All available BATCH callbacks will be invoked in an undefined order, +// stopping at the first one to return a value of true (indicating successful +// processing). If no batch callback returns true, the batch will be "flattened" +// (i.e., its messages will be processed individually by the normal event +// handlers). Batch callbacks can be removed as usual with RemoveCallback. +func (irc *Connection) AddBatchCallback(callback func(*Batch) bool) CallbackID { + irc.eventsMutex.Lock() + defer irc.eventsMutex.Unlock() + + idNum := irc.callbackCounter + irc.callbackCounter++ + nbc := make([]batchCallbackPair, len(irc.batchCallbacks)+1) + copy(nbc, irc.batchCallbacks) + nbc[len(nbc)-1] = batchCallbackPair{id: idNum, callback: callback} + irc.batchCallbacks = nbc + return CallbackID{eventCode: "BATCH", id: idNum} +} + +func (irc *Connection) removeBatchCallbackNoMutex(idNum uint64) { + current := irc.batchCallbacks + if len(current) == 0 { + return + } + newList := make([]batchCallbackPair, 0, len(current)-1) + for _, p := range current { + if p.id != idNum { + newList = append(newList, p) + } + } + irc.batchCallbacks = newList +} + // Convenience function to add a callback that will be called once the // connection is completed (this is traditionally referred to as "connection // registration"). @@ -132,18 +168,194 @@ func (irc *Connection) getCallbacks(code string) (result []callbackPair) { return irc.events[code] } +func (irc *Connection) getBatchCallbacks() (result []batchCallbackPair) { + irc.eventsMutex.Lock() + defer irc.eventsMutex.Unlock() + + return irc.batchCallbacks +} + +var ( + // ad-hoc internal errors for batch processing + // these indicate invalid data from the server (or else local corruption) + errorDuplicateBatchID = errors.New("found duplicate batch ID") + errorNoParentBatchID = errors.New("parent batch ID not found") + errorBatchNotOpen = errors.New("tried to close batch, but batch ID not found") + errorUnknownLabel = errors.New("received labeled response from server, but we don't recognize the label") +) + +func (irc *Connection) handleBatchCommand(msg ircmsg.IRCMessage) { + if len(msg.Params) < 1 || len(msg.Params[0]) < 2 { + irc.Log.Printf("Invalid BATCH command from server\n") + return + } + + start := msg.Params[0][0] == '+' + if !start && msg.Params[0][0] != '-' { + irc.Log.Printf("Invalid BATCH ID from server: %s\n", msg.Params[0]) + return + } + batchID := msg.Params[0][1:] + isNested, parentBatchID := msg.GetTag("batch") + var label int64 + if start { + if present, labelStr := msg.GetTag("label"); present { + label = deserializeLabel(labelStr) + } + } + + finishedBatch, callback, err := func() (finishedBatch *Batch, callback LabelCallback, err error) { + irc.batchMutex.Lock() + defer irc.batchMutex.Unlock() + + if start { + if _, ok := irc.batches[batchID]; ok { + err = errorDuplicateBatchID + return + } + batchObj := new(Batch) + batchObj.IRCMessage = msg + irc.batches[batchID] = batchInProgress{ + createdAt: time.Now(), + batch: batchObj, + label: label, + } + if isNested { + parentBip := irc.batches[parentBatchID] + if parentBip.batch == nil { + err = errorNoParentBatchID + return + } + parentBip.batch.Items = append(parentBip.batch.Items, batchObj) + } + } else { + bip := irc.batches[batchID] + if bip.batch == nil { + err = errorBatchNotOpen + return + } + delete(irc.batches, batchID) + if !isNested { + finishedBatch = bip.batch + if bip.label != 0 { + callback = irc.getLabelCallbackNoMutex(bip.label) + if callback == nil { + err = errorUnknownLabel + } + + } + } + } + return + }() + + if err != nil { + irc.Log.Printf("batch error: %v (batchID=`%s`, parentBatchID=`%s`)", err, batchID, parentBatchID) + } else if callback != nil { + callback(finishedBatch) + } else if finishedBatch != nil { + irc.HandleBatch(finishedBatch) + } +} + +func (irc *Connection) getLabelCallbackNoMutex(label int64) (callback LabelCallback) { + if lc, ok := irc.labelCallbacks[label]; ok { + callback = lc.callback + delete(irc.labelCallbacks, label) + } + return +} + +func (irc *Connection) getLabelCallback(label int64) (callback LabelCallback) { + irc.batchMutex.Lock() + defer irc.batchMutex.Unlock() + return irc.getLabelCallbackNoMutex(label) +} + +func (irc *Connection) HandleBatch(batch *Batch) { + if batch == nil { + return + } + + success := false + for _, bh := range irc.getBatchCallbacks() { + if bh.callback(batch) { + success = true + break + } + } + if !success { + irc.handleBatchNaively(batch) + } +} + +// recursively "flatten" the nested batch; process every command individually +func (irc *Connection) handleBatchNaively(batch *Batch) { + if batch.Command != "BATCH" { + irc.HandleEvent(Event{IRCMessage: batch.IRCMessage}) + } + for _, item := range batch.Items { + irc.handleBatchNaively(item) + } +} + +func (irc *Connection) handleBatchedCommand(msg ircmsg.IRCMessage, batchID string) { + irc.batchMutex.Lock() + defer irc.batchMutex.Unlock() + + bip := irc.batches[batchID] + if bip.batch == nil { + irc.Log.Printf("ignoring command with unknown batch ID %s\n", batchID) + return + } + bip.batch.Items = append(bip.batch.Items, &Batch{IRCMessage: msg}) +} + // Execute all callbacks associated with a given event. func (irc *Connection) runCallbacks(msg ircmsg.IRCMessage) { if !irc.AllowPanic { defer func() { if r := recover(); r != nil { - log.Printf("Caught panic in callback: %v\n%s", r, debug.Stack()) + irc.Log.Printf("Caught panic in callback: %v\n%s", r, debug.Stack()) } }() } - event := Event{IRCMessage: msg} + // handle batch start or end + if irc.batchNegotiated { + if msg.Command == "BATCH" { + irc.handleBatchCommand(msg) + return + } else if hasBatchTag, batchID := msg.GetTag("batch"); hasBatchTag { + irc.handleBatchedCommand(msg, batchID) + return + } + } + // handle labeled single command, or labeled ACK + if irc.labelNegotiated { + if hasLabel, labelStr := msg.GetTag("label"); hasLabel { + var labelCallback LabelCallback + if label := deserializeLabel(labelStr); label != 0 { + labelCallback = irc.getLabelCallback(label) + } + if labelCallback == nil { + irc.Log.Printf("received unrecognized label from server: %s\n", labelStr) + return + } else { + labelCallback(&Batch{ + IRCMessage: msg, + }) + } + return + } + } + + // OK, it's a normal IRC command + irc.HandleEvent(Event{IRCMessage: msg}) +} + +func (irc *Connection) HandleEvent(event Event) { if irc.EnableCTCP { eventRewriteCTCP(&event) } @@ -386,6 +598,56 @@ func (irc *Connection) handleCAPLS(params []string) { } } +// labeled-response + +func (irc *Connection) registerLabel(callback LabelCallback) string { + irc.batchMutex.Lock() + defer irc.batchMutex.Unlock() + irc.labelCounter++ // increment first: 0 is an invalid label + label := irc.labelCounter + irc.labelCallbacks[label] = pendingLabel{ + createdAt: time.Now(), + callback: callback, + } + return serializeLabel(label) +} + +func (irc *Connection) unregisterLabel(labelStr string) { + label := deserializeLabel(labelStr) + if label == 0 { + return + } + irc.batchMutex.Lock() + defer irc.batchMutex.Unlock() + delete(irc.labelCallbacks, label) +} + +func (irc *Connection) expireLabels(force bool) { + var failedCallbacks []LabelCallback + defer func() { + for _, bcb := range failedCallbacks { + bcb(nil) + } + }() + + irc.batchMutex.Lock() + defer irc.batchMutex.Unlock() + now := time.Now() + + for label, lcb := range irc.labelCallbacks { + if force || now.Sub(lcb.createdAt) > irc.KeepAlive { + failedCallbacks = append(failedCallbacks, lcb.callback) + delete(irc.labelCallbacks, label) + } + } + + for batchID, bip := range irc.batches { + if force || now.Sub(bip.createdAt) > irc.KeepAlive { + delete(irc.batches, batchID) + } + } +} + func splitCAPToken(token string) (name, value string) { equalIdx := strings.IndexByte(token, '=') if equalIdx == -1 { @@ -403,3 +665,18 @@ func (irc *Connection) handleStandardReplies(e Event) { irc.Log.Printf("Received error code from server: %s %s\n", e.Command, strings.Join(e.Params, " ")) } } + +const ( + labelBase = 32 +) + +func serializeLabel(label int64) string { + return strconv.FormatInt(label, labelBase) +} + +func deserializeLabel(str string) int64 { + if p, err := strconv.ParseInt(str, labelBase, 64); err == nil { + return p + } + return 0 +} diff --git a/ircevent/irc_labeledresponse_test.go b/ircevent/irc_labeledresponse_test.go new file mode 100644 index 0000000..118c296 --- /dev/null +++ b/ircevent/irc_labeledresponse_test.go @@ -0,0 +1,282 @@ +package ircevent + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "fmt" + "testing" +) + +const ( + multilineName = "draft/multiline" + chathistoryName = "draft/chathistory" + concatTag = "draft/multiline-concat" + playbackCap = "draft/event-playback" +) + +func TestLabeledResponse(t *testing.T) { + irccon := connForTesting("go-eventirc", "go-eventirc", false) + irccon.Debug = true + irccon.RequestCaps = []string{"message-tags", "batch", "labeled-response"} + irccon.RealName = "ecf61da38b58" + results := make(map[string]string) + irccon.AddConnectCallback(func(e Event) { + irccon.SendWithLabel(func(batch *Batch) { + if batch == nil { + return + } + for _, line := range batch.Items { + results[line.Command] = line.Params[len(line.Params)-1] + } + irccon.Quit() + }, nil, "WHOIS", irccon.CurrentNick()) + }) + err := irccon.Connect() + if err != nil { + t.Fatalf("labeled response connection failed: %s", err) + } + irccon.Loop() + + // RPL_WHOISUSER, last param is the realname + assertEqual(results["311"], "ecf61da38b58") + if _, ok := results["379"]; !ok { + t.Errorf("Expected 379 RPL_WHOISMODES in response, but not received") + } + assertEqual(len(irccon.batches), 0) +} + +func TestLabeledResponseNoCaps(t *testing.T) { + irccon := connForTesting("go-eventirc", "go-eventirc", false) + irccon.Debug = true + irccon.RequestCaps = []string{"message-tags"} + irccon.RealName = "ecf61da38b58" + + err := irccon.Connect() + if err != nil { + t.Fatalf("labeled response connection failed: %s", err) + } + go irccon.Loop() + + results := make(map[string]string) + err = irccon.SendWithLabel(func(batch *Batch) { + if batch == nil { + return + } + for _, line := range batch.Items { + results[line.Command] = line.Params[len(line.Params)-1] + } + irccon.Quit() + }, nil, "WHOIS", irccon.CurrentNick()) + if err != CapabilityNotNegotiated { + t.Errorf("expected capability negotiation error, got %v", err) + } + assertEqual(len(irccon.batches), 0) + irccon.Quit() +} + +// test labeled single-line response, and labeled ACK +func TestLabeledResponseSingleResponse(t *testing.T) { + irc := connForTesting("go-eventirc", "go-eventirc", false) + irc.Debug = true + irc.RequestCaps = []string{"message-tags", "batch", "labeled-response"} + + err := irc.Connect() + if err != nil { + t.Fatalf("labeled response connection failed: %s", err) + } + go irc.Loop() + + channel := fmt.Sprintf("#%s", randomString()) + irc.Join(channel) + event := make(chan empty) + err = irc.SendWithLabel(func(batch *Batch) { + if !(batch != nil && batch.Command == "PONG" && batch.Params[len(batch.Params)-1] == "asdf") { + t.Errorf("expected labeled PONG, got %#v", batch) + } + close(event) + }, nil, "PING", "asdf") + <-event + + // no-op JOIN will send labeled ACK + event = make(chan empty) + err = irc.SendWithLabel(func(batch *Batch) { + if !(batch != nil && batch.Command == "ACK") { + t.Errorf("expected labeled ACK, got %#v", batch) + } + close(event) + }, nil, "JOIN", channel) + <-event + + assertEqual(len(irc.batches), 0) + irc.Quit() +} + +func randomString() string { + buf := make([]byte, 8) + rand.Read(buf) + return hex.EncodeToString(buf) +} + +func TestNestedBatch(t *testing.T) { + irc := connForTesting("go-eventirc", "go-eventirc", false) + irc.Debug = true + irc.RequestCaps = []string{"message-tags", "batch", "labeled-response", "server-time", multilineName, chathistoryName, playbackCap} + channel := fmt.Sprintf("#%s", randomString()) + + irc.AddConnectCallback(func(e Event) { + irc.Join(channel) + irc.Privmsg(channel, "hi") + irc.Send("BATCH", "+123", "draft/multiline", channel) + irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "hello") + irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "") + irc.SendWithTags(map[string]string{"batch": "123", concatTag: ""}, "PRIVMSG", channel, "how is ") + irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "everyone?") + irc.Send("BATCH", "-123") + }) + + err := irc.Connect() + if err != nil { + t.Fatalf("labeled response connection failed: %s", err) + } + go irc.Loop() + + var historyBatch *Batch + event := make(chan empty) + irc.SendWithLabel(func(batch *Batch) { + historyBatch = batch + close(event) + }, nil, "CHATHISTORY", "LATEST", channel, "*", "10") + + <-event + assertEqual(len(irc.labelCallbacks), 0) + + if historyBatch == nil { + t.Errorf("received nil history batch") + } + + // history should contain the JOIN, the PRIVMSG, and the multiline batch as a single item + if !(historyBatch.Command == "BATCH" && len(historyBatch.Items) == 3) { + t.Errorf("chathistory must send a real batch, got %#v", historyBatch) + } + var privmsg, multiline *Batch + for _, item := range historyBatch.Items { + switch item.Command { + case "PRIVMSG": + privmsg = item + case "BATCH": + multiline = item + } + } + if !(privmsg.Command == "PRIVMSG" && privmsg.Params[0] == channel && privmsg.Params[1] == "hi") { + t.Errorf("expected echo of individual privmsg, got %#v", privmsg) + } + if !(multiline.Command == "BATCH" && len(multiline.Items) == 4 && multiline.Items[3].Command == "PRIVMSG" && multiline.Items[3].Params[1] == "everyone?") { + t.Errorf("expected multiline in history, got %#v\n", multiline) + } + + assertEqual(len(irc.batches), 0) + irc.Quit() +} + +func TestBatchHandlers(t *testing.T) { + alice := connForTesting("alice", "go-eventirc", false) + alice.Debug = true + alice.RequestCaps = []string{"message-tags", "batch", "labeled-response", "server-time", "echo-message", multilineName, chathistoryName, playbackCap} + channel := fmt.Sprintf("#%s", randomString()) + + aliceUnderstandsBatches := true + var aliceBatchCount, alicePrivmsgCount int + alice.AddBatchCallback(func(batch *Batch) bool { + if aliceUnderstandsBatches { + aliceBatchCount++ + return true + } + return false + }) + alice.AddCallback("PRIVMSG", func(e Event) { + alicePrivmsgCount++ + }) + + err := alice.Connect() + if err != nil { + t.Fatalf("labeled response connection failed: %s", err) + } + go alice.Loop() + alice.Join(channel) + synchronize(alice) + + bob := connForTesting("bob", "go-eventirc", false) + bob.Debug = true + bob.RequestCaps = []string{"message-tags", "batch", "labeled-response", "server-time", "echo-message", multilineName, chathistoryName, playbackCap} + var buf bytes.Buffer + bob.AddBatchCallback(func(b *Batch) bool { + if !(len(b.Params) >= 3 && b.Params[1] == multilineName) { + return false + } + for i, item := range b.Items { + if item.Command == "PRIVMSG" { + buf.WriteString(item.Params[1]) + if !(item.HasTag(concatTag) || i == len(b.Items)-1) { + buf.WriteByte('\n') + } + } + } + return true + }) + + err = bob.Connect() + if err != nil { + t.Fatalf("labeled response connection failed: %s", err) + } + go bob.Loop() + bob.Join(channel) + synchronize(bob) + + sendMultiline := func() { + alice.Send("BATCH", "+123", "draft/multiline", channel) + alice.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "hello") + alice.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "") + alice.SendWithTags(map[string]string{"batch": "123", concatTag: ""}, "PRIVMSG", channel, "how is ") + alice.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "everyone?") + alice.Send("BATCH", "-123") + synchronize(alice) + } + multilineMessageValue := "hello\n\nhow is everyone?" + + sendMultiline() + synchronize(alice) + synchronize(bob) + + assertEqual(alicePrivmsgCount, 0) + alicePrivmsgCount = 0 + assertEqual(aliceBatchCount, 1) + aliceBatchCount = 0 + + assertEqual(buf.String(), multilineMessageValue) + buf.Reset() + + aliceUnderstandsBatches = false + sendMultiline() + synchronize(alice) + synchronize(bob) + + // disabled alice's batch handler, she should see a flattened batch + assertEqual(alicePrivmsgCount, 4) + assertEqual(aliceBatchCount, 0) + + assertEqual(buf.String(), multilineMessageValue) + + assertEqual(len(alice.batches), 0) + assertEqual(len(bob.batches), 0) + alice.Quit() + bob.Quit() +} + +func synchronize(irc *Connection) { + event := make(chan empty) + irc.SendWithLabel(func(b *Batch) { + close(event) + }, nil, "PING", "!") + <-event +} diff --git a/ircevent/irc_parse_test.go b/ircevent/irc_parse_test.go index 1341ca9..2b9036e 100644 --- a/ircevent/irc_parse_test.go +++ b/ircevent/irc_parse_test.go @@ -1,6 +1,8 @@ package ircevent import ( + "fmt" + "reflect" "testing" ) @@ -19,18 +21,18 @@ func TestParse(t *testing.T) { } } -func assertEqual(found, expected string, t *testing.T) { - if found != expected { - t.Errorf("expected `%s`, got `%s`\n", expected, found) +func assertEqual(found, expected interface{}) { + if !reflect.DeepEqual(found, expected) { + panic(fmt.Sprintf("expected `%#v`, got `%#v`\n", expected, found)) } } func TestUnescapeIsupport(t *testing.T) { - assertEqual(unescapeISupportValue(""), "", t) - assertEqual(unescapeISupportValue("a"), "a", t) - assertEqual(unescapeISupportValue(`\x20`), " ", t) - assertEqual(unescapeISupportValue(`\x20b`), " b", t) - assertEqual(unescapeISupportValue(`a\x20`), "a ", t) - assertEqual(unescapeISupportValue(`a\x20b`), "a b", t) - assertEqual(unescapeISupportValue(`\x20\x20`), " ", t) + assertEqual(unescapeISupportValue(""), "") + assertEqual(unescapeISupportValue("a"), "a") + assertEqual(unescapeISupportValue(`\x20`), " ") + assertEqual(unescapeISupportValue(`\x20b`), " b") + assertEqual(unescapeISupportValue(`a\x20`), "a ") + assertEqual(unescapeISupportValue(`a\x20b`), "a b") + assertEqual(unescapeISupportValue(`\x20\x20`), " ") } diff --git a/ircevent/irc_struct.go b/ircevent/irc_struct.go index 685845a..10ec2e9 100644 --- a/ircevent/irc_struct.go +++ b/ircevent/irc_struct.go @@ -24,6 +24,15 @@ type callbackPair struct { callback Callback } +type BatchCallback func(*Batch) bool + +type batchCallbackPair struct { + id uint64 + callback BatchCallback +} + +type LabelCallback func(*Batch) + type capResult struct { capName string ack bool @@ -77,9 +86,11 @@ type Connection struct { // Connect() builds these with sufficient capacity to receive all expected // responses during negotiation. Sends to them are nonblocking, so anything // sent outside of negotiation will not cause the relevant callbacks to block. - welcomeChan chan empty // signals that we got 001 and we are now connected - saslChan chan saslResult // transmits the final outcome of SASL negotiation - capsChan chan capResult // transmits the final status of each CAP negotiated + welcomeChan chan empty // signals that we got 001 and we are now connected + saslChan chan saslResult // transmits the final outcome of SASL negotiation + capsChan chan capResult // transmits the final status of each CAP negotiated + batchNegotiated bool + labelNegotiated bool // callback state eventsMutex sync.Mutex @@ -89,16 +100,51 @@ type Connection struct { // if we add a callback in two places we might reuse the number (XXX) callbackCounter uint64 // did we initialize the callbacks needed for the library itself? + batchCallbacks []batchCallbackPair hasBaseCallbacks bool + batchMutex sync.Mutex + batches map[string]batchInProgress + labelCallbacks map[int64]pendingLabel + labelCounter int64 + Log *log.Logger } -// A struct to represent an event. +type batchInProgress struct { + createdAt time.Time + label int64 + // needs to be heap-allocated so we can append to batch.Items: + batch *Batch +} + +type pendingLabel struct { + createdAt time.Time + callback LabelCallback +} + +// Event represents an individual IRC line. type Event struct { ircmsg.IRCMessage } +// Batch represents an IRCv3 batch, or a line within one. There are +// two cases: +// 1. (Batch).Command == "BATCH". This indicates the start of an IRCv3 +// batch; the embedded IRCMessage is the initial BATCH command, which +// may contain tags that pertain to the batch as a whole. (Batch).Items +// contains zero or more *Batch elements, pointing to the contents of +// the batch in order. +// 2. (Batch).Command != "BATCH". This is an ordinary IRC line; its +// tags, command, and parameters are available as members of the embedded +// IRCMessage. +// In the context of labeled-response, there is a third case: a `nil` +// value of *Batch indicates that the server failed to respond in time. +type Batch struct { + ircmsg.IRCMessage + Items []*Batch +} + // Retrieve the last message from Event arguments. // This function leaves the arguments untouched and // returns an empty string if there are none. From 6d11cde149d9e634a2262690357cb16ac1838ef1 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 10 Mar 2021 02:32:02 -0500 Subject: [PATCH 2/6] explain HandleBatch and HandleEvent --- ircevent/irc_callback.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index d90d344..1e7c716 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -272,6 +272,9 @@ func (irc *Connection) getLabelCallback(label int64) (callback LabelCallback) { return irc.getLabelCallbackNoMutex(label) } +// HandleBatch handles a *Batch using available handlers, "flattening" it if +// no handler succeeds. This can be used in a batch or labeled-response callback +// to process inner batches. func (irc *Connection) HandleBatch(batch *Batch) { if batch == nil { return @@ -355,6 +358,8 @@ func (irc *Connection) runCallbacks(msg ircmsg.IRCMessage) { irc.HandleEvent(Event{IRCMessage: msg}) } +// HandleEvent handles an IRC line using the available handlers. This can be +// used in a batch or labeled-response callback to process an individual line. func (irc *Connection) HandleEvent(event Event) { if irc.EnableCTCP { eventRewriteCTCP(&event) From b2a42a4234a3cb33dd4973cc3431a784f41db55b Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 10 Mar 2021 12:43:23 -0500 Subject: [PATCH 3/6] make sure non-label batches expire as well --- ircevent/irc.go | 6 +++--- ircevent/irc_callback.go | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ircevent/irc.go b/ircevent/irc.go index 6071b67..6bd80c1 100644 --- a/ircevent/irc.go +++ b/ircevent/irc.go @@ -134,8 +134,8 @@ func (irc *Connection) readLoop() { return } - if irc.labelNegotiated && time.Since(lastExpireCheck) > irc.Timeout { - irc.expireLabels(false) + if irc.batchNegotiated && time.Since(lastExpireCheck) > irc.Timeout { + irc.expireBatches(false) lastExpireCheck = time.Now() } } @@ -307,7 +307,7 @@ func (irc *Connection) waitForStop() { irc.socket.Close() } - irc.expireLabels(true) + irc.expireBatches(true) } // Quit the current connection and disconnect from the server diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index 1e7c716..00c84db 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -627,7 +627,9 @@ func (irc *Connection) unregisterLabel(labelStr string) { delete(irc.labelCallbacks, label) } -func (irc *Connection) expireLabels(force bool) { +// expire open batches from the server that weren't closed in a +// timely fashion +func (irc *Connection) expireBatches(force bool) { var failedCallbacks []LabelCallback defer func() { for _, bcb := range failedCallbacks { From 05d8f4419826702f4a92528ac0f6b1258d330e1d Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 10 Mar 2021 13:22:31 -0500 Subject: [PATCH 4/6] clean up synchronization in tests --- ircevent/irc_labeledresponse_test.go | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/ircevent/irc_labeledresponse_test.go b/ircevent/irc_labeledresponse_test.go index 118c296..ca6b3f5 100644 --- a/ircevent/irc_labeledresponse_test.go +++ b/ircevent/irc_labeledresponse_test.go @@ -124,23 +124,21 @@ func TestNestedBatch(t *testing.T) { irc.RequestCaps = []string{"message-tags", "batch", "labeled-response", "server-time", multilineName, chathistoryName, playbackCap} channel := fmt.Sprintf("#%s", randomString()) - irc.AddConnectCallback(func(e Event) { - irc.Join(channel) - irc.Privmsg(channel, "hi") - irc.Send("BATCH", "+123", "draft/multiline", channel) - irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "hello") - irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "") - irc.SendWithTags(map[string]string{"batch": "123", concatTag: ""}, "PRIVMSG", channel, "how is ") - irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "everyone?") - irc.Send("BATCH", "-123") - }) - err := irc.Connect() if err != nil { t.Fatalf("labeled response connection failed: %s", err) } go irc.Loop() + irc.Join(channel) + irc.Privmsg(channel, "hi") + irc.Send("BATCH", "+123", "draft/multiline", channel) + irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "hello") + irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "") + irc.SendWithTags(map[string]string{"batch": "123", concatTag: ""}, "PRIVMSG", channel, "how is ") + irc.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "everyone?") + irc.Send("BATCH", "-123") + var historyBatch *Batch event := make(chan empty) irc.SendWithLabel(func(batch *Batch) { @@ -241,12 +239,11 @@ func TestBatchHandlers(t *testing.T) { alice.SendWithTags(map[string]string{"batch": "123"}, "PRIVMSG", channel, "everyone?") alice.Send("BATCH", "-123") synchronize(alice) + synchronize(bob) } multilineMessageValue := "hello\n\nhow is everyone?" sendMultiline() - synchronize(alice) - synchronize(bob) assertEqual(alicePrivmsgCount, 0) alicePrivmsgCount = 0 @@ -258,8 +255,6 @@ func TestBatchHandlers(t *testing.T) { aliceUnderstandsBatches = false sendMultiline() - synchronize(alice) - synchronize(bob) // disabled alice's batch handler, she should see a flattened batch assertEqual(alicePrivmsgCount, 4) From b26cd91715d3a434970772103a2d255ccce78018 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 10 Mar 2021 14:05:12 -0500 Subject: [PATCH 5/6] tweak force-expiration behavior --- ircevent/irc_callback.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index 00c84db..dbe7266 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -628,7 +628,9 @@ func (irc *Connection) unregisterLabel(labelStr string) { } // expire open batches from the server that weren't closed in a -// timely fashion +// timely fashion. `force` expires all label callbacks regardless +// of time created (so they can be cleaned up when the connection +// fails). func (irc *Connection) expireBatches(force bool) { var failedCallbacks []LabelCallback defer func() { @@ -649,7 +651,7 @@ func (irc *Connection) expireBatches(force bool) { } for batchID, bip := range irc.batches { - if force || now.Sub(bip.createdAt) > irc.KeepAlive { + if now.Sub(bip.createdAt) > irc.KeepAlive { delete(irc.batches, batchID) } } From 8f78fbb4a257a348b59deae2c2684e65c6dcacc3 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 10 Mar 2021 14:07:24 -0500 Subject: [PATCH 6/6] don't accept regular callbacks for BATCH --- ircevent/irc_callback.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index dbe7266..da72e6c 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -32,7 +32,7 @@ func (irc *Connection) AddCallback(eventCode string, callback func(Event)) Callb func (irc *Connection) addCallback(eventCode string, callback Callback, prepend bool, idNum uint64) CallbackID { eventCode = strings.ToUpper(eventCode) - if eventCode == "" || strings.HasPrefix(eventCode, "*") { + if eventCode == "" || strings.HasPrefix(eventCode, "*") || eventCode == "BATCH" { return CallbackID{} }