Settle for caps using traditional locks and maps

This commit is contained in:
kayos@tcp.direct 2022-03-19 15:36:46 -07:00
parent 86271f76fa
commit eee810320c
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
6 changed files with 39 additions and 55 deletions

44
cap.go

@ -9,8 +9,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
cmap "github.com/orcaman/concurrent-map"
) )
// Something not in the list? Depending on the type of capability, you can // Something not in the list? Depending on the type of capability, you can
@ -51,9 +49,7 @@ var possibleCap = map[string][]string{
const capServerTimeFormat = "2006-01-02T15:04:05.999Z" const capServerTimeFormat = "2006-01-02T15:04:05.999Z"
func (c *Client) listCAP() { func (c *Client) listCAP() {
if !c.Config.disableTracking { c.write(&Event{Command: CAP, Params: []string{CAP_LS, "302"}})
c.write(&Event{Command: CAP, Params: []string{CAP_LS, "302"}})
}
} }
func possibleCapList(c *Client) map[string][]string { func possibleCapList(c *Client) map[string][]string {
@ -120,16 +116,14 @@ func parseCap(raw string) map[string]map[string]string {
// This will lock further registration until we have acknowledged (or denied) // This will lock further registration until we have acknowledged (or denied)
// the capabilities. // the capabilities.
func handleCAP(c *Client, e Event) { func handleCAP(c *Client, e Event) {
c.Handlers.mu.Lock()
defer c.Handlers.mu.Unlock()
c.state.Lock() c.state.Lock()
defer c.state.Unlock() defer c.state.Unlock()
if len(e.Params) >= 2 && e.Params[1] == CAP_DEL { if len(e.Params) >= 2 && e.Params[1] == CAP_DEL {
caps := parseCap(e.Last()) caps := parseCap(e.Last())
for capab := range caps { for cap := range caps {
// TODO: test the deletion. // TODO: test the deletion.
c.state.enabledCap.Remove(capab) delete(c.state.enabledCap, cap)
} }
return return
} }
@ -152,7 +146,7 @@ func handleCAP(c *Client, e Event) {
} }
if len(possible[capName]) == 0 || len(caps[capName]) == 0 { if len(possible[capName]) == 0 || len(caps[capName]) == 0 {
c.state.tmpCap.Set(capName, caps[capName]) c.state.tmpCap[capName] = caps[capName]
continue continue
} }
@ -173,7 +167,7 @@ func handleCAP(c *Client, e Event) {
continue continue
} }
c.state.tmpCap.Set(capName, caps[capName]) c.state.tmpCap[capName] = caps[capName]
} }
// Indicates if this is a multi-line LS. (3 args means it's the // Indicates if this is a multi-line LS. (3 args means it's the
@ -186,15 +180,10 @@ func handleCAP(c *Client, e Event) {
} }
// Let them know which ones we'd like to enable. // Let them know which ones we'd like to enable.
reqKeys := make([]string, len(c.state.tmpCap.Keys())) reqKeys := make([]string, len(c.state.tmpCap))
i := 0 i := 0
for k := range c.state.tmpCap.IterBuffered() { for k := range c.state.tmpCap {
kv := k.Val.(map[string]string) reqKeys[i] = k
var index = 0
for _, value := range kv {
reqKeys[index] = value
index++
}
i++ i++
} }
c.write(&Event{Command: CAP, Params: []string{CAP_REQ, strings.Join(reqKeys, " ")}}) c.write(&Event{Command: CAP, Params: []string{CAP_REQ, strings.Join(reqKeys, " ")}})
@ -204,12 +193,10 @@ func handleCAP(c *Client, e Event) {
if len(e.Params) == 3 && e.Params[1] == CAP_ACK { if len(e.Params) == 3 && e.Params[1] == CAP_ACK {
enabled := strings.Split(e.Last(), " ") enabled := strings.Split(e.Last(), " ")
for _, capab := range enabled { for _, capab := range enabled {
val, ok := c.state.tmpCap.Get(capab) if val, ok := c.state.tmpCap[capab]; ok {
if ok { c.state.enabledCap[capab] = val
val = val.(map[string]string)
c.state.enabledCap.Set(capab, val)
} else { } else {
c.state.enabledCap.Set(capab, nil) c.state.enabledCap[capab] = nil
} }
} }
@ -218,10 +205,8 @@ func handleCAP(c *Client, e Event) {
// Handle STS, and only if it's something specifically we enabled (client // Handle STS, and only if it's something specifically we enabled (client
// may choose to disable girc automatic STS, and do it themselves). // may choose to disable girc automatic STS, and do it themselves).
stsi, sok := c.state.enabledCap.Get("sts") if sts, sok := c.state.enabledCap["sts"]; sok && !c.Config.DisableSTS {
if sok && !c.Config.DisableSTS {
var isError bool var isError bool
sts := stsi.(map[string]string)
// Some things are updated in the policy depending on if the current // Some things are updated in the policy depending on if the current
// connection is over tls or not. // connection is over tls or not.
var hasTLSConnection bool var hasTLSConnection bool
@ -298,10 +283,9 @@ func handleCAP(c *Client, e Event) {
// Re-initialize the tmpCap, so if we get multiple 'CAP LS' requests // Re-initialize the tmpCap, so if we get multiple 'CAP LS' requests
// due to cap-notify, we can re-evaluate what we can support. // due to cap-notify, we can re-evaluate what we can support.
c.state.tmpCap = cmap.New() c.state.tmpCap = make(map[string]map[string]string)
_, ok := c.state.enabledCap.Get("sasl") if _, ok := c.state.enabledCap["sasl"]; ok && c.Config.SASL != nil {
if ok && c.Config.SASL != nil {
c.write(&Event{Command: AUTHENTICATE, Params: []string{c.Config.SASL.Method()}}) c.write(&Event{Command: AUTHENTICATE, Params: []string{c.Config.SASL.Method()}})
// Don't "CAP END", since we want to authenticate. // Don't "CAP END", since we want to authenticate.
return return

@ -794,13 +794,15 @@ func (c *Client) HasCapability(name string) (has bool) {
name = strings.ToLower(name) name = strings.ToLower(name)
for _, key := range c.state.enabledCap.Keys() { c.state.RLock()
for key := range c.state.enabledCap {
key = strings.ToLower(key) key = strings.ToLower(key)
if key == name { if key == name {
has = true has = true
break break
} }
} }
c.state.RUnlock()
return has return has
} }

@ -335,7 +335,9 @@ startConn:
c.listCAP() c.listCAP()
// Then nickname. // Then nickname.
c.state.RLock()
c.write(&Event{Command: NICK, Params: []string{c.Config.Nick}}) c.write(&Event{Command: NICK, Params: []string{c.Config.Nick}})
c.state.RUnlock()
// Then username and realname. // Then username and realname.
if c.Config.Name == "" { if c.Config.Name == "" {
@ -507,15 +509,15 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, working *int32)
// Check if tags exist on the event. If they do, and message-tags // Check if tags exist on the event. If they do, and message-tags
// isn't a supported capability, remove them from the event. // isn't a supported capability, remove them from the event.
if event.Tags != nil { if event.Tags != nil {
// c.state.RLock()
var in bool var in bool
for i := 0; i < len(c.state.enabledCap); i++ { for i := 0; i < len(c.state.enabledCap); i++ {
if _, ok := c.state.enabledCap.Get("message-tags"); ok { if _, ok := c.state.enabledCap["message-tags"]; ok {
in = true in = true
break break
} }
} }
// c.state.RUnlock()
if !in { if !in {
event.Tags = Tags{} event.Tags = Tags{}

@ -134,9 +134,6 @@ func newCTCP() *CTCP {
// call executes the necessary CTCP handler for the incoming event/CTCP // call executes the necessary CTCP handler for the incoming event/CTCP
// command. // command.
func (c *CTCP) call(client *Client, event *CTCPEvent) { func (c *CTCP) call(client *Client, event *CTCPEvent) {
c.mu.RLock()
defer c.mu.RUnlock()
// If they want to catch any panics, add to defer stack. // If they want to catch any panics, add to defer stack.
if client.Config.RecoverFunc != nil && event.Origin != nil { if client.Config.RecoverFunc != nil && event.Origin != nil {
defer recoverHandlerPanic(client, event.Origin, "ctcp-"+strings.ToLower(event.Command), 3) defer recoverHandlerPanic(client, event.Origin, "ctcp-"+strings.ToLower(event.Command), 3)

@ -20,9 +20,6 @@ import (
// RunHandlers manually runs handlers for a given event. // RunHandlers manually runs handlers for a given event.
func (c *Client) RunHandlers(event *Event) { func (c *Client) RunHandlers(event *Event) {
c.mu.RLock()
defer c.mu.RUnlock()
if event == nil { if event == nil {
c.debug.Print("nil event") c.debug.Print("nil event")
return return
@ -134,7 +131,7 @@ func (nest *nestedHandlers) getAllHandlersFor(s string) (handlers chan handlerTu
// Caller manages internal and external (user facing) handlers. // Caller manages internal and external (user facing) handlers.
type Caller struct { type Caller struct {
// mu is the mutex that should be used when accessing handlers. // mu is the mutex that should be used when accessing handlers.
mu sync.RWMutex mu *sync.RWMutex
parent *Client parent *Client
@ -158,6 +155,7 @@ func newCaller(parent *Client, debugOut *log.Logger) *Caller {
internal: newNestedHandlers(), internal: newNestedHandlers(),
debug: debugOut, debug: debugOut,
parent: parent, parent: parent,
mu: &sync.RWMutex{},
} }
return c return c
@ -216,7 +214,6 @@ type execStack struct {
func (c *Caller) exec(command string, bg bool, client *Client, event *Event) { func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
// Build a stack of handlers which can be executed concurrently. // Build a stack of handlers which can be executed concurrently.
var stack []execStack var stack []execStack
@ -259,12 +256,12 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
// execution speed. // execution speed.
var working int32 var working int32
atomic.AddInt32(&working, int32(len(stack))) atomic.AddInt32(&working, int32(len(stack)))
c.debug.Printf("starting %d jobs", atomic.LoadInt32(&working)) // c.debug.Printf("starting %d jobs", atomic.LoadInt32(&working))
for i := 0; i < len(stack); i++ { for i := 0; i < len(stack); i++ {
go func(index int) { go func(index int) {
c.debug.Printf("(%s) [%d/%d] exec %s => %s", c.parent.Config.Nick, // c.debug.Printf("(%s) [%d/%d] exec %s => %s", c.parent.Config.Nick,
index+1, len(stack), stack[index].cuid, command) // index+1, len(stack), stack[index].cuid, command)
start := time.Now() // start := time.Now()
if bg { if bg {
go func() { go func() {
@ -273,8 +270,8 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
defer recoverHandlerPanic(client, event, stack[index].cuid, 3) defer recoverHandlerPanic(client, event, stack[index].cuid, 3)
} }
stack[index].Handler.Execute(client, *event) stack[index].Handler.Execute(client, *event)
c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick, // c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick,
stack[index].cuid, time.Since(start)) // stack[index].cuid, time.Since(start))
}() }()
return return
} }
@ -285,14 +282,14 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
} }
stack[index].Handler.Execute(client, *event) stack[index].Handler.Execute(client, *event)
c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick, stack[index].cuid, time.Since(start)) // c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick, stack[index].cuid, time.Since(start))
}(i) }(i)
// new events from becoming ahead of ol1 handlers. // new events from becoming ahead of ol1 handlers.
c.debug.Printf("(%s) atomic.CompareAndSwap: %d jobs running", c.parent.Config.Nick, atomic.LoadInt32(&working)) // c.debug.Printf("(%s) atomic.CompareAndSwap: %d jobs running", c.parent.Config.Nick, atomic.LoadInt32(&working))
if atomic.CompareAndSwapInt32(&working, 0, -1) { if atomic.CompareAndSwapInt32(&working, 0, -1) {
c.debug.Printf("(%s) exec stack completed", c.parent.Config.Nick) // c.debug.Printf("(%s) exec stack completed", c.parent.Config.Nick)
return return
} }
} }
@ -360,6 +357,8 @@ func (c *Caller) remove(cuid string) (ok bool) {
// sregister is much like Caller.register(), except that it safely locks // sregister is much like Caller.register(), except that it safely locks
// the Caller mutex. // the Caller mutex.
func (c *Caller) sregister(internal, bg bool, cmd string, handler Handler) (cuid string) { func (c *Caller) sregister(internal, bg bool, cmd string, handler Handler) (cuid string) {
c.mu.Lock()
defer c.mu.Unlock()
cuid = c.register(internal, bg, cmd, handler) cuid = c.register(internal, bg, cmd, handler)
return cuid return cuid
} }

@ -28,11 +28,11 @@ type state struct {
// users map[string]*User // users map[string]*User
users cmap.ConcurrentMap users cmap.ConcurrentMap
// enabledCap are the capabilities which are enabled for this connection. // enabledCap are the capabilities which are enabled for this connection.
enabledCap cmap.ConcurrentMap enabledCap map[string]map[string]string
// tmpCap are the capabilties which we share with the server during the // tmpCap are the capabilties which we share with the server during the
// last capability check. These will get sent once we have received the // last capability check. These will get sent once we have received the
// last capability list command from the server. // last capability list command from the server.
tmpCap cmap.ConcurrentMap tmpCap map[string]map[string]string
// serverOptions are the standard capabilities and configurations // serverOptions are the standard capabilities and configurations
// supported by the server at connection time. This also includes // supported by the server at connection time. This also includes
// RPL_ISUPPORT entries. // RPL_ISUPPORT entries.
@ -69,8 +69,8 @@ func (s *state) reset(initial bool) {
} }
} }
s.enabledCap = cmap.New() s.enabledCap = make(map[string]map[string]string)
s.tmpCap = cmap.New() s.tmpCap = make(map[string]map[string]string)
s.motd = "" s.motd = ""
if initial { if initial {