Merge pull request #42 from slingamn/isupport_and_302.1

refactor callback/protocol handling
This commit is contained in:
Shivaram Lingamneni 2021-03-01 16:56:28 -05:00 committed by GitHub
commit a1d30e7a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 365 additions and 135 deletions

@ -36,7 +36,13 @@ func main() {
SASLPassword: saslPassword,
}
irc.AddCallback("001", func(e ircevent.Event) { irc.Join(channel) })
irc.AddConnectCallback(func(e ircevent.Event) {
// attempt to set the BOT mode on ourself:
if botMode := irc.ISupport()["BOT"]; botMode != "" {
irc.Send("MODE", irc.CurrentNick(), "+"+botMode)
}
irc.Join(channel)
})
irc.AddCallback("JOIN", func(e ircevent.Event) {}) // TODO try to rejoin if we *don't* get this
irc.AddCallback("PRIVMSG", func(e ircevent.Event) {
if len(e.Params) < 2 {

@ -446,13 +446,23 @@ func (irc *Connection) setCurrentNick(nick string) {
irc.currentNick = nick
}
// Return IRCv3 CAPs actually enabled on the connection.
func (irc *Connection) AcknowledgedCaps(result []string) {
// Return IRCv3 CAPs actually enabled on the connection, together
// with their values if applicable. The resulting map is shared,
// so do not modify it.
func (irc *Connection) AcknowledgedCaps() (result map[string]string) {
irc.stateMutex.Lock()
defer irc.stateMutex.Unlock()
result = make([]string, len(irc.acknowledgedCaps))
copy(result[:], irc.acknowledgedCaps[:])
return
return irc.capsAcked
}
// Returns the 005 RPL_ISUPPORT tokens sent by the server when the
// connection was initiated, parsed into key-value form as a map.
// The resulting map is shared, so do not modify it.
func (irc *Connection) ISupport() (result map[string]string) {
irc.stateMutex.Lock()
defer irc.stateMutex.Unlock()
// XXX modifications to isupport are not permitted after registration
return irc.isupport
}
// Returns true if the connection is connected to an IRC server.
@ -498,7 +508,6 @@ func (irc *Connection) Connect() (err error) {
}
// mark Server as stopped since there can be an error during connect
irc.acknowledgedCaps = nil
irc.running = false
irc.socket = nil
irc.currentNick = ""
@ -588,6 +597,11 @@ func (irc *Connection) Connect() (err error) {
irc.capsChan = make(chan capResult, len(irc.RequestCaps))
irc.saslChan = make(chan saslResult, 1)
irc.welcomeChan = make(chan empty, 1)
irc.registered = false
irc.isupportPartial = make(map[string]string)
irc.isupport = nil
irc.capsAcked = make(map[string]string)
irc.capsAdvertised = nil
irc.stateMutex.Unlock()
go irc.readLoop()
@ -642,10 +656,12 @@ func (irc *Connection) negotiateCaps() error {
defer func() {
irc.stateMutex.Lock()
defer irc.stateMutex.Unlock()
irc.acknowledgedCaps = acknowledgedCaps
for _, c := range acknowledgedCaps {
irc.capsAcked[c] = irc.capsAdvertised[c]
}
}()
irc.Send("CAP", "LS")
irc.Send("CAP", "LS", "302")
defer func() {
irc.Send("CAP", "END")
}()
@ -670,6 +686,11 @@ func (irc *Connection) negotiateCaps() error {
}
if irc.UseSASL {
if !sliceContains("sasl", acknowledgedCaps) {
return SASLFailed
} else {
irc.Send("AUTHENTICATE", irc.SASLMech)
}
select {
case res := <-irc.saslChan:
if res.Failed {

@ -4,11 +4,17 @@ import (
"fmt"
"log"
"runtime/debug"
"strconv"
"strings"
"github.com/goshuirc/irc-go/ircmsg"
)
const (
// fake events that we manage specially
registrationEvent = "*REGISTRATION"
)
// Tuple type for uniquely identifying callbacks
type CallbackID struct {
eventCode string
@ -16,27 +22,44 @@ type CallbackID struct {
}
// Register a callback to a connection and event code. A callback is a function
// which takes only an Event pointer as parameter. Valid event codes are all
// IRC/CTCP commands and error/response codes. To register a callback for all
// events pass "*" as the event code. This function returns the ID of the
// which takes only an Event object as parameter. Valid event codes are all
// IRC/CTCP commands and error/response codes. This function returns the ID of the
// registered callback for later management.
func (irc *Connection) AddCallback(eventCode string, callback func(Event)) CallbackID {
return irc.addCallback(eventCode, Callback(callback), false, 0)
}
func (irc *Connection) addCallback(eventCode string, callback Callback, prepend bool, idNum uint64) CallbackID {
eventCode = strings.ToUpper(eventCode)
if eventCode == "" || strings.HasPrefix(eventCode, "*") {
return CallbackID{}
}
irc.eventsMutex.Lock()
defer irc.eventsMutex.Unlock()
if irc.events == nil {
irc.events = make(map[string]map[uint64]Callback)
irc.events = make(map[string][]callbackPair)
}
_, ok := irc.events[eventCode]
if !ok {
irc.events[eventCode] = make(map[uint64]Callback)
if idNum == 0 {
idNum = irc.callbackCounter
irc.callbackCounter++
}
id := CallbackID{eventCode: eventCode, id: irc.idCounter}
irc.idCounter++
irc.events[eventCode][id.id] = Callback(callback)
id := CallbackID{eventCode: eventCode, id: idNum}
newPair := callbackPair{id: id.id, callback: callback}
current := irc.events[eventCode]
newList := make([]callbackPair, len(current)+1)
start := 0
if prepend {
newList[start] = newPair
start++
}
copy(newList[start:], current)
if !prepend {
newList[len(newList)-1] = newPair
}
irc.events[eventCode] = newList
return id
}
@ -44,7 +67,27 @@ func (irc *Connection) AddCallback(eventCode string, callback func(Event)) Callb
func (irc *Connection) RemoveCallback(id CallbackID) {
irc.eventsMutex.Lock()
defer irc.eventsMutex.Unlock()
delete(irc.events[id.eventCode], id.id)
switch id.eventCode {
case registrationEvent:
irc.removeCallbackNoMutex(RPL_ENDOFMOTD, id.id)
irc.removeCallbackNoMutex(ERR_NOMOTD, id.id)
default:
irc.removeCallbackNoMutex(id.eventCode, id.id)
}
}
func (irc *Connection) removeCallbackNoMutex(code string, id uint64) {
current := irc.events[code]
if len(current) == 0 {
return
}
newList := make([]callbackPair, 0, len(current)-1)
for _, p := range current {
if p.id != id {
newList = append(newList, p)
}
}
irc.events[code] = newList
}
// Remove all callbacks from a given event code.
@ -61,33 +104,32 @@ func (irc *Connection) ReplaceCallback(id CallbackID, callback func(Event)) bool
irc.eventsMutex.Lock()
defer irc.eventsMutex.Unlock()
if _, ok := irc.events[id.eventCode][id.id]; ok {
irc.events[id.eventCode][id.id] = callback
return true
list := irc.events[id.eventCode]
for i, p := range list {
if p.id == id.id {
list[i] = callbackPair{id: id.id, callback: callback}
return true
}
}
return false
}
func (irc *Connection) getCallbacks(code string) (result []Callback) {
// Convenience function to add a callback that will be called once the
// connection is completed (this is traditionally referred to as "connection
// registration").
func (irc *Connection) AddConnectCallback(callback func(Event)) (id CallbackID) {
// XXX: forcibly use the same ID number for both copies of the callback
id376 := irc.AddCallback(RPL_ENDOFMOTD, callback)
irc.addCallback(ERR_NOMOTD, callback, false, id376.id)
return CallbackID{eventCode: registrationEvent, id: id376.id}
}
func (irc *Connection) getCallbacks(code string) (result []callbackPair) {
code = strings.ToUpper(code)
irc.eventsMutex.Lock()
defer irc.eventsMutex.Unlock()
cMap := irc.events[code]
starMap := irc.events["*"]
length := len(cMap) + len(starMap)
if length == 0 {
return
}
result = make([]Callback, 0, length)
for _, c := range cMap {
result = append(result, c)
}
for _, c := range starMap {
result = append(result, c)
}
return
return irc.events[code]
}
// Execute all callbacks associated with a given event.
@ -106,12 +148,12 @@ func (irc *Connection) runCallbacks(msg ircmsg.IRCMessage) {
eventRewriteCTCP(&event)
}
callbacks := irc.getCallbacks(event.Command)
callbackPairs := irc.getCallbacks(event.Command)
// just run the callbacks in serial, since it's not safe for them
// to take a long time to execute in any case
for _, callback := range callbacks {
callback(event)
for _, pair := range callbackPairs {
pair.callback(event)
}
}
@ -136,12 +178,15 @@ func (irc *Connection) setupCallbacks() {
// 433: ERR_NICKNAMEINUSE "<nick> :Nickname is already in use"
// 437: ERR_UNAVAILRESOURCE "<nick/channel> :Nick/channel is temporarily unavailable"
irc.AddCallback("433", irc.handleUnavailableNick)
irc.AddCallback("437", irc.handleUnavailableNick)
irc.AddCallback(ERR_NICKNAMEINUSE, irc.handleUnavailableNick)
irc.AddCallback(ERR_UNAVAILRESOURCE, irc.handleUnavailableNick)
// 1: RPL_WELCOME "Welcome to the Internet Relay Network <nick>!<user>@<host>"
// 001: RPL_WELCOME "Welcome to the Internet Relay Network <nick>!<user>@<host>"
// Set irc.currentNick to the actually used nick in this connection.
irc.AddCallback("001", irc.handleRplWelcome)
irc.AddCallback(RPL_WELCOME, irc.handleRplWelcome)
// 005: RPL_ISUPPORT, conveys supported server features
irc.AddCallback(RPL_ISUPPORT, irc.handleISupport)
// respond to NICK from the server (in response to our own NICK, or sent unprompted)
irc.AddCallback("NICK", func(e Event) {
@ -156,36 +201,7 @@ func (irc *Connection) setupCallbacks() {
}
})
irc.AddCallback("CAP", func(e Event) {
if len(e.Params) != 3 {
return
}
command := e.Params[1]
capsChan := irc.capsChan
// TODO this assumes all the caps on one line
// TODO support CAP LS 302
if command == "LS" {
capsList := strings.Fields(e.Params[2])
for _, capName := range irc.RequestCaps {
if sliceContains(capName, capsList) {
irc.Send("CAP", "REQ", capName)
} else {
select {
case capsChan <- capResult{capName, false}:
default:
}
}
}
} else if command == "ACK" || command == "NAK" {
for _, capName := range strings.Fields(e.Params[2]) {
select {
case capsChan <- capResult{capName, command == "ACK"}:
default:
}
}
}
})
irc.AddCallback("CAP", irc.handleCAP)
if irc.UseSASL {
irc.setupSASLCallbacks()
@ -194,6 +210,11 @@ func (irc *Connection) setupCallbacks() {
if irc.EnableCTCP {
irc.setupCTCPCallbacks()
}
// prepend our own callbacks for the end of registration,
// so they happen before any client-added callbacks
irc.addCallback(RPL_ENDOFMOTD, irc.handleRegistration, true, 0)
irc.addCallback(ERR_NOMOTD, irc.handleRegistration, true, 0)
}
func (irc *Connection) handleRplWelcome(e Event) {
@ -204,12 +225,29 @@ func (irc *Connection) handleRplWelcome(e Event) {
if len(e.Params) > 0 {
irc.currentNick = e.Params[0]
}
}
func (irc *Connection) handleRegistration(e Event) {
// wake up Connect() if applicable
select {
case irc.welcomeChan <- empty{}:
default:
defer func() {
select {
case irc.welcomeChan <- empty{}:
default:
}
}()
irc.stateMutex.Lock()
defer irc.stateMutex.Unlock()
if irc.registered {
return
}
irc.registered = true
// mark the isupport complete
irc.isupport = irc.isupportPartial
irc.isupportPartial = nil
}
func (irc *Connection) handleUnavailableNick(e Event) {
@ -229,3 +267,126 @@ func (irc *Connection) handleUnavailableNick(e Event) {
irc.Send("NICK", nickToTry)
}
}
func (irc *Connection) handleISupport(e Event) {
irc.stateMutex.Lock()
defer irc.stateMutex.Unlock()
// TODO handle 005 changes after registration
if irc.isupportPartial == nil {
return
}
if len(e.Params) < 3 {
return
}
for _, token := range e.Params[1 : len(e.Params)-1] {
equalsIdx := strings.IndexByte(token, '=')
if equalsIdx == -1 {
irc.isupportPartial[token] = "" // no value
} else {
irc.isupportPartial[token[:equalsIdx]] = unescapeISupportValue(token[equalsIdx+1:])
}
}
}
func unescapeISupportValue(in string) (out string) {
if strings.IndexByte(in, '\\') == -1 {
return in
}
var buf strings.Builder
for i := 0; i < len(in); {
if in[i] == '\\' && i+3 < len(in) && in[i+1] == 'x' {
hex := in[i+2 : i+4]
if octet, err := strconv.ParseInt(hex, 16, 8); err == nil {
buf.WriteByte(byte(octet))
i += 4
continue
}
}
buf.WriteByte(in[i])
i++
}
return buf.String()
}
func (irc *Connection) handleCAP(e Event) {
if len(e.Params) < 3 {
return
}
ack := false
// CAP <NICK | * > <SUBCOMMAND> PARAMS...
switch e.Params[1] {
case "LS":
irc.handleCAPLS(e.Params[2:])
case "ACK":
ack = true
fallthrough
case "NAK":
for _, token := range strings.Fields(e.Params[2]) {
name, _ := splitCAPToken(token)
if sliceContains(name, irc.RequestCaps) {
select {
case irc.capsChan <- capResult{capName: name, ack: ack}:
default:
}
}
}
}
}
func (irc *Connection) handleCAPLS(params []string) {
var capsToReq, capsNotFound []string
defer func() {
for _, c := range capsToReq {
irc.Send("CAP", "REQ", c)
}
for _, c := range capsNotFound {
select {
case irc.capsChan <- capResult{capName: c, ack: false}:
default:
}
}
}()
irc.stateMutex.Lock()
defer irc.stateMutex.Unlock()
if irc.registered {
// TODO server could probably trick us into panic here by sending
// additional LS before the end of registration
return
}
if irc.capsAdvertised == nil {
irc.capsAdvertised = make(map[string]string)
}
// multiline responses to CAP LS 302 start with a 4-parameter form:
// CAP * LS * :account-notify away-notify [...]
// and end with a 3-parameter form:
// CAP * LS :userhost-in-names znc.in/playback [...]
final := len(params) == 1
for _, token := range strings.Fields(params[len(params)-1]) {
name, value := splitCAPToken(token)
irc.capsAdvertised[name] = value
}
if final {
for _, c := range irc.RequestCaps {
if _, ok := irc.capsAdvertised[c]; ok {
capsToReq = append(capsToReq, c)
} else {
capsNotFound = append(capsNotFound, c)
}
}
}
}
func splitCAPToken(token string) (name, value string) {
equalIdx := strings.IndexByte(token, '=')
if equalIdx == -1 {
return token, ""
} else {
return token[:equalIdx], token[equalIdx+1:]
}
}

@ -18,3 +18,19 @@ func TestParse(t *testing.T) {
t.Fatal("Parse failed: host")
}
}
func assertEqual(found, expected string, t *testing.T) {
if found != expected {
t.Errorf("expected `%s`, got `%s`\n", expected, found)
}
}
func TestUnescapeIsupport(t *testing.T) {
assertEqual(unescapeISupportValue(""), "", t)
assertEqual(unescapeISupportValue("a"), "a", t)
assertEqual(unescapeISupportValue(`\x20`), " ", t)
assertEqual(unescapeISupportValue(`\x20b`), " b", t)
assertEqual(unescapeISupportValue(`a\x20`), "a ", t)
assertEqual(unescapeISupportValue(`a\x20b`), "a b", t)
assertEqual(unescapeISupportValue(`\x20\x20`), " ", t)
}

@ -4,7 +4,6 @@ import (
"encoding/base64"
"errors"
"fmt"
"strings"
)
type saslResult struct {
@ -21,16 +20,6 @@ func sliceContains(str string, list []string) bool {
return false
}
// Check if a space-separated list of arguments contains a value.
func listContains(list string, value string) bool {
for _, arg_name := range strings.Split(strings.TrimSpace(list), " ") {
if arg_name == value {
return true
}
}
return false
}
func (irc *Connection) submitSASLResult(r saslResult) {
select {
case irc.saslChan <- r:
@ -39,43 +28,35 @@ func (irc *Connection) submitSASLResult(r saslResult) {
}
func (irc *Connection) setupSASLCallbacks() {
irc.AddCallback("CAP", func(e Event) {
if len(e.Params) == 3 {
if e.Params[1] == "LS" {
if !listContains(e.Params[2], "sasl") {
irc.submitSASLResult(saslResult{true, errors.New("no SASL capability " + e.Params[2])})
}
}
if e.Params[1] == "ACK" && listContains(e.Params[2], "sasl") {
irc.Send("AUTHENTICATE", irc.SASLMech)
}
}
})
irc.AddCallback("AUTHENTICATE", func(e Event) {
str := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s\x00%s\x00%s", irc.SASLLogin, irc.SASLLogin, irc.SASLPassword)))
irc.Send("AUTHENTICATE", str)
})
irc.AddCallback("901", func(e Event) {
irc.AddCallback(RPL_LOGGEDOUT, func(e Event) {
irc.SendRaw("CAP END")
irc.SendRaw("QUIT")
irc.submitSASLResult(saslResult{true, errors.New(e.Params[1])})
})
irc.AddCallback("902", func(e Event) {
irc.AddCallback(ERR_NICKLOCKED, func(e Event) {
irc.SendRaw("CAP END")
irc.SendRaw("QUIT")
irc.submitSASLResult(saslResult{true, errors.New(e.Params[1])})
})
irc.AddCallback("903", func(e Event) {
irc.AddCallback(RPL_SASLSUCCESS, func(e Event) {
irc.submitSASLResult(saslResult{false, nil})
})
irc.AddCallback("904", func(e Event) {
irc.AddCallback(ERR_SASLFAIL, func(e Event) {
irc.SendRaw("CAP END")
irc.SendRaw("QUIT")
irc.submitSASLResult(saslResult{true, errors.New(e.Params[1])})
})
// this could potentially happen with auto-login via certfp?
irc.AddCallback(ERR_SASLALREADY, func(e Event) {
irc.submitSASLResult(saslResult{false, nil})
})
}

@ -51,7 +51,7 @@ func runCAPTest(caps []string, useSASL bool, t *testing.T) {
irccon.AddCallback("001", func(e Event) { irccon.Join("#go-eventirc") })
irccon.AddCallback("366", func(e Event) {
irccon.Privmsg("#go-eventirc", "Test Message SASL\n")
irccon.Privmsg("#go-eventirc", "Test Message SASL")
irccon.Quit()
})
@ -89,3 +89,18 @@ func TestConnectionNonexistentCAPs(t *testing.T) {
func TestConnectionGoodCAPs(t *testing.T) {
runCAPTest([]string{"server-time", "message-tags"}, false, t)
}
func TestSASLFail(t *testing.T) {
irccon := connForTesting("go-eventirc", "go-eventirc", true)
irccon.Debug = true
irccon.UseTLS = true
setSaslTestCreds(irccon, t)
irccon.TLSConfig = &tls.Config{InsecureSkipVerify: true}
irccon.AddCallback("001", func(e Event) { irccon.Join("#go-eventirc") })
// intentionally break the password
irccon.SASLPassword = irccon.SASLPassword + "_"
err := irccon.Connect()
if err == nil {
t.Errorf("successfully connected with invalid password")
}
}

@ -19,6 +19,11 @@ type empty struct{}
type Callback func(Event)
type callbackPair struct {
id uint64
callback Callback
}
type capResult struct {
capName string
ack bool
@ -62,9 +67,13 @@ type Connection struct {
pingSent bool // we sent PING and are waiting for PONG
// IRC protocol connection state
currentNick string // nickname assigned by the server, empty before registration
acknowledgedCaps []string
nickCounter int
currentNick string // nickname assigned by the server, empty before registration
capsAdvertised map[string]string
capsAcked map[string]string
isupport map[string]string
isupportPartial map[string]string
nickCounter int
registered bool
// Connect() builds these with sufficient capacity to receive all expected
// responses during negotiation. Sends to them are nonblocking, so anything
// sent outside of negotiation will not cause the relevant callbacks to block.
@ -73,9 +82,13 @@ type Connection struct {
capsChan chan capResult // transmits the final status of each CAP negotiated
// callback state
eventsMutex sync.Mutex
events map[string]map[uint64]Callback
idCounter uint64 // assign unique IDs to callbacks
eventsMutex sync.Mutex
events map[string][]callbackPair
// we assign ID numbers to callbacks so they can be removed. normally
// the ID number is globally unique (generated by incrementing this counter).
// if we add a callback in two places we might reuse the number (XXX)
callbackCounter uint64
// did we initialize the callbacks needed for the library itself?
hasBaseCallbacks bool
Log *log.Logger

@ -56,27 +56,6 @@ func TestRemoveCallback(t *testing.T) {
}
}
func TestWildcardCallback(t *testing.T) {
irccon := connForTesting("go-eventirc", "go-eventirc", false)
debugTest(irccon)
done := make(chan int, 10)
irccon.AddCallback("TEST", func(e Event) { done <- 1 })
irccon.AddCallback("*", func(e Event) { done <- 2 })
irccon.runCallbacks(mockEvent("TEST"))
var results []int
results = append(results, <-done)
results = append(results, <-done)
if !compareResults(results, 1, 2) {
t.Error("Wildcard callback not called")
}
}
func TestClearCallback(t *testing.T) {
irccon := connForTesting("go-eventirc", "go-eventirc", false)
debugTest(irccon)
@ -332,3 +311,24 @@ func TestConnectionNickInUse(t *testing.T) {
}
t.Errorf("expected %s and a suffixed version, got %s and %s", ircnick, nick1, nick2)
}
func TestConnectionCallbacks(t *testing.T) {
rand.Seed(time.Now().UnixNano())
ircnick := randStr(8)
irccon1 := connForTesting(ircnick, "IRCTest1", false)
debugTest(irccon1)
resultChan := make(chan map[string]string, 1)
irccon1.AddConnectCallback(func(e Event) {
resultChan <- irccon1.ISupport()
})
err := irccon1.Connect()
if err != nil {
panic(err)
}
go irccon1.Loop()
isupport := <-resultChan
if casemapping := isupport["CASEMAPPING"]; casemapping == "" {
t.Errorf("casemapping not detected in 005 RPL_ISUPPORT output; this is unheard of")
}
irccon1.Quit()
}

17
ircevent/numerics.go Normal file

@ -0,0 +1,17 @@
package ircevent
const (
RPL_WELCOME = "001"
RPL_ISUPPORT = "005"
RPL_ENDOFMOTD = "376"
ERR_NOMOTD = "422"
ERR_NICKNAMEINUSE = "433"
ERR_UNAVAILRESOURCE = "437"
// SASL
RPL_LOGGEDIN = "900"
RPL_LOGGEDOUT = "901"
ERR_NICKLOCKED = "902"
RPL_SASLSUCCESS = "903"
ERR_SASLFAIL = "904"
ERR_SASLALREADY = "907"
)