diff --git a/ircevent/irc.go b/ircevent/irc.go index 63fbfac..1b4ac5c 100644 --- a/ircevent/irc.go +++ b/ircevent/irc.go @@ -134,7 +134,7 @@ func (irc *Connection) readLoop() { return } - if irc.batchNegotiated && time.Since(lastExpireCheck) > irc.Timeout { + if irc.batchNegotiated() && time.Since(lastExpireCheck) > irc.Timeout { irc.expireBatches(false) lastExpireCheck = time.Now() } @@ -376,7 +376,7 @@ func (irc *Connection) Send(command string, params ...string) error { // If the server fails to respond correctly, the callback will be invoked with `nil` // as the argument. func (irc *Connection) SendWithLabel(callback func(*Batch), tags map[string]string, command string, params ...string) error { - if !irc.labelNegotiated { + if !irc.labelNegotiated() { return CapabilityNotNegotiated } @@ -691,14 +691,7 @@ func (irc *Connection) negotiateCaps() error { var acknowledgedCaps []string defer func() { - irc.stateMutex.Lock() - defer irc.stateMutex.Unlock() - for _, c := range acknowledgedCaps { - irc.capsAcked[c] = irc.capsAdvertised[c] - } - _, irc.batchNegotiated = irc.capsAcked["batch"] - _, labelNegotiated := irc.capsAcked["labeled-response"] - irc.labelNegotiated = irc.batchNegotiated && labelNegotiated + irc.processAckedCaps(acknowledgedCaps) }() irc.Send("CAP", "LS", "302") diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index 21a547b..42a1145 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -325,7 +325,7 @@ func (irc *Connection) runCallbacks(msg ircmsg.Message) { } // handle batch start or end - if irc.batchNegotiated { + if irc.batchNegotiated() { if msg.Command == "BATCH" { irc.handleBatchCommand(msg) return @@ -336,7 +336,7 @@ func (irc *Connection) runCallbacks(msg ircmsg.Message) { } // handle labeled single command, or labeled ACK - if irc.labelNegotiated { + if irc.labelNegotiated() { if hasLabel, labelStr := msg.GetTag("label"); hasLabel { var labelCallback LabelCallback if label := deserializeLabel(labelStr); label != 0 { diff --git a/ircevent/irc_struct.go b/ircevent/irc_struct.go index 3b2859b..5c0a238 100644 --- a/ircevent/irc_struct.go +++ b/ircevent/irc_struct.go @@ -10,6 +10,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "github.com/goshuirc/irc-go/ircmsg" @@ -86,11 +87,10 @@ type Connection struct { // Connect() builds these with sufficient capacity to receive all expected // responses during negotiation. Sends to them are nonblocking, so anything // sent outside of negotiation will not cause the relevant callbacks to block. - welcomeChan chan empty // signals that we got 001 and we are now connected - saslChan chan saslResult // transmits the final outcome of SASL negotiation - capsChan chan capResult // transmits the final status of each CAP negotiated - batchNegotiated bool - labelNegotiated bool + welcomeChan chan empty // signals that we got 001 and we are now connected + saslChan chan saslResult // transmits the final outcome of SASL negotiation + capsChan chan capResult // transmits the final status of each CAP negotiated + capFlags uint32 // callback state eventsMutex sync.Mutex @@ -140,6 +140,56 @@ type Batch struct { Items []*Batch } +const ( + capFlagBatch uint32 = 1 << iota + capFlagMessageTags + capFlagLabeledResponse + capFlagMultiline +) + +func (irc *Connection) processAckedCaps(acknowledgedCaps []string) { + irc.stateMutex.Lock() + defer irc.stateMutex.Unlock() + var hasBatch, hasLabel, hasTags, hasMultiline bool + for _, c := range acknowledgedCaps { + irc.capsAcked[c] = irc.capsAdvertised[c] + switch c { + case "batch": + hasBatch = true + case "labeled-response": + hasLabel = true + case "message-tags": + hasTags = true + case "draft/multiline", "multiline": + hasMultiline = true + } + } + + var capFlags uint32 + if hasBatch { + capFlags |= capFlagBatch + } + if hasBatch && hasLabel { + capFlags |= capFlagLabeledResponse + } + if hasTags { + capFlags |= capFlagMessageTags + } + if hasTags && hasBatch && hasMultiline { + capFlags |= capFlagMultiline + } + + atomic.StoreUint32(&irc.capFlags, capFlags) +} + +func (irc *Connection) batchNegotiated() bool { + return atomic.LoadUint32(&irc.capFlags)&capFlagBatch != 0 +} + +func (irc *Connection) labelNegotiated() bool { + return atomic.LoadUint32(&irc.capFlags)&capFlagLabeledResponse != 0 +} + func ExtractNick(source string) string { nick, _, _ := SplitNUH(source) return nick