diff --git a/ircevent/examples/simple.go b/ircevent/examples/simple.go index b15367d..a6f3cfc 100644 --- a/ircevent/examples/simple.go +++ b/ircevent/examples/simple.go @@ -36,7 +36,13 @@ func main() { SASLPassword: saslPassword, } - irc.AddCallback("001", func(e ircevent.Event) { irc.Join(channel) }) + irc.AddConnectCallback(func(e ircevent.Event) { + // attempt to set the BOT mode on ourself: + if botMode := irc.ISupport()["BOT"]; botMode != "" { + irc.Send("MODE", irc.CurrentNick(), "+"+botMode) + } + irc.Join(channel) + }) irc.AddCallback("JOIN", func(e ircevent.Event) {}) // TODO try to rejoin if we *don't* get this irc.AddCallback("PRIVMSG", func(e ircevent.Event) { if len(e.Params) < 2 { diff --git a/ircevent/irc.go b/ircevent/irc.go index 5c1a120..d8f95c8 100644 --- a/ircevent/irc.go +++ b/ircevent/irc.go @@ -446,13 +446,23 @@ func (irc *Connection) setCurrentNick(nick string) { irc.currentNick = nick } -// Return IRCv3 CAPs actually enabled on the connection. -func (irc *Connection) AcknowledgedCaps(result []string) { +// Return IRCv3 CAPs actually enabled on the connection, together +// with their values if applicable. The resulting map is shared, +// so do not modify it. +func (irc *Connection) AcknowledgedCaps() (result map[string]string) { irc.stateMutex.Lock() defer irc.stateMutex.Unlock() - result = make([]string, len(irc.acknowledgedCaps)) - copy(result[:], irc.acknowledgedCaps[:]) - return + return irc.capsAcked +} + +// Returns the 005 RPL_ISUPPORT tokens sent by the server when the +// connection was initiated, parsed into key-value form as a map. +// The resulting map is shared, so do not modify it. +func (irc *Connection) ISupport() (result map[string]string) { + irc.stateMutex.Lock() + defer irc.stateMutex.Unlock() + // XXX modifications to isupport are not permitted after registration + return irc.isupport } // Returns true if the connection is connected to an IRC server. @@ -498,7 +508,6 @@ func (irc *Connection) Connect() (err error) { } // mark Server as stopped since there can be an error during connect - irc.acknowledgedCaps = nil irc.running = false irc.socket = nil irc.currentNick = "" @@ -588,6 +597,11 @@ func (irc *Connection) Connect() (err error) { irc.capsChan = make(chan capResult, len(irc.RequestCaps)) irc.saslChan = make(chan saslResult, 1) irc.welcomeChan = make(chan empty, 1) + irc.registered = false + irc.isupportPartial = make(map[string]string) + irc.isupport = nil + irc.capsAcked = make(map[string]string) + irc.capsAdvertised = nil irc.stateMutex.Unlock() go irc.readLoop() @@ -642,10 +656,12 @@ func (irc *Connection) negotiateCaps() error { defer func() { irc.stateMutex.Lock() defer irc.stateMutex.Unlock() - irc.acknowledgedCaps = acknowledgedCaps + for _, c := range acknowledgedCaps { + irc.capsAcked[c] = irc.capsAdvertised[c] + } }() - irc.Send("CAP", "LS") + irc.Send("CAP", "LS", "302") defer func() { irc.Send("CAP", "END") }() @@ -670,6 +686,11 @@ func (irc *Connection) negotiateCaps() error { } if irc.UseSASL { + if !sliceContains("sasl", acknowledgedCaps) { + return SASLFailed + } else { + irc.Send("AUTHENTICATE", irc.SASLMech) + } select { case res := <-irc.saslChan: if res.Failed { diff --git a/ircevent/irc_callback.go b/ircevent/irc_callback.go index 8d42ae9..b3c9f77 100644 --- a/ircevent/irc_callback.go +++ b/ircevent/irc_callback.go @@ -4,11 +4,17 @@ import ( "fmt" "log" "runtime/debug" + "strconv" "strings" "github.com/goshuirc/irc-go/ircmsg" ) +const ( + // fake events that we manage specially + registrationEvent = "*REGISTRATION" +) + // Tuple type for uniquely identifying callbacks type CallbackID struct { eventCode string @@ -16,27 +22,44 @@ type CallbackID struct { } // Register a callback to a connection and event code. A callback is a function -// which takes only an Event pointer as parameter. Valid event codes are all -// IRC/CTCP commands and error/response codes. To register a callback for all -// events pass "*" as the event code. This function returns the ID of the +// which takes only an Event object as parameter. Valid event codes are all +// IRC/CTCP commands and error/response codes. This function returns the ID of the // registered callback for later management. func (irc *Connection) AddCallback(eventCode string, callback func(Event)) CallbackID { + return irc.addCallback(eventCode, Callback(callback), false, 0) +} + +func (irc *Connection) addCallback(eventCode string, callback Callback, prepend bool, idNum uint64) CallbackID { eventCode = strings.ToUpper(eventCode) + if eventCode == "" || strings.HasPrefix(eventCode, "*") { + return CallbackID{} + } irc.eventsMutex.Lock() defer irc.eventsMutex.Unlock() if irc.events == nil { - irc.events = make(map[string]map[uint64]Callback) + irc.events = make(map[string][]callbackPair) } - _, ok := irc.events[eventCode] - if !ok { - irc.events[eventCode] = make(map[uint64]Callback) + if idNum == 0 { + idNum = irc.callbackCounter + irc.callbackCounter++ } - id := CallbackID{eventCode: eventCode, id: irc.idCounter} - irc.idCounter++ - irc.events[eventCode][id.id] = Callback(callback) + id := CallbackID{eventCode: eventCode, id: idNum} + newPair := callbackPair{id: id.id, callback: callback} + current := irc.events[eventCode] + newList := make([]callbackPair, len(current)+1) + start := 0 + if prepend { + newList[start] = newPair + start++ + } + copy(newList[start:], current) + if !prepend { + newList[len(newList)-1] = newPair + } + irc.events[eventCode] = newList return id } @@ -44,7 +67,27 @@ func (irc *Connection) AddCallback(eventCode string, callback func(Event)) Callb func (irc *Connection) RemoveCallback(id CallbackID) { irc.eventsMutex.Lock() defer irc.eventsMutex.Unlock() - delete(irc.events[id.eventCode], id.id) + switch id.eventCode { + case registrationEvent: + irc.removeCallbackNoMutex(RPL_ENDOFMOTD, id.id) + irc.removeCallbackNoMutex(ERR_NOMOTD, id.id) + default: + irc.removeCallbackNoMutex(id.eventCode, id.id) + } +} + +func (irc *Connection) removeCallbackNoMutex(code string, id uint64) { + current := irc.events[code] + if len(current) == 0 { + return + } + newList := make([]callbackPair, 0, len(current)-1) + for _, p := range current { + if p.id != id { + newList = append(newList, p) + } + } + irc.events[code] = newList } // Remove all callbacks from a given event code. @@ -61,33 +104,32 @@ func (irc *Connection) ReplaceCallback(id CallbackID, callback func(Event)) bool irc.eventsMutex.Lock() defer irc.eventsMutex.Unlock() - if _, ok := irc.events[id.eventCode][id.id]; ok { - irc.events[id.eventCode][id.id] = callback - return true + list := irc.events[id.eventCode] + for i, p := range list { + if p.id == id.id { + list[i] = callbackPair{id: id.id, callback: callback} + return true + } } return false } -func (irc *Connection) getCallbacks(code string) (result []Callback) { +// Convenience function to add a callback that will be called once the +// connection is completed (this is traditionally referred to as "connection +// registration"). +func (irc *Connection) AddConnectCallback(callback func(Event)) (id CallbackID) { + // XXX: forcibly use the same ID number for both copies of the callback + id376 := irc.AddCallback(RPL_ENDOFMOTD, callback) + irc.addCallback(ERR_NOMOTD, callback, false, id376.id) + return CallbackID{eventCode: registrationEvent, id: id376.id} +} + +func (irc *Connection) getCallbacks(code string) (result []callbackPair) { code = strings.ToUpper(code) irc.eventsMutex.Lock() defer irc.eventsMutex.Unlock() - - cMap := irc.events[code] - starMap := irc.events["*"] - length := len(cMap) + len(starMap) - if length == 0 { - return - } - result = make([]Callback, 0, length) - for _, c := range cMap { - result = append(result, c) - } - for _, c := range starMap { - result = append(result, c) - } - return + return irc.events[code] } // Execute all callbacks associated with a given event. @@ -106,12 +148,12 @@ func (irc *Connection) runCallbacks(msg ircmsg.IRCMessage) { eventRewriteCTCP(&event) } - callbacks := irc.getCallbacks(event.Command) + callbackPairs := irc.getCallbacks(event.Command) // just run the callbacks in serial, since it's not safe for them // to take a long time to execute in any case - for _, callback := range callbacks { - callback(event) + for _, pair := range callbackPairs { + pair.callback(event) } } @@ -136,12 +178,15 @@ func (irc *Connection) setupCallbacks() { // 433: ERR_NICKNAMEINUSE " :Nickname is already in use" // 437: ERR_UNAVAILRESOURCE " :Nick/channel is temporarily unavailable" - irc.AddCallback("433", irc.handleUnavailableNick) - irc.AddCallback("437", irc.handleUnavailableNick) + irc.AddCallback(ERR_NICKNAMEINUSE, irc.handleUnavailableNick) + irc.AddCallback(ERR_UNAVAILRESOURCE, irc.handleUnavailableNick) - // 1: RPL_WELCOME "Welcome to the Internet Relay Network !@" + // 001: RPL_WELCOME "Welcome to the Internet Relay Network !@" // Set irc.currentNick to the actually used nick in this connection. - irc.AddCallback("001", irc.handleRplWelcome) + irc.AddCallback(RPL_WELCOME, irc.handleRplWelcome) + + // 005: RPL_ISUPPORT, conveys supported server features + irc.AddCallback(RPL_ISUPPORT, irc.handleISupport) // respond to NICK from the server (in response to our own NICK, or sent unprompted) irc.AddCallback("NICK", func(e Event) { @@ -156,36 +201,7 @@ func (irc *Connection) setupCallbacks() { } }) - irc.AddCallback("CAP", func(e Event) { - if len(e.Params) != 3 { - return - } - command := e.Params[1] - capsChan := irc.capsChan - - // TODO this assumes all the caps on one line - // TODO support CAP LS 302 - if command == "LS" { - capsList := strings.Fields(e.Params[2]) - for _, capName := range irc.RequestCaps { - if sliceContains(capName, capsList) { - irc.Send("CAP", "REQ", capName) - } else { - select { - case capsChan <- capResult{capName, false}: - default: - } - } - } - } else if command == "ACK" || command == "NAK" { - for _, capName := range strings.Fields(e.Params[2]) { - select { - case capsChan <- capResult{capName, command == "ACK"}: - default: - } - } - } - }) + irc.AddCallback("CAP", irc.handleCAP) if irc.UseSASL { irc.setupSASLCallbacks() @@ -194,6 +210,11 @@ func (irc *Connection) setupCallbacks() { if irc.EnableCTCP { irc.setupCTCPCallbacks() } + + // prepend our own callbacks for the end of registration, + // so they happen before any client-added callbacks + irc.addCallback(RPL_ENDOFMOTD, irc.handleRegistration, true, 0) + irc.addCallback(ERR_NOMOTD, irc.handleRegistration, true, 0) } func (irc *Connection) handleRplWelcome(e Event) { @@ -204,12 +225,29 @@ func (irc *Connection) handleRplWelcome(e Event) { if len(e.Params) > 0 { irc.currentNick = e.Params[0] } +} +func (irc *Connection) handleRegistration(e Event) { // wake up Connect() if applicable - select { - case irc.welcomeChan <- empty{}: - default: + defer func() { + select { + case irc.welcomeChan <- empty{}: + default: + } + }() + + irc.stateMutex.Lock() + defer irc.stateMutex.Unlock() + + if irc.registered { + return } + irc.registered = true + + // mark the isupport complete + irc.isupport = irc.isupportPartial + irc.isupportPartial = nil + } func (irc *Connection) handleUnavailableNick(e Event) { @@ -229,3 +267,126 @@ func (irc *Connection) handleUnavailableNick(e Event) { irc.Send("NICK", nickToTry) } } + +func (irc *Connection) handleISupport(e Event) { + irc.stateMutex.Lock() + defer irc.stateMutex.Unlock() + + // TODO handle 005 changes after registration + if irc.isupportPartial == nil { + return + } + if len(e.Params) < 3 { + return + } + for _, token := range e.Params[1 : len(e.Params)-1] { + equalsIdx := strings.IndexByte(token, '=') + if equalsIdx == -1 { + irc.isupportPartial[token] = "" // no value + } else { + irc.isupportPartial[token[:equalsIdx]] = unescapeISupportValue(token[equalsIdx+1:]) + } + } +} + +func unescapeISupportValue(in string) (out string) { + if strings.IndexByte(in, '\\') == -1 { + return in + } + var buf strings.Builder + for i := 0; i < len(in); { + if in[i] == '\\' && i+3 < len(in) && in[i+1] == 'x' { + hex := in[i+2 : i+4] + if octet, err := strconv.ParseInt(hex, 16, 8); err == nil { + buf.WriteByte(byte(octet)) + i += 4 + continue + } + } + buf.WriteByte(in[i]) + i++ + } + return buf.String() +} + +func (irc *Connection) handleCAP(e Event) { + if len(e.Params) < 3 { + return + } + ack := false + // CAP PARAMS... + switch e.Params[1] { + case "LS": + irc.handleCAPLS(e.Params[2:]) + case "ACK": + ack = true + fallthrough + case "NAK": + for _, token := range strings.Fields(e.Params[2]) { + name, _ := splitCAPToken(token) + if sliceContains(name, irc.RequestCaps) { + select { + case irc.capsChan <- capResult{capName: name, ack: ack}: + default: + } + } + } + } +} + +func (irc *Connection) handleCAPLS(params []string) { + var capsToReq, capsNotFound []string + defer func() { + for _, c := range capsToReq { + irc.Send("CAP", "REQ", c) + } + for _, c := range capsNotFound { + select { + case irc.capsChan <- capResult{capName: c, ack: false}: + default: + } + } + }() + + irc.stateMutex.Lock() + defer irc.stateMutex.Unlock() + + if irc.registered { + // TODO server could probably trick us into panic here by sending + // additional LS before the end of registration + return + } + + if irc.capsAdvertised == nil { + irc.capsAdvertised = make(map[string]string) + } + + // multiline responses to CAP LS 302 start with a 4-parameter form: + // CAP * LS * :account-notify away-notify [...] + // and end with a 3-parameter form: + // CAP * LS :userhost-in-names znc.in/playback [...] + final := len(params) == 1 + for _, token := range strings.Fields(params[len(params)-1]) { + name, value := splitCAPToken(token) + irc.capsAdvertised[name] = value + } + + if final { + for _, c := range irc.RequestCaps { + if _, ok := irc.capsAdvertised[c]; ok { + capsToReq = append(capsToReq, c) + } else { + capsNotFound = append(capsNotFound, c) + } + } + } +} + +func splitCAPToken(token string) (name, value string) { + equalIdx := strings.IndexByte(token, '=') + if equalIdx == -1 { + return token, "" + } else { + return token[:equalIdx], token[equalIdx+1:] + } +} diff --git a/ircevent/irc_parse_test.go b/ircevent/irc_parse_test.go index a69eb58..1341ca9 100644 --- a/ircevent/irc_parse_test.go +++ b/ircevent/irc_parse_test.go @@ -18,3 +18,19 @@ func TestParse(t *testing.T) { t.Fatal("Parse failed: host") } } + +func assertEqual(found, expected string, t *testing.T) { + if found != expected { + t.Errorf("expected `%s`, got `%s`\n", expected, found) + } +} + +func TestUnescapeIsupport(t *testing.T) { + assertEqual(unescapeISupportValue(""), "", t) + assertEqual(unescapeISupportValue("a"), "a", t) + assertEqual(unescapeISupportValue(`\x20`), " ", t) + assertEqual(unescapeISupportValue(`\x20b`), " b", t) + assertEqual(unescapeISupportValue(`a\x20`), "a ", t) + assertEqual(unescapeISupportValue(`a\x20b`), "a b", t) + assertEqual(unescapeISupportValue(`\x20\x20`), " ", t) +} diff --git a/ircevent/irc_sasl.go b/ircevent/irc_sasl.go index 95fd3de..23074fb 100644 --- a/ircevent/irc_sasl.go +++ b/ircevent/irc_sasl.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "errors" "fmt" - "strings" ) type saslResult struct { @@ -21,16 +20,6 @@ func sliceContains(str string, list []string) bool { return false } -// Check if a space-separated list of arguments contains a value. -func listContains(list string, value string) bool { - for _, arg_name := range strings.Split(strings.TrimSpace(list), " ") { - if arg_name == value { - return true - } - } - return false -} - func (irc *Connection) submitSASLResult(r saslResult) { select { case irc.saslChan <- r: @@ -39,43 +28,35 @@ func (irc *Connection) submitSASLResult(r saslResult) { } func (irc *Connection) setupSASLCallbacks() { - irc.AddCallback("CAP", func(e Event) { - if len(e.Params) == 3 { - if e.Params[1] == "LS" { - if !listContains(e.Params[2], "sasl") { - irc.submitSASLResult(saslResult{true, errors.New("no SASL capability " + e.Params[2])}) - } - } - if e.Params[1] == "ACK" && listContains(e.Params[2], "sasl") { - irc.Send("AUTHENTICATE", irc.SASLMech) - } - } - }) - irc.AddCallback("AUTHENTICATE", func(e Event) { str := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s\x00%s\x00%s", irc.SASLLogin, irc.SASLLogin, irc.SASLPassword))) irc.Send("AUTHENTICATE", str) }) - irc.AddCallback("901", func(e Event) { + irc.AddCallback(RPL_LOGGEDOUT, func(e Event) { irc.SendRaw("CAP END") irc.SendRaw("QUIT") irc.submitSASLResult(saslResult{true, errors.New(e.Params[1])}) }) - irc.AddCallback("902", func(e Event) { + irc.AddCallback(ERR_NICKLOCKED, func(e Event) { irc.SendRaw("CAP END") irc.SendRaw("QUIT") irc.submitSASLResult(saslResult{true, errors.New(e.Params[1])}) }) - irc.AddCallback("903", func(e Event) { + irc.AddCallback(RPL_SASLSUCCESS, func(e Event) { irc.submitSASLResult(saslResult{false, nil}) }) - irc.AddCallback("904", func(e Event) { + irc.AddCallback(ERR_SASLFAIL, func(e Event) { irc.SendRaw("CAP END") irc.SendRaw("QUIT") irc.submitSASLResult(saslResult{true, errors.New(e.Params[1])}) }) + + // this could potentially happen with auto-login via certfp? + irc.AddCallback(ERR_SASLALREADY, func(e Event) { + irc.submitSASLResult(saslResult{false, nil}) + }) } diff --git a/ircevent/irc_sasl_test.go b/ircevent/irc_sasl_test.go index 044b6cb..3ad8997 100644 --- a/ircevent/irc_sasl_test.go +++ b/ircevent/irc_sasl_test.go @@ -51,7 +51,7 @@ func runCAPTest(caps []string, useSASL bool, t *testing.T) { irccon.AddCallback("001", func(e Event) { irccon.Join("#go-eventirc") }) irccon.AddCallback("366", func(e Event) { - irccon.Privmsg("#go-eventirc", "Test Message SASL\n") + irccon.Privmsg("#go-eventirc", "Test Message SASL") irccon.Quit() }) @@ -89,3 +89,18 @@ func TestConnectionNonexistentCAPs(t *testing.T) { func TestConnectionGoodCAPs(t *testing.T) { runCAPTest([]string{"server-time", "message-tags"}, false, t) } + +func TestSASLFail(t *testing.T) { + irccon := connForTesting("go-eventirc", "go-eventirc", true) + irccon.Debug = true + irccon.UseTLS = true + setSaslTestCreds(irccon, t) + irccon.TLSConfig = &tls.Config{InsecureSkipVerify: true} + irccon.AddCallback("001", func(e Event) { irccon.Join("#go-eventirc") }) + // intentionally break the password + irccon.SASLPassword = irccon.SASLPassword + "_" + err := irccon.Connect() + if err == nil { + t.Errorf("successfully connected with invalid password") + } +} diff --git a/ircevent/irc_struct.go b/ircevent/irc_struct.go index 43349cb..685845a 100644 --- a/ircevent/irc_struct.go +++ b/ircevent/irc_struct.go @@ -19,6 +19,11 @@ type empty struct{} type Callback func(Event) +type callbackPair struct { + id uint64 + callback Callback +} + type capResult struct { capName string ack bool @@ -62,9 +67,13 @@ type Connection struct { pingSent bool // we sent PING and are waiting for PONG // IRC protocol connection state - currentNick string // nickname assigned by the server, empty before registration - acknowledgedCaps []string - nickCounter int + currentNick string // nickname assigned by the server, empty before registration + capsAdvertised map[string]string + capsAcked map[string]string + isupport map[string]string + isupportPartial map[string]string + nickCounter int + registered bool // Connect() builds these with sufficient capacity to receive all expected // responses during negotiation. Sends to them are nonblocking, so anything // sent outside of negotiation will not cause the relevant callbacks to block. @@ -73,9 +82,13 @@ type Connection struct { capsChan chan capResult // transmits the final status of each CAP negotiated // callback state - eventsMutex sync.Mutex - events map[string]map[uint64]Callback - idCounter uint64 // assign unique IDs to callbacks + eventsMutex sync.Mutex + events map[string][]callbackPair + // we assign ID numbers to callbacks so they can be removed. normally + // the ID number is globally unique (generated by incrementing this counter). + // if we add a callback in two places we might reuse the number (XXX) + callbackCounter uint64 + // did we initialize the callbacks needed for the library itself? hasBaseCallbacks bool Log *log.Logger diff --git a/ircevent/irc_test.go b/ircevent/irc_test.go index ecdcf15..8ad3208 100644 --- a/ircevent/irc_test.go +++ b/ircevent/irc_test.go @@ -56,27 +56,6 @@ func TestRemoveCallback(t *testing.T) { } } -func TestWildcardCallback(t *testing.T) { - irccon := connForTesting("go-eventirc", "go-eventirc", false) - debugTest(irccon) - - done := make(chan int, 10) - - irccon.AddCallback("TEST", func(e Event) { done <- 1 }) - irccon.AddCallback("*", func(e Event) { done <- 2 }) - - irccon.runCallbacks(mockEvent("TEST")) - - var results []int - - results = append(results, <-done) - results = append(results, <-done) - - if !compareResults(results, 1, 2) { - t.Error("Wildcard callback not called") - } -} - func TestClearCallback(t *testing.T) { irccon := connForTesting("go-eventirc", "go-eventirc", false) debugTest(irccon) @@ -332,3 +311,24 @@ func TestConnectionNickInUse(t *testing.T) { } t.Errorf("expected %s and a suffixed version, got %s and %s", ircnick, nick1, nick2) } + +func TestConnectionCallbacks(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + ircnick := randStr(8) + irccon1 := connForTesting(ircnick, "IRCTest1", false) + debugTest(irccon1) + resultChan := make(chan map[string]string, 1) + irccon1.AddConnectCallback(func(e Event) { + resultChan <- irccon1.ISupport() + }) + err := irccon1.Connect() + if err != nil { + panic(err) + } + go irccon1.Loop() + isupport := <-resultChan + if casemapping := isupport["CASEMAPPING"]; casemapping == "" { + t.Errorf("casemapping not detected in 005 RPL_ISUPPORT output; this is unheard of") + } + irccon1.Quit() +} diff --git a/ircevent/numerics.go b/ircevent/numerics.go new file mode 100644 index 0000000..a94376d --- /dev/null +++ b/ircevent/numerics.go @@ -0,0 +1,17 @@ +package ircevent + +const ( + RPL_WELCOME = "001" + RPL_ISUPPORT = "005" + RPL_ENDOFMOTD = "376" + ERR_NOMOTD = "422" + ERR_NICKNAMEINUSE = "433" + ERR_UNAVAILRESOURCE = "437" + // SASL + RPL_LOGGEDIN = "900" + RPL_LOGGEDOUT = "901" + ERR_NICKLOCKED = "902" + RPL_SASLSUCCESS = "903" + ERR_SASLFAIL = "904" + ERR_SASLALREADY = "907" +)