Settle for caps using traditional locks and maps

This commit is contained in:
kayos@tcp.direct 2022-03-19 15:36:46 -07:00
parent 86271f76fa
commit eee810320c
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
6 changed files with 39 additions and 55 deletions

44
cap.go
View File

@ -9,8 +9,6 @@ import (
"strconv"
"strings"
"time"
cmap "github.com/orcaman/concurrent-map"
)
// Something not in the list? Depending on the type of capability, you can
@ -51,9 +49,7 @@ var possibleCap = map[string][]string{
const capServerTimeFormat = "2006-01-02T15:04:05.999Z"
func (c *Client) listCAP() {
if !c.Config.disableTracking {
c.write(&Event{Command: CAP, Params: []string{CAP_LS, "302"}})
}
c.write(&Event{Command: CAP, Params: []string{CAP_LS, "302"}})
}
func possibleCapList(c *Client) map[string][]string {
@ -120,16 +116,14 @@ func parseCap(raw string) map[string]map[string]string {
// This will lock further registration until we have acknowledged (or denied)
// the capabilities.
func handleCAP(c *Client, e Event) {
c.Handlers.mu.Lock()
defer c.Handlers.mu.Unlock()
c.state.Lock()
defer c.state.Unlock()
if len(e.Params) >= 2 && e.Params[1] == CAP_DEL {
caps := parseCap(e.Last())
for capab := range caps {
for cap := range caps {
// TODO: test the deletion.
c.state.enabledCap.Remove(capab)
delete(c.state.enabledCap, cap)
}
return
}
@ -152,7 +146,7 @@ func handleCAP(c *Client, e Event) {
}
if len(possible[capName]) == 0 || len(caps[capName]) == 0 {
c.state.tmpCap.Set(capName, caps[capName])
c.state.tmpCap[capName] = caps[capName]
continue
}
@ -173,7 +167,7 @@ func handleCAP(c *Client, e Event) {
continue
}
c.state.tmpCap.Set(capName, caps[capName])
c.state.tmpCap[capName] = caps[capName]
}
// Indicates if this is a multi-line LS. (3 args means it's the
@ -186,15 +180,10 @@ func handleCAP(c *Client, e Event) {
}
// Let them know which ones we'd like to enable.
reqKeys := make([]string, len(c.state.tmpCap.Keys()))
reqKeys := make([]string, len(c.state.tmpCap))
i := 0
for k := range c.state.tmpCap.IterBuffered() {
kv := k.Val.(map[string]string)
var index = 0
for _, value := range kv {
reqKeys[index] = value
index++
}
for k := range c.state.tmpCap {
reqKeys[i] = k
i++
}
c.write(&Event{Command: CAP, Params: []string{CAP_REQ, strings.Join(reqKeys, " ")}})
@ -204,12 +193,10 @@ func handleCAP(c *Client, e Event) {
if len(e.Params) == 3 && e.Params[1] == CAP_ACK {
enabled := strings.Split(e.Last(), " ")
for _, capab := range enabled {
val, ok := c.state.tmpCap.Get(capab)
if ok {
val = val.(map[string]string)
c.state.enabledCap.Set(capab, val)
if val, ok := c.state.tmpCap[capab]; ok {
c.state.enabledCap[capab] = val
} else {
c.state.enabledCap.Set(capab, nil)
c.state.enabledCap[capab] = nil
}
}
@ -218,10 +205,8 @@ func handleCAP(c *Client, e Event) {
// Handle STS, and only if it's something specifically we enabled (client
// may choose to disable girc automatic STS, and do it themselves).
stsi, sok := c.state.enabledCap.Get("sts")
if sok && !c.Config.DisableSTS {
if sts, sok := c.state.enabledCap["sts"]; sok && !c.Config.DisableSTS {
var isError bool
sts := stsi.(map[string]string)
// Some things are updated in the policy depending on if the current
// connection is over tls or not.
var hasTLSConnection bool
@ -298,10 +283,9 @@ func handleCAP(c *Client, e Event) {
// Re-initialize the tmpCap, so if we get multiple 'CAP LS' requests
// due to cap-notify, we can re-evaluate what we can support.
c.state.tmpCap = cmap.New()
c.state.tmpCap = make(map[string]map[string]string)
_, ok := c.state.enabledCap.Get("sasl")
if ok && c.Config.SASL != nil {
if _, ok := c.state.enabledCap["sasl"]; ok && c.Config.SASL != nil {
c.write(&Event{Command: AUTHENTICATE, Params: []string{c.Config.SASL.Method()}})
// Don't "CAP END", since we want to authenticate.
return

View File

@ -794,13 +794,15 @@ func (c *Client) HasCapability(name string) (has bool) {
name = strings.ToLower(name)
for _, key := range c.state.enabledCap.Keys() {
c.state.RLock()
for key := range c.state.enabledCap {
key = strings.ToLower(key)
if key == name {
has = true
break
}
}
c.state.RUnlock()
return has
}

View File

@ -335,7 +335,9 @@ startConn:
c.listCAP()
// Then nickname.
c.state.RLock()
c.write(&Event{Command: NICK, Params: []string{c.Config.Nick}})
c.state.RUnlock()
// Then username and realname.
if c.Config.Name == "" {
@ -507,15 +509,15 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, working *int32)
// Check if tags exist on the event. If they do, and message-tags
// isn't a supported capability, remove them from the event.
if event.Tags != nil {
//
c.state.RLock()
var in bool
for i := 0; i < len(c.state.enabledCap); i++ {
if _, ok := c.state.enabledCap.Get("message-tags"); ok {
if _, ok := c.state.enabledCap["message-tags"]; ok {
in = true
break
}
}
//
c.state.RUnlock()
if !in {
event.Tags = Tags{}

View File

@ -134,9 +134,6 @@ func newCTCP() *CTCP {
// call executes the necessary CTCP handler for the incoming event/CTCP
// command.
func (c *CTCP) call(client *Client, event *CTCPEvent) {
c.mu.RLock()
defer c.mu.RUnlock()
// If they want to catch any panics, add to defer stack.
if client.Config.RecoverFunc != nil && event.Origin != nil {
defer recoverHandlerPanic(client, event.Origin, "ctcp-"+strings.ToLower(event.Command), 3)

View File

@ -20,9 +20,6 @@ import (
// RunHandlers manually runs handlers for a given event.
func (c *Client) RunHandlers(event *Event) {
c.mu.RLock()
defer c.mu.RUnlock()
if event == nil {
c.debug.Print("nil event")
return
@ -134,7 +131,7 @@ func (nest *nestedHandlers) getAllHandlersFor(s string) (handlers chan handlerTu
// Caller manages internal and external (user facing) handlers.
type Caller struct {
// mu is the mutex that should be used when accessing handlers.
mu sync.RWMutex
mu *sync.RWMutex
parent *Client
@ -158,6 +155,7 @@ func newCaller(parent *Client, debugOut *log.Logger) *Caller {
internal: newNestedHandlers(),
debug: debugOut,
parent: parent,
mu: &sync.RWMutex{},
}
return c
@ -216,7 +214,6 @@ type execStack struct {
func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
c.mu.RLock()
defer c.mu.RUnlock()
// Build a stack of handlers which can be executed concurrently.
var stack []execStack
@ -259,12 +256,12 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
// execution speed.
var working int32
atomic.AddInt32(&working, int32(len(stack)))
c.debug.Printf("starting %d jobs", atomic.LoadInt32(&working))
// c.debug.Printf("starting %d jobs", atomic.LoadInt32(&working))
for i := 0; i < len(stack); i++ {
go func(index int) {
c.debug.Printf("(%s) [%d/%d] exec %s => %s", c.parent.Config.Nick,
index+1, len(stack), stack[index].cuid, command)
start := time.Now()
// c.debug.Printf("(%s) [%d/%d] exec %s => %s", c.parent.Config.Nick,
// index+1, len(stack), stack[index].cuid, command)
// start := time.Now()
if bg {
go func() {
@ -273,8 +270,8 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
defer recoverHandlerPanic(client, event, stack[index].cuid, 3)
}
stack[index].Handler.Execute(client, *event)
c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick,
stack[index].cuid, time.Since(start))
// c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick,
// stack[index].cuid, time.Since(start))
}()
return
}
@ -285,14 +282,14 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
}
stack[index].Handler.Execute(client, *event)
c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick, stack[index].cuid, time.Since(start))
// c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick, stack[index].cuid, time.Since(start))
}(i)
// new events from becoming ahead of ol1 handlers.
c.debug.Printf("(%s) atomic.CompareAndSwap: %d jobs running", c.parent.Config.Nick, atomic.LoadInt32(&working))
// c.debug.Printf("(%s) atomic.CompareAndSwap: %d jobs running", c.parent.Config.Nick, atomic.LoadInt32(&working))
if atomic.CompareAndSwapInt32(&working, 0, -1) {
c.debug.Printf("(%s) exec stack completed", c.parent.Config.Nick)
// c.debug.Printf("(%s) exec stack completed", c.parent.Config.Nick)
return
}
}
@ -360,6 +357,8 @@ func (c *Caller) remove(cuid string) (ok bool) {
// sregister is much like Caller.register(), except that it safely locks
// the Caller mutex.
func (c *Caller) sregister(internal, bg bool, cmd string, handler Handler) (cuid string) {
c.mu.Lock()
defer c.mu.Unlock()
cuid = c.register(internal, bg, cmd, handler)
return cuid
}

View File

@ -28,11 +28,11 @@ type state struct {
// users map[string]*User
users cmap.ConcurrentMap
// enabledCap are the capabilities which are enabled for this connection.
enabledCap cmap.ConcurrentMap
enabledCap map[string]map[string]string
// tmpCap are the capabilties which we share with the server during the
// last capability check. These will get sent once we have received the
// last capability list command from the server.
tmpCap cmap.ConcurrentMap
tmpCap map[string]map[string]string
// serverOptions are the standard capabilities and configurations
// supported by the server at connection time. This also includes
// RPL_ISUPPORT entries.
@ -69,8 +69,8 @@ func (s *state) reset(initial bool) {
}
}
s.enabledCap = cmap.New()
s.tmpCap = cmap.New()
s.enabledCap = make(map[string]map[string]string)
s.tmpCap = make(map[string]map[string]string)
s.motd = ""
if initial {