Implement cmaps and break everything

This commit is contained in:
kayos@tcp.direct 2022-03-16 04:30:59 -07:00
parent fa2aba1ef2
commit 23cea998f1
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
16 changed files with 356 additions and 366 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
.idea
*.save

View File

@ -7,7 +7,6 @@ package girc
import ( import (
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/araddon/dateparse" "github.com/araddon/dateparse"
@ -110,7 +109,7 @@ func handleConnect(c *Client, e Event) {
split := strings.Split(e.Params[1], " ") split := strings.Split(e.Params[1], " ")
if strings.HasPrefix(e.Params[1], "Welcome to the") && len(split) > 3 { if strings.HasPrefix(e.Params[1], "Welcome to the") && len(split) > 3 {
if len(split[3]) > 0 { if len(split[3]) > 0 {
c.state.network.Store(split[3]) c.state.network = split[3]
c.IRCd.Network = split[3] c.IRCd.Network = split[3]
} }
} }
@ -153,9 +152,6 @@ func handleJOIN(c *Client, e Event) {
channelName := e.Params[0] channelName := e.Params[0]
c.state.Lock()
defer c.state.Unlock()
channel := c.state.lookupChannel(channelName) channel := c.state.lookupChannel(channelName)
if channel == nil { if channel == nil {
if ok := c.state.createChannel(channelName); !ok { if ok := c.state.createChannel(channelName); !ok {
@ -167,7 +163,7 @@ func handleJOIN(c *Client, e Event) {
user := c.state.lookupUser(e.Source.Name) user := c.state.lookupUser(e.Source.Name)
if user == nil { if user == nil {
if ok := c.state.createUser(e.Source); !ok { if _, ok := c.state.createUser(e.Source); !ok {
return return
} }
user = c.state.lookupUser(e.Source.Name) user = c.state.lookupUser(e.Source.Name)
@ -225,15 +221,14 @@ func handlePART(c *Client, e Event) {
defer c.state.notify(c, UPDATE_STATE) defer c.state.notify(c, UPDATE_STATE)
if e.Source.ID() == c.GetID() { if e.Source.ID() == c.GetID() {
c.state.Lock()
c.state.deleteChannel(channel) c.state.deleteChannel(channel)
c.state.Unlock()
return return
} }
c.state.Lock()
c.state.deleteUser(channel, e.Source.ID()) c.state.deleteUser(channel, e.Source.ID())
c.state.Unlock()
} }
// handleCREATIONTIME handles incoming TOPIC events and keeps channel tracking info // handleCREATIONTIME handles incoming TOPIC events and keeps channel tracking info
@ -250,9 +245,6 @@ func handleCREATIONTIME(c *Client, e Event) {
break break
} }
c.state.Lock()
defer c.state.Unlock()
channel := c.state.lookupChannel(name) channel := c.state.lookupChannel(name)
if channel == nil { if channel == nil {
return return
@ -275,15 +267,14 @@ func handleTOPIC(c *Client, e Event) {
name = e.Params[1] name = e.Params[1]
} }
c.state.Lock()
channel := c.state.lookupChannel(name) channel := c.state.lookupChannel(name)
if channel == nil { if channel == nil {
c.state.Unlock()
return return
} }
channel.Topic = e.Last() channel.Topic = e.Last()
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }
@ -328,15 +319,12 @@ func handleWHO(c *Client, e Event) {
} }
} }
c.state.Lock()
defer c.state.Unlock()
user := c.state.lookupUser(nick) user := c.state.lookupUser(nick)
if user == nil { if user == nil {
c.state.createUser(&Source{nick, ident, host}) usr, _ := c.state.createUser(&Source{nick, ident, host})
c.state.users[nick].Extras.Name = realname usr.Extras.Name = realname
if account != "0" { if account != "0" {
c.state.users[nick].Extras.Account = account usr.Extras.Account = account
} }
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
return return
@ -365,16 +353,16 @@ func handleKICK(c *Client, e Event) {
defer c.state.notify(c, UPDATE_STATE) defer c.state.notify(c, UPDATE_STATE)
if e.Params[1] == c.GetNick() { if e.Params[1] == c.GetNick() {
c.state.Lock()
c.state.deleteChannel(e.Params[0]) c.state.deleteChannel(e.Params[0])
c.state.Unlock()
return return
} }
// Assume it's just another user. // Assume it's just another user.
c.state.Lock()
c.state.deleteUser(e.Params[0], e.Params[1]) c.state.deleteUser(e.Params[0], e.Params[1])
c.state.Unlock()
} }
// handleNICK ensures that users are renamed in state, or the client name is // handleNICK ensures that users are renamed in state, or the client name is
@ -384,12 +372,11 @@ func handleNICK(c *Client, e Event) {
return return
} }
c.state.Lock()
// renameUser updates the LastActive time automatically. // renameUser updates the LastActive time automatically.
if len(e.Params) >= 1 { if len(e.Params) >= 1 {
c.state.renameUser(e.Source.ID(), e.Last()) c.state.renameUser(e.Source.ID(), e.Last())
} }
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }
@ -403,9 +390,8 @@ func handleQUIT(c *Client, e Event) {
return return
} }
c.state.Lock()
c.state.deleteUser("", e.Source.ID()) c.state.deleteUser("", e.Source.ID())
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }
@ -418,8 +404,7 @@ func handleGLOBALUSERS(c *Client, e Event) {
if err != nil { if err != nil {
return return
} }
c.state.Lock()
defer c.state.Unlock()
c.IRCd.UserCount = cusers c.IRCd.UserCount = cusers
c.IRCd.MaxUserCount = musers c.IRCd.MaxUserCount = musers
} }
@ -433,8 +418,7 @@ func handleLOCALUSERS(c *Client, e Event) {
if err != nil { if err != nil {
return return
} }
c.state.Lock()
defer c.state.Unlock()
c.IRCd.LocalUserCount = cusers c.IRCd.LocalUserCount = cusers
c.IRCd.LocalMaxUserCount = musers c.IRCd.LocalMaxUserCount = musers
} }
@ -444,8 +428,7 @@ func handleLUSERCHANNELS(c *Client, e Event) {
if err != nil { if err != nil {
return return
} }
c.state.Lock()
defer c.state.Unlock()
c.IRCd.ChannelCount = ccount c.IRCd.ChannelCount = ccount
} }
@ -454,8 +437,7 @@ func handleLUSEROP(c *Client, e Event) {
if err != nil { if err != nil {
return return
} }
c.state.Lock()
defer c.state.Unlock()
c.IRCd.OperCount = ocount c.IRCd.OperCount = ocount
} }
@ -480,9 +462,9 @@ func handleCREATED(c *Client, e Event) {
if err != nil { if err != nil {
return return
} }
c.state.Lock()
c.IRCd.Compiled = compiled c.IRCd.Compiled = compiled
c.state.Unlock()
c.state.notify(c, UPDATE_GENERAL) c.state.notify(c, UPDATE_GENERAL)
} }
@ -502,10 +484,10 @@ func handleYOURHOST(c *Client, e Event) {
if len(host)+len(ver) == 0 { if len(host)+len(ver) == 0 {
return return
} }
c.state.Lock()
c.IRCd.Host = host c.IRCd.Host = host
c.IRCd.Version = ver c.IRCd.Version = ver
c.state.Unlock()
c.state.notify(c, UPDATE_GENERAL) c.state.notify(c, UPDATE_GENERAL)
} }
@ -529,29 +511,20 @@ func handleISUPPORT(c *Client, e Event) {
split := strings.Split(e.Params[i], "=") split := strings.Split(e.Params[i], "=")
if len(split) != 2 { if len(split) != 2 {
c.mu.Lock() c.state.serverOptions.Set(e.Params[i], "")
c.state.serverOptions[e.Params[i]] = &atomic.Value{}
c.mu.Unlock()
c.state.serverOptions[e.Params[i]].Store("")
continue continue
} }
if len(split[0]) < 1 || len(split[1]) < 1 { if len(split[0]) < 1 || len(split[1]) < 1 {
c.mu.Lock() c.state.serverOptions.Set(e.Params[i], "")
c.state.serverOptions[e.Params[i]] = &atomic.Value{}
c.mu.Unlock()
c.state.serverOptions[e.Params[i]].Store("")
continue continue
} }
if split[0] == "NETWORK" { if split[0] == "NETWORK" {
c.state.network.Store(split[1]) c.state.network = split[1]
} }
c.mu.Lock() c.state.serverOptions.Set(split[0], split[1])
c.state.serverOptions[split[0]] = &atomic.Value{}
c.mu.Unlock()
c.state.serverOptions[split[0]].Store(split[1])
} }
c.state.notify(c, UPDATE_GENERAL) c.state.notify(c, UPDATE_GENERAL)
@ -560,7 +533,6 @@ func handleISUPPORT(c *Client, e Event) {
// handleMOTD handles incoming MOTD messages and buffers them up for use with // handleMOTD handles incoming MOTD messages and buffers them up for use with
// Client.ServerMOTD(). // Client.ServerMOTD().
func handleMOTD(c *Client, e Event) { func handleMOTD(c *Client, e Event) {
c.state.Lock()
defer c.state.notify(c, UPDATE_GENERAL) defer c.state.notify(c, UPDATE_GENERAL)
@ -568,7 +540,6 @@ func handleMOTD(c *Client, e Event) {
if e.Command == RPL_MOTDSTART { if e.Command == RPL_MOTDSTART {
c.state.motd = "" c.state.motd = ""
c.state.Unlock()
return return
} }
@ -577,7 +548,7 @@ func handleMOTD(c *Client, e Event) {
c.state.motd += "\n" c.state.motd += "\n"
} }
c.state.motd += e.Last() c.state.motd += e.Last()
c.state.Unlock()
} }
// handleNAMES handles incoming NAMES queries, of which lists all users in // handleNAMES handles incoming NAMES queries, of which lists all users in
@ -598,7 +569,6 @@ func handleNAMES(c *Client, e Event) {
var modes, nick string var modes, nick string
var ok bool var ok bool
c.state.Lock()
for i := 0; i < len(parts); i++ { for i := 0; i < len(parts); i++ {
modes, nick, ok = parseUserPrefix(parts[i]) modes, nick, ok = parseUserPrefix(parts[i])
if !ok { if !ok {
@ -638,7 +608,7 @@ func handleNAMES(c *Client, e Event) {
perms.set(modes, false) perms.set(modes, false)
user.Perms.set(channel.Name, perms) user.Perms.set(channel.Name, perms)
} }
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }
@ -651,9 +621,6 @@ func updateLastActive(c *Client, e Event) {
return return
} }
c.state.Lock()
defer c.state.Unlock()
// Update the users last active time, if they exist. // Update the users last active time, if they exist.
user := c.state.lookupUser(e.Source.Name) user := c.state.lookupUser(e.Source.Name)
if user == nil { if user == nil {

8
cap.go
View File

@ -118,8 +118,6 @@ func parseCap(raw string) map[string]map[string]string {
// This will lock further registration until we have acknowledged (or denied) // This will lock further registration until we have acknowledged (or denied)
// the capabilities. // the capabilities.
func handleCAP(c *Client, e Event) { func handleCAP(c *Client, e Event) {
c.state.Lock()
defer c.state.Unlock()
if len(e.Params) >= 2 && e.Params[1] == CAP_DEL { if len(e.Params) >= 2 && e.Params[1] == CAP_DEL {
caps := parseCap(e.Last()) caps := parseCap(e.Last())
@ -309,13 +307,11 @@ func handleCHGHOST(c *Client, e Event) {
return return
} }
c.state.Lock()
user := c.state.lookupUser(e.Source.Name) user := c.state.lookupUser(e.Source.Name)
if user != nil { if user != nil {
user.Ident = e.Params[0] user.Ident = e.Params[0]
user.Host = e.Params[1] user.Host = e.Params[1]
} }
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }
@ -323,12 +319,12 @@ func handleCHGHOST(c *Client, e Event) {
// handleAWAY handles incoming IRCv3 AWAY events, for which are sent both // handleAWAY handles incoming IRCv3 AWAY events, for which are sent both
// when users are no longer away, or when they are away. // when users are no longer away, or when they are away.
func handleAWAY(c *Client, e Event) { func handleAWAY(c *Client, e Event) {
c.state.Lock()
user := c.state.lookupUser(e.Source.Name) user := c.state.lookupUser(e.Source.Name)
if user != nil { if user != nil {
user.Extras.Away = e.Last() user.Extras.Away = e.Last()
} }
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }

View File

@ -24,12 +24,11 @@ func handleTags(c *Client, e Event) {
return return
} }
c.state.Lock()
user := c.state.lookupUser(e.Source.ID()) user := c.state.lookupUser(e.Source.ID())
if user != nil { if user != nil {
user.Extras.Account = account user.Extras.Account = account
} }
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }

View File

@ -16,6 +16,7 @@ func TestCapSupported(t *testing.T) {
User: "user", User: "user",
SASL: &SASLPlain{User: "test", Pass: "example"}, SASL: &SASLPlain{User: "test", Pass: "example"},
SupportedCaps: map[string][]string{"example": nil}, SupportedCaps: map[string][]string{"example": nil},
Debug: newDebugWriter(t),
}) })
var ok bool var ok bool

View File

@ -21,6 +21,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
cmap "github.com/orcaman/concurrent-map"
) )
// Client contains all of the information necessary to run a single IRC // Client contains all of the information necessary to run a single IRC
@ -392,12 +394,10 @@ var ErrConnNotTLS = errors.New("underlying connection is not tls")
// safe to call multiple times. See Connect()'s documentation on how // safe to call multiple times. See Connect()'s documentation on how
// handlers and goroutines are handled when disconnected from the server. // handlers and goroutines are handled when disconnected from the server.
func (c *Client) Close() { func (c *Client) Close() {
c.mu.RLock()
if c.stop != nil { if c.stop != nil {
c.debug.Print("requesting client to stop") c.debug.Print("requesting client to stop")
c.stop() c.stop()
} }
c.mu.RUnlock()
} }
// Quit sends a QUIT message to the server with a given reason to close the // Quit sends a QUIT message to the server with a given reason to close the
@ -481,9 +481,7 @@ func (c *Client) DisableTracking() {
c.Config.disableTracking = true c.Config.disableTracking = true
c.Handlers.clearInternal() c.Handlers.clearInternal()
c.state.Lock() c.state.channels.Clear()
c.state.channels = nil
c.state.Unlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
c.registerBuiltins() c.registerBuiltins()
@ -598,12 +596,12 @@ func (c *Client) GetHost() (host string) {
func (c *Client) ChannelList() []string { func (c *Client) ChannelList() []string {
c.panicIfNotTracking() c.panicIfNotTracking()
c.state.RLock()
channels := make([]string, 0, len(c.state.channels)) channels := make([]string, 0, len(c.state.channels))
for channel := range c.state.channels { for channel := range c.state.channels.IterBuffered() {
channels = append(channels, c.state.channels[channel].Name) chn := channel.Val.(*Channel)
channels = append(channels, chn.Name)
} }
c.state.RUnlock()
sort.Strings(channels) sort.Strings(channels)
return channels return channels
} }
@ -613,12 +611,11 @@ func (c *Client) ChannelList() []string {
func (c *Client) Channels() []*Channel { func (c *Client) Channels() []*Channel {
c.panicIfNotTracking() c.panicIfNotTracking()
c.state.RLock()
channels := make([]*Channel, 0, len(c.state.channels)) channels := make([]*Channel, 0, len(c.state.channels))
for channel := range c.state.channels { for channel := range c.state.channels.IterBuffered() {
channels = append(channels, c.state.channels[channel].Copy()) chn := channel.Val.(*Channel)
channels = append(channels, chn.Copy())
} }
c.state.RUnlock()
sort.Slice(channels, func(i, j int) bool { sort.Slice(channels, func(i, j int) bool {
return channels[i].Name < channels[j].Name return channels[i].Name < channels[j].Name
@ -631,12 +628,12 @@ func (c *Client) Channels() []*Channel {
func (c *Client) UserList() []string { func (c *Client) UserList() []string {
c.panicIfNotTracking() c.panicIfNotTracking()
c.state.RLock()
users := make([]string, 0, len(c.state.users)) users := make([]string, 0, len(c.state.users))
for user := range c.state.users { for user := range c.state.users.IterBuffered() {
users = append(users, c.state.users[user].Nick) usr := user.Val.(*User)
users = append(users, usr.Nick)
} }
c.state.RUnlock()
sort.Strings(users) sort.Strings(users)
return users return users
} }
@ -646,12 +643,11 @@ func (c *Client) UserList() []string {
func (c *Client) Users() []*User { func (c *Client) Users() []*User {
c.panicIfNotTracking() c.panicIfNotTracking()
c.state.RLock()
users := make([]*User, 0, len(c.state.users)) users := make([]*User, 0, len(c.state.users))
for user := range c.state.users { for user := range c.state.users.IterBuffered() {
users = append(users, c.state.users[user].Copy()) usr := user.Val.(*User)
users = append(users, usr.Copy())
} }
c.state.RUnlock()
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].Nick < users[j].Nick return users[i].Nick < users[j].Nick
@ -667,9 +663,8 @@ func (c *Client) LookupChannel(name string) (channel *Channel) {
return nil return nil
} }
c.state.RLock()
channel = c.state.lookupChannel(name).Copy() channel = c.state.lookupChannel(name).Copy()
c.state.RUnlock()
return channel return channel
} }
@ -681,20 +676,17 @@ func (c *Client) LookupUser(nick string) (user *User) {
return nil return nil
} }
c.state.RLock()
user = c.state.lookupUser(nick).Copy() user = c.state.lookupUser(nick).Copy()
c.state.RUnlock()
return user return user
} }
// IsInChannel returns true if the client is in channel. Panics if tracking // IsInChannel returns true if the client is in channel. Panics if tracking
// is disabled. // is disabled.
// TODO: make sure this still works.
func (c *Client) IsInChannel(channel string) (in bool) { func (c *Client) IsInChannel(channel string) (in bool) {
c.panicIfNotTracking() c.panicIfNotTracking()
_, in = c.state.channels.Get(ToRFC1459(channel))
c.state.RLock()
_, in = c.state.channels[ToRFC1459(channel)]
c.state.RUnlock()
return in return in
} }
@ -707,15 +699,13 @@ func (c *Client) IsInChannel(channel string) (in bool) {
func (c *Client) GetServerOption(key string) (result string, ok bool) { func (c *Client) GetServerOption(key string) (result string, ok bool) {
c.panicIfNotTracking() c.panicIfNotTracking()
c.mu.RLock() oi, ok := c.state.serverOptions.Get(key)
if _, ok := c.state.serverOptions[key]; !ok { if !ok {
c.mu.RUnlock()
return "", ok return "", ok
} }
c.mu.RUnlock() result = oi.(string)
result = c.state.serverOptions[key].Load().(string)
if len(result) > 0 { if len(result) > 0 {
ok = true ok = true
} }
@ -726,23 +716,9 @@ func (c *Client) GetServerOption(key string) (result string, ok bool) {
// GetAllServerOption retrieves all of a server's capability settings that were retrieved // GetAllServerOption retrieves all of a server's capability settings that were retrieved
// during client connection. This is also known as ISUPPORT (or RPL_PROTOCTL). // during client connection. This is also known as ISUPPORT (or RPL_PROTOCTL).
// Will panic if used when tracking has been disabled. // Will panic if used when tracking has been disabled.
func (c *Client) GetAllServerOption() (map[string]string, error) { func (c *Client) GetAllServerOption() <-chan cmap.Tuple {
c.panicIfNotTracking() c.panicIfNotTracking()
return c.state.serverOptions.IterBuffered()
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.state.serverOptions) > 0 {
copied := make(map[string]string)
for k, av := range c.state.serverOptions {
if v := av.Load(); v != nil {
copied[k] = av.Load().(string)
}
}
return copied, nil
} else {
return nil, errors.New("server options is empty")
}
} }
// NetworkName returns the network identifier. E.g. "EsperNet", "ByteIRC". // NetworkName returns the network identifier. E.g. "EsperNet", "ByteIRC".
@ -752,11 +728,8 @@ func (c *Client) NetworkName() (name string) {
c.panicIfNotTracking() c.panicIfNotTracking()
var ok bool var ok bool
if c.state.network.Load() != nil { if len(c.state.network) > 0 {
name = c.state.network.Load().(string) return
if len(name) > 0 {
return
}
} }
name, ok = c.GetServerOption("NETWORK") name, ok = c.GetServerOption("NETWORK")
@ -786,8 +759,7 @@ func (c *Client) ServerVersion() (version string) {
// it upon connect. Will panic if used when tracking has been disabled. // it upon connect. Will panic if used when tracking has been disabled.
func (c *Client) ServerMOTD() (motd string) { func (c *Client) ServerMOTD() (motd string) {
c.panicIfNotTracking() c.panicIfNotTracking()
c.state.RLock()
defer c.state.RUnlock()
return c.state.motd return c.state.motd
} }

View File

@ -10,6 +10,19 @@ import (
"time" "time"
) )
type debugWriter struct {
t *testing.T
}
func newDebugWriter(t *testing.T) debugWriter {
return debugWriter{t: t}
}
func (d debugWriter) Write(p []byte) (n int, err error) {
go d.t.Logf("%v", string(p))
return len(p), nil
}
func TestDisableTracking(t *testing.T) { func TestDisableTracking(t *testing.T) {
client := New(Config{ client := New(Config{
Server: "dummy.int", Server: "dummy.int",
@ -17,21 +30,19 @@ func TestDisableTracking(t *testing.T) {
Nick: "test", Nick: "test",
User: "test", User: "test",
Name: "Testing123", Name: "Testing123",
Debug: newDebugWriter(t),
}) })
if len(client.Handlers.internal) < 1 { if client.Handlers.internal.len() < 1 {
t.Fatal("Client.Handlers empty, though just initialized") t.Fatal("Client.Handlers empty, though just initialized")
} }
client.DisableTracking() client.DisableTracking()
if _, ok := client.Handlers.internal[CAP]; ok { if _, ok := client.Handlers.internal.cm.Get(CAP); ok {
t.Fatal("Client.Handlers contains capability tracking handlers, though disabled") t.Fatal("Client.Handlers contains capability tracking handlers, though disabled")
} }
client.state.Lock() if len(client.state.channels.Keys()) > 0 {
defer client.state.Unlock()
if client.state.channels != nil {
t.Fatal("Client.DisableTracking() called but channel state still exists") t.Fatal("Client.DisableTracking() called but channel state still exists")
} }
} }
@ -85,6 +96,7 @@ func TestClientLifetime(t *testing.T) {
Nick: "test", Nick: "test",
User: "test", User: "test",
Name: "Testing123", Name: "Testing123",
Debug: newDebugWriter(t),
}) })
tm := client.Lifetime() tm := client.Lifetime()
@ -95,7 +107,7 @@ func TestClientLifetime(t *testing.T) {
} }
func TestClientUptime(t *testing.T) { func TestClientUptime(t *testing.T) {
c, conn, server := genMockConn() c, conn, server := genMockConn(t)
defer conn.Close() defer conn.Close()
defer server.Close() defer server.Close()
go mockReadBuffer(conn) go mockReadBuffer(conn)
@ -140,7 +152,7 @@ func TestClientUptime(t *testing.T) {
} }
func TestClientGet(t *testing.T) { func TestClientGet(t *testing.T) {
c, conn, server := genMockConn() c, conn, server := genMockConn(t)
defer conn.Close() defer conn.Close()
defer server.Close() defer server.Close()
go mockReadBuffer(conn) go mockReadBuffer(conn)
@ -171,7 +183,7 @@ func TestClientGet(t *testing.T) {
} }
func TestClientClose(t *testing.T) { func TestClientClose(t *testing.T) {
c, conn, server := genMockConn() c, conn, server := genMockConn(t)
defer server.Close() defer server.Close()
defer conn.Close() defer conn.Close()
go mockReadBuffer(conn) go mockReadBuffer(conn)

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"strings" "strings"
"sync"
"github.com/yunginnanet/girc-atomic" "github.com/yunginnanet/girc-atomic"
) )
@ -72,7 +71,6 @@ type CmdHandler struct {
prefix string prefix string
re *regexp.Regexp re *regexp.Regexp
mu sync.Mutex
cmds map[string]*Command cmds map[string]*Command
} }
@ -116,9 +114,6 @@ func (ch *CmdHandler) Add(cmd *Command) error {
cmd.MinArgs = 0 cmd.MinArgs = 0
} }
ch.mu.Lock()
defer ch.mu.Unlock()
if _, ok := ch.cmds[cmd.Name]; ok { if _, ok := ch.cmds[cmd.Name]; ok {
return fmt.Errorf("command already registered: %s", cmd.Name) return fmt.Errorf("command already registered: %s", cmd.Name)
} }
@ -154,9 +149,6 @@ func (ch *CmdHandler) Execute(client *girc.Client, event girc.Event) {
args = []string{} args = []string{}
} }
ch.mu.Lock()
defer ch.mu.Unlock()
if invCmd == "help" { if invCmd == "help" {
if len(args) == 0 { if len(args) == 0 {
client.Cmd.ReplyTo(event, girc.Fmt("type '{b}!help {blue}<command>{c}{b}' to optionally get more info about a specific command.")) client.Cmd.ReplyTo(event, girc.Fmt("type '{b}!help {blue}<command>{c}{b}' to optionally get more info about a specific command."))

11
conn.go
View File

@ -469,9 +469,6 @@ func (c *Client) Send(event *Event) {
// write is the lower level function to write an event. It does not have a // write is the lower level function to write an event. It does not have a
// write-delay when sending events. // write-delay when sending events.
func (c *Client) write(event *Event) { func (c *Client) write(event *Event) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.conn == nil { if c.conn == nil {
// Drop the event if disconnected. // Drop the event if disconnected.
c.debugLogEvent(event, true) c.debugLogEvent(event, true)
@ -515,7 +512,7 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
// Check if tags exist on the event. If they do, and message-tags // Check if tags exist on the event. If they do, and message-tags
// isn't a supported capability, remove them from the event. // isn't a supported capability, remove them from the event.
if event.Tags != nil { if event.Tags != nil {
// c.state.RLock() //
var in bool var in bool
for i := 0; i < len(c.state.enabledCap); i++ { for i := 0; i < len(c.state.enabledCap); i++ {
if _, ok := c.state.enabledCap["message-tags"]; ok { if _, ok := c.state.enabledCap["message-tags"]; ok {
@ -523,7 +520,7 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
break break
} }
} }
// c.state.RUnlock() //
if !in { if !in {
event.Tags = Tags{} event.Tags = Tags{}
@ -583,9 +580,9 @@ type ErrTimedOut struct {
func (ErrTimedOut) Error() string { return "timed out waiting for a requested PING response" } func (ErrTimedOut) Error() string { return "timed out waiting for a requested PING response" }
func (c *Client) pingLoop(ctx context.Context, errs chan error, wg *sync.WaitGroup) { func (c *Client) pingLoop(ctx context.Context, errs chan error, wg *sync.WaitGroup) {
defer wg.Done()
// Don't run the pingLoop if they want to disable it. // Don't run the pingLoop if they want to disable it.
if c.Config.PingDelay <= 0 { if c.Config.PingDelay <= 0 {
wg.Done()
return return
} }
@ -624,7 +621,6 @@ func (c *Client) pingLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
Delay: c.Config.PingDelay, Delay: c.Config.PingDelay,
} }
wg.Done()
return return
} }
@ -632,7 +628,6 @@ func (c *Client) pingLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
c.Cmd.Ping(fmt.Sprintf("%d", time.Now().UnixNano())) c.Cmd.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
case <-ctx.Done(): case <-ctx.Done():
wg.Done()
return return
} }
} }

View File

@ -93,13 +93,14 @@ func TestRate(t *testing.T) {
return return
} }
func genMockConn() (client *Client, clientConn net.Conn, serverConn net.Conn) { func genMockConn(t *testing.T) (client *Client, clientConn net.Conn, serverConn net.Conn) {
client = New(Config{ client = New(Config{
Server: "dummy.int", Server: "dummy.int",
Port: 6667, Port: 6667,
Nick: "test", Nick: "test",
User: "test", User: "test",
Name: "Testing123", Name: "Testing123",
Debug: newDebugWriter(t),
}) })
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
@ -107,14 +108,19 @@ func genMockConn() (client *Client, clientConn net.Conn, serverConn net.Conn) {
return client, conn1, conn2 return client, conn1, conn2
} }
func mockReadBuffer(conn net.Conn) { func mockReadBuffer(conn net.Conn) error {
// Accept all outgoing writes from the client. // Accept all outgoing writes from the client.
b := bufio.NewReader(conn) b := bufio.NewReader(conn)
for { for {
conn.SetReadDeadline(time.Now().Add(10 * time.Second)) err := conn.SetReadDeadline(time.Now().Add(10 * time.Second))
_, err := b.ReadString(byte('\n'))
if err != nil { if err != nil {
return return err
}
var str string
str, err = b.ReadString(byte('\n'))
println(str)
if err != nil {
return err
} }
} }
} }

5
go.mod
View File

@ -2,4 +2,7 @@ module github.com/yunginnanet/girc-atomic
go 1.17 go 1.17
require github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de require (
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de
github.com/orcaman/concurrent-map v1.0.0
)

2
go.sum
View File

@ -3,6 +3,8 @@ github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoU
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/orcaman/concurrent-map v1.0.0 h1:I/2A2XPCb4IuQWcQhBhSwGfiuybl/J0ev9HDbW65HOY=
github.com/orcaman/concurrent-map v1.0.0/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=

View File

@ -13,11 +13,14 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/orcaman/concurrent-map"
) )
// RunHandlers manually runs handlers for a given event. // RunHandlers manually runs handlers for a given event.
func (c *Client) RunHandlers(event *Event) { func (c *Client) RunHandlers(event *Event) {
if event == nil { if event == nil {
c.debug.Print("nil event")
return return
} }
@ -68,6 +71,62 @@ func (f HandlerFunc) Execute(client *Client, event Event) {
f(client, event) f(client, event)
} }
// nestedHandlers consists of a nested concurrent map.
//
// ( cmap.ConcurrentMap[command]cmap.ConcurrentMap[cuid]Handler )
//
// command and cuid are both strings.
type nestedHandlers struct {
cm cmap.ConcurrentMap
}
type handlerTuple struct {
cuid string
handler Handler
}
func newNestedHandlers() *nestedHandlers {
return &nestedHandlers{cm: cmap.New()}
}
func (nest *nestedHandlers) len() (total int) {
for hs := range nest.cm.IterBuffered() {
hndlrs := hs.Val.(cmap.ConcurrentMap)
total += len(hndlrs.Keys())
}
return
}
func (nest *nestedHandlers) lenFor(cmd string) (total int) {
cmd = strings.ToUpper(cmd)
hs, ok := nest.cm.Get(cmd)
if !ok {
return 0
}
hndlrs := hs.(cmap.ConcurrentMap)
return len(hndlrs.Keys())
}
func (nest *nestedHandlers) getAllHandlersFor(s string) (handlers chan handlerTuple, ok bool) {
var h interface{}
h, ok = nest.cm.Get(s)
if !ok {
return
}
hm := h.(cmap.ConcurrentMap)
handlers = make(chan handlerTuple, 5)
go func() {
for hi := range hm.IterBuffered() {
ht := handlerTuple{
hi.Key,
hi.Val.(Handler),
}
handlers <- ht
}
}()
return
}
// Caller manages internal and external (user facing) handlers. // Caller manages internal and external (user facing) handlers.
type Caller struct { type Caller struct {
// mu is the mutex that should be used when accessing handlers. // mu is the mutex that should be used when accessing handlers.
@ -80,9 +139,10 @@ type Caller struct {
// Also of note: "COMMAND" should always be uppercase for normalization. // Also of note: "COMMAND" should always be uppercase for normalization.
// external is a map of user facing handlers. // external is a map of user facing handlers.
external map[string]map[string]Handler external *nestedHandlers
// external map[string]map[string]Handler
// internal is a map of internally used handlers for the client. // internal is a map of internally used handlers for the client.
internal map[string]map[string]Handler internal *nestedHandlers
// debug is the clients logger used for debugging. // debug is the clients logger used for debugging.
debug *log.Logger debug *log.Logger
} }
@ -90,8 +150,8 @@ type Caller struct {
// newCaller creates and initializes a new handler. // newCaller creates and initializes a new handler.
func newCaller(parent *Client, debugOut *log.Logger) *Caller { func newCaller(parent *Client, debugOut *log.Logger) *Caller {
c := &Caller{ c := &Caller{
external: map[string]map[string]Handler{}, external: newNestedHandlers(),
internal: map[string]map[string]Handler{}, internal: newNestedHandlers(),
debug: debugOut, debug: debugOut,
parent: parent, parent: parent,
} }
@ -101,45 +161,17 @@ func newCaller(parent *Client, debugOut *log.Logger) *Caller {
// Len returns the total amount of user-entered registered handlers. // Len returns the total amount of user-entered registered handlers.
func (c *Caller) Len() int { func (c *Caller) Len() int {
var total int return c.external.len()
// c.mu.RLock()
for command := range c.external {
total += len(c.external[command])
}
// c.mu.RUnlock()
return total
} }
// Count is much like Caller.Len(), however it counts the number of // Count is much like Caller.Len(), however it counts the number of
// registered handlers for a given command. // registered handlers for a given command.
func (c *Caller) Count(cmd string) int { func (c *Caller) Count(cmd string) int {
var total int return c.external.lenFor(cmd)
cmd = strings.ToUpper(cmd)
// c.mu.RLock()
for command := range c.external {
if command == cmd {
total += len(c.external[command])
}
}
// c.mu.RUnlock()
return total
} }
func (c *Caller) String() string { func (c *Caller) String() string {
var total int return fmt.Sprintf("<Caller external:%d internal:%d>", c.Len(), c.internal.len())
c.mu.RLock()
for cmd := range c.internal {
total += len(c.internal[cmd])
}
c.mu.RUnlock()
return fmt.Sprintf("<Caller external:%d internal:%d>", c.Len(), total)
} }
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
@ -166,77 +198,57 @@ func (c *Caller) cuidToID(input string) (cmd, uid string) {
return input[:i], input[i+1:] return input[:i], input[i+1:]
} }
type execStack struct {
Handler
cuid string
}
// exec executes all handlers pertaining to specified event. Internal first, // exec executes all handlers pertaining to specified event. Internal first,
// then external. // then external.
// //
// Please note that there is no specific order/priority for which the handlers // Please note that there is no specific order/priority for which the handlers
// are executed. // are executed.
func (c *Caller) exec(command string, bg bool, client *Client, event *Event) { func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
handle := func(wgr *sync.WaitGroup, h handlerTuple) {
// Build a stack of handlers which can be executed concurrently. c.debug.Printf("(%s) exec %s => %s", c.parent.Config.Nick, command, h.cuid)
var stack []execStack start := time.Now()
c.mu.RLock() if bg {
// Get internal handlers first. go func() {
if _, ok := c.internal[command]; ok { defer wgr.Done()
for cuid := range c.internal[command] { if client.Config.RecoverFunc != nil {
if (strings.HasSuffix(cuid, ":bg") && !bg) || (!strings.HasSuffix(cuid, ":bg") && bg) { defer recoverHandlerPanic(client, event, h.cuid, 3)
continue }
} h.handler.Execute(client, *event)
stack = append(stack, execStack{c.internal[command][cuid], cuid}) c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick,
h.cuid, time.Since(start))
}()
return
} }
}
c.mu.RUnlock()
c.mu.RLock() if client.Config.RecoverFunc != nil {
// Then external handlers. defer recoverHandlerPanic(client, event, h.cuid, 3)
if _, ok := c.external[command]; ok {
for cuid := range c.external[command] {
if (strings.HasSuffix(cuid, ":bg") && !bg) || (!strings.HasSuffix(cuid, ":bg") && bg) {
continue
}
stack = append(stack, execStack{c.external[command][cuid], cuid})
} }
h.handler.Execute(client, *event)
c.debug.Printf("(%s) done %s == %s", c.parent.Config.Nick, h.cuid, time.Since(start))
wgr.Done()
} }
c.mu.RUnlock()
// Run all handlers concurrently across the same event. This should // Run all handlers concurrently across the same event. This should
// still help prevent mis-ordered events, while speeding up the // still help prevent mis-ordered events, while speeding up the
// execution speed. // execution speed.
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(stack))
for i := 0; i < len(stack); i++ {
go func(index int) {
defer wg.Done()
c.debug.Printf("(%s) [%d/%d] exec %s => %s", c.parent.Config.Nick,
index+1, len(stack), stack[index].cuid, command)
start := time.Now()
if bg { internals, iok := c.internal.getAllHandlersFor(command)
go func() { if iok {
if client.Config.RecoverFunc != nil { for h := range internals {
defer recoverHandlerPanic(client, event, stack[index].cuid, 3) wg.Add(1)
} go handle(&wg, h)
stack[index].Execute(client, *event) }
c.debug.Printf("(%s) [%d/%d] done %s == %s", c.parent.Config.Nick, }
index+1, len(stack), stack[index].cuid, time.Since(start)) externals, eok := c.external.getAllHandlersFor(command)
}() if eok {
for h := range externals {
return wg.Add(1)
} go handle(&wg, h)
}
if client.Config.RecoverFunc != nil {
defer recoverHandlerPanic(client, event, stack[index].cuid, 3)
}
stack[index].Execute(client, *event)
c.debug.Printf("(%s) [%d/%d] done %s == %s", c.parent.Config.Nick, index+1, len(stack), stack[index].cuid, time.Since(start))
}(i)
} }
// Wait for all of the handlers to complete. Not doing this may cause // Wait for all of the handlers to complete. Not doing this may cause
@ -248,20 +260,14 @@ func (c *Caller) exec(command string, bg bool, client *Client, event *Event) {
// ClearAll clears all external handlers currently setup within the client. // ClearAll clears all external handlers currently setup within the client.
// This ignores internal handlers. // This ignores internal handlers.
func (c *Caller) ClearAll() { func (c *Caller) ClearAll() {
c.mu.Lock() c.external.cm.Clear()
c.external = map[string]map[string]Handler{}
c.mu.Unlock()
c.debug.Print("cleared all external handlers") c.debug.Print("cleared all external handlers")
} }
// clearInternal clears all internal handlers currently setup within the // clearInternal clears all internal handlers currently setup within the
// client. // client.
func (c *Caller) clearInternal() { func (c *Caller) clearInternal() {
c.mu.Lock() c.internal.cm.Clear()
c.internal = map[string]map[string]Handler{}
c.mu.Unlock()
c.debug.Print("cleared all internal handlers") c.debug.Print("cleared all internal handlers")
} }
@ -269,13 +275,7 @@ func (c *Caller) clearInternal() {
// This ignores internal handlers. // This ignores internal handlers.
func (c *Caller) Clear(cmd string) { func (c *Caller) Clear(cmd string) {
cmd = strings.ToUpper(cmd) cmd = strings.ToUpper(cmd)
c.external.cm.Remove(cmd)
c.mu.Lock()
delete(c.external, cmd)
c.mu.Unlock()
c.debug.Printf("(%s) cleared external handlers for %s", c.parent.Config.Nick, cmd) c.debug.Printf("(%s) cleared external handlers for %s", c.parent.Config.Nick, cmd)
} }
@ -292,23 +292,27 @@ func (c *Caller) Remove(cuid string) (success bool) {
// remove is much like Remove, however is NOT concurrency safe. Lock Caller.mu // remove is much like Remove, however is NOT concurrency safe. Lock Caller.mu
// on your own. // on your own.
func (c *Caller) remove(cuid string) (success bool) { func (c *Caller) remove(cuid string) (ok bool) {
cmd, uid := c.cuidToID(cuid) cmd, uid := c.cuidToID(cuid)
if len(cmd) == 0 || len(uid) == 0 { if len(cmd) == 0 || len(uid) == 0 {
return false return false
} }
// Check if the irc command/event has any handlers on it. // Check if the irc command/event has any handlers on it.
if _, ok := c.external[cmd]; !ok { var h interface{}
return false h, ok = c.external.cm.Get(cmd)
if !ok {
return
} }
hs := h.(cmap.ConcurrentMap)
// Check to see if it's actually a registered handler. // Check to see if it's actually a registered handler.
if _, ok := c.external[cmd][uid]; !ok { if _, ok = hs.Get(cuid); !ok {
return false return
} }
delete(c.external[cmd], uid) hs.Remove(uid)
c.debug.Printf("removed handler %s", cuid) c.debug.Printf("removed handler %s", cuid)
// Assume success. // Assume success.
@ -318,10 +322,7 @@ func (c *Caller) remove(cuid string) (success bool) {
// sregister is much like Caller.register(), except that it safely locks // sregister is much like Caller.register(), except that it safely locks
// the Caller mutex. // the Caller mutex.
func (c *Caller) sregister(internal, bg bool, cmd string, handler Handler) (cuid string) { func (c *Caller) sregister(internal, bg bool, cmd string, handler Handler) (cuid string) {
c.mu.Lock()
cuid = c.register(internal, bg, cmd, handler) cuid = c.register(internal, bg, cmd, handler)
c.mu.Unlock()
return cuid return cuid
} }
@ -338,21 +339,31 @@ func (c *Caller) register(internal, bg bool, cmd string, handler Handler) (cuid
cuid += ":bg" cuid += ":bg"
} }
var (
parent *nestedHandlers
chandlers cmap.ConcurrentMap
ei interface{}
ok bool
)
if internal { if internal {
if _, ok := c.internal[cmd]; !ok { parent = c.internal
c.internal[cmd] = map[string]Handler{}
}
c.internal[cmd][uid] = handler
} else { } else {
if _, ok := c.external[cmd]; !ok { parent = c.external
c.external[cmd] = map[string]Handler{}
}
c.external[cmd][uid] = handler
} }
_, file, line, _ := runtime.Caller(3) ei, ok = parent.cm.Get(cmd)
if ok {
chandlers = ei.(cmap.ConcurrentMap)
} else {
chandlers = cmap.New()
}
parent.cm.SetIfAbsent(cmd, chandlers)
chandlers.Set(uid, handler)
_, file, line, _ := runtime.Caller(2)
c.debug.Printf("reg %q => %s [int:%t bg:%t] %s:%d", uid, cmd, internal, bg, file, line) c.debug.Printf("reg %q => %s [int:%t bg:%t] %s:%d", uid, cmd, internal, bg, file, line)

View File

@ -333,10 +333,9 @@ func handleMODE(c *Client, e Event) {
return return
} }
c.state.RLock()
channel := c.state.lookupChannel(e.Params[0]) channel := c.state.lookupChannel(e.Params[0])
if channel == nil { if channel == nil {
c.state.RUnlock()
return return
} }
@ -363,15 +362,14 @@ func handleMODE(c *Client, e Event) {
} }
} }
c.state.RUnlock()
c.state.notify(c, UPDATE_STATE) c.state.notify(c, UPDATE_STATE)
} }
// chanModes returns the ISUPPORT list of server-supported channel modes, // chanModes returns the ISUPPORT list of server-supported channel modes,
// alternatively falling back to ModeDefaults. // alternatively falling back to ModeDefaults.
func (s *state) chanModes() string { func (s *state) chanModes() string {
if validmodes, ok := s.serverOptions["CHANMODES"]; ok { if validmodes, ok := s.serverOptions.Get("CHANMODES"); ok {
modes := validmodes.Load().(string) modes := validmodes.(string)
if IsValidChannelMode(modes) { if IsValidChannelMode(modes) {
return modes return modes
} }
@ -384,8 +382,8 @@ func (s *state) chanModes() string {
// This includes mode characters, as well as user prefix symbols. Falls back // This includes mode characters, as well as user prefix symbols. Falls back
// to DefaultPrefixes if not server-supported. // to DefaultPrefixes if not server-supported.
func (s *state) userPrefixes() string { func (s *state) userPrefixes() string {
if atomicprefix, ok := s.serverOptions["PREFIX"]; ok { if pi, ok := s.serverOptions.Get("PREFIX"); ok {
prefix := atomicprefix.Load().(string) prefix := pi.(string)
if isValidUserPrefix(prefix) { if isValidUserPrefix(prefix) {
return prefix return prefix
} }

104
state.go
View File

@ -10,19 +10,23 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
cmap "github.com/orcaman/concurrent-map"
) )
// state represents the actively-changing variables within the client // state represents the actively-changing variables within the client
// runtime. Note that everything within the state should be guarded by the // runtime. Note that everything within the state should be guarded by the
// embedded sync.RWMutex. // embedded sync.RWMutex.
type state struct { type state struct {
sync.RWMutex *sync.RWMutex
// nick, ident, and host are the internal trackers for our user. // nick, ident, and host are the internal trackers for our user.
nick, ident, host atomic.Value nick, ident, host atomic.Value
// channels represents all channels we're active in. // channels represents all channels we're active in.
channels map[string]*Channel // channels map[string]*Channel
channels cmap.ConcurrentMap
// users represents all of users that we're tracking. // users represents all of users that we're tracking.
users map[string]*User // users map[string]*User
users cmap.ConcurrentMap
// enabledCap are the capabilities which are enabled for this connection. // enabledCap are the capabilities which are enabled for this connection.
enabledCap map[string]map[string]string enabledCap map[string]map[string]string
// tmpCap are the capabilties which we share with the server during the // tmpCap are the capabilties which we share with the server during the
@ -32,10 +36,10 @@ type state struct {
// serverOptions are the standard capabilities and configurations // serverOptions are the standard capabilities and configurations
// supported by the server at connection time. This also includes // supported by the server at connection time. This also includes
// RPL_ISUPPORT entries. // RPL_ISUPPORT entries.
serverOptions map[string]*atomic.Value serverOptions cmap.ConcurrentMap
// network is an alternative way to store and retrieve the NETWORK server option. // network is an alternative way to store and retrieve the NETWORK server option.
network atomic.Value network string
// motd is the servers message of the day. // motd is the servers message of the day.
motd string motd string
@ -53,13 +57,18 @@ type state struct {
// reset resets the state back to it's original form. // reset resets the state back to it's original form.
func (s *state) reset(initial bool) { func (s *state) reset(initial bool) {
s.Lock()
s.nick.Store("") s.nick.Store("")
s.ident.Store("") s.ident.Store("")
s.host.Store("") s.host.Store("")
s.channels = make(map[string]*Channel) var cmaps = []*cmap.ConcurrentMap{&s.channels, &s.users, &s.serverOptions}
s.users = make(map[string]*User) for _, cm := range cmaps {
s.serverOptions = make(map[string]*atomic.Value) if initial {
*cm = cmap.New()
} else {
cm.Clear()
}
}
s.enabledCap = make(map[string]map[string]string) s.enabledCap = make(map[string]map[string]string)
s.tmpCap = make(map[string]map[string]string) s.tmpCap = make(map[string]map[string]string)
s.motd = "" s.motd = ""
@ -67,7 +76,6 @@ func (s *state) reset(initial bool) {
if initial { if initial {
s.sts.reset() s.sts.reset()
} }
s.Unlock()
} }
// User represents an IRC user and the state attached to them. // User represents an IRC user and the state attached to them.
@ -141,14 +149,12 @@ func (u User) Channels(c *Client) []*Channel {
var channels []*Channel var channels []*Channel
c.state.RLock()
for i := 0; i < len(u.ChannelList); i++ { for i := 0; i < len(u.ChannelList); i++ {
ch := c.state.lookupChannel(u.ChannelList[i]) ch := c.state.lookupChannel(u.ChannelList[i])
if ch != nil { if ch != nil {
channels = append(channels, ch) channels = append(channels, ch)
} }
} }
c.state.RUnlock()
return channels return channels
} }
@ -261,14 +267,12 @@ func (ch Channel) Users(c *Client) []*User {
var users []*User var users []*User
c.state.RLock()
for i := 0; i < len(ch.UserList); i++ { for i := 0; i < len(ch.UserList); i++ {
user := c.state.lookupUser(ch.UserList[i]) user := c.state.lookupUser(ch.UserList[i])
if user != nil { if user != nil {
users = append(users, user) users = append(users, user)
} }
} }
c.state.RUnlock()
return users return users
} }
@ -282,7 +286,6 @@ func (ch Channel) Trusted(c *Client) []*User {
var users []*User var users []*User
c.state.RLock()
for i := 0; i < len(ch.UserList); i++ { for i := 0; i < len(ch.UserList); i++ {
user := c.state.lookupUser(ch.UserList[i]) user := c.state.lookupUser(ch.UserList[i])
if user == nil { if user == nil {
@ -294,7 +297,6 @@ func (ch Channel) Trusted(c *Client) []*User {
users = append(users, user) users = append(users, user)
} }
} }
c.state.RUnlock()
return users return users
} }
@ -309,7 +311,6 @@ func (ch Channel) Admins(c *Client) []*User {
var users []*User var users []*User
c.state.RLock()
for i := 0; i < len(ch.UserList); i++ { for i := 0; i < len(ch.UserList); i++ {
user := c.state.lookupUser(ch.UserList[i]) user := c.state.lookupUser(ch.UserList[i])
if user == nil { if user == nil {
@ -321,7 +322,6 @@ func (ch Channel) Admins(c *Client) []*User {
users = append(users, user) users = append(users, user)
} }
} }
c.state.RUnlock()
return users return users
} }
@ -400,17 +400,17 @@ func (s *state) createChannel(name string) (ok bool) {
supported := s.chanModes() supported := s.chanModes()
prefixes, _ := parsePrefixes(s.userPrefixes()) prefixes, _ := parsePrefixes(s.userPrefixes())
if _, ok := s.channels[ToRFC1459(name)]; ok { if _, ok := s.channels.Get(ToRFC1459(name)); ok {
return false return false
} }
s.channels[ToRFC1459(name)] = &Channel{ s.channels.Set(ToRFC1459(name), &Channel{
Name: name, Name: name,
UserList: []string{}, UserList: []string{},
Joined: time.Now(), Joined: time.Now(),
Network: s.client.NetworkName(), Network: s.client.NetworkName(),
Modes: NewCModes(supported, prefixes), Modes: NewCModes(supported, prefixes),
} })
return true return true
} }
@ -419,37 +419,51 @@ func (s *state) createChannel(name string) (ok bool) {
func (s *state) deleteChannel(name string) { func (s *state) deleteChannel(name string) {
name = ToRFC1459(name) name = ToRFC1459(name)
_, ok := s.channels[name] c, ok := s.channels.Get(name)
if !ok { if !ok {
return return
} }
for _, user := range s.channels[name].UserList { chn := c.(*Channel)
s.users[user].deleteChannel(name)
for _, user := range chn.UserList {
ui, _ := s.users.Get(user)
usr := ui.(*User)
usr.deleteChannel(name)
} }
delete(s.channels, name) s.channels.Remove(name)
} }
// lookupChannel returns a reference to a channel, nil returned if no results // lookupChannel returns a reference to a channel, nil returned if no results
// found. // found.
func (s *state) lookupChannel(name string) *Channel { func (s *state) lookupChannel(name string) *Channel {
return s.channels[ToRFC1459(name)] ci, cok := s.channels.Get(ToRFC1459(name))
chn, ok := ci.(*Channel)
if !ok || !cok {
return nil
}
return chn
} }
// lookupUser returns a reference to a user, nil returned if no results // lookupUser returns a reference to a user, nil returned if no results
// found. // found.
func (s *state) lookupUser(name string) *User { func (s *state) lookupUser(name string) *User {
return s.users[ToRFC1459(name)] ui, uok := s.users.Get(ToRFC1459(name))
usr, ok := ui.(*User)
if !ok || !uok {
return nil
}
return usr
} }
func (s *state) createUser(src *Source) (ok bool) { func (s *state) createUser(src *Source) (u *User, ok bool) {
if _, ok := s.users[src.ID()]; ok { if _, ok := s.users.Get(src.ID()); ok {
// User already exists. // User already exists.
return false return nil, false
} }
s.users[src.ID()] = &User{ u = &User{
Nick: src.Name, Nick: src.Name,
Host: src.Host, Host: src.Host,
Ident: src.Ident, Ident: src.Ident,
@ -460,7 +474,8 @@ func (s *state) createUser(src *Source) (ok bool) {
Perms: &UserPerms{channels: make(map[string]Perms)}, Perms: &UserPerms{channels: make(map[string]Perms)},
} }
return true s.users.Set(src.ID(), u)
return u, true
} }
// deleteUser removes the user from channel state. // deleteUser removes the user from channel state.
@ -472,10 +487,12 @@ func (s *state) deleteUser(channelName, nick string) {
if channelName == "" { if channelName == "" {
for i := 0; i < len(user.ChannelList); i++ { for i := 0; i < len(user.ChannelList); i++ {
s.channels[user.ChannelList[i]].deleteUser(nick) ci, _ := s.channels.Get(user.ChannelList[i])
chn := ci.(*Channel)
chn.deleteUser(nick)
} }
delete(s.users, ToRFC1459(nick)) s.users.Remove(ToRFC1459(nick))
return return
} }
@ -491,7 +508,7 @@ func (s *state) deleteUser(channelName, nick string) {
// This means they are no longer in any channels we track, delete // This means they are no longer in any channels we track, delete
// them from state. // them from state.
delete(s.users, ToRFC1459(nick)) s.users.Remove(ToRFC1459(nick))
} }
} }
@ -509,18 +526,19 @@ func (s *state) renameUser(from, to string) {
return return
} }
delete(s.users, from) s.users.Remove(from)
user.Nick = to user.Nick = to
user.LastActive = time.Now() user.LastActive = time.Now()
s.users[ToRFC1459(to)] = user s.users.Set(ToRFC1459(to), user)
for i := 0; i < len(user.ChannelList); i++ { for chanchan := range s.channels.IterBuffered() {
for j := 0; j < len(s.channels[user.ChannelList[i]].UserList); j++ { chi := chanchan.Val
if s.channels[user.ChannelList[i]].UserList[j] == from { chn := chi.(*Channel)
s.channels[user.ChannelList[i]].UserList[j] = ToRFC1459(to) for i := range chn.UserList {
if chn.UserList[i] == from {
sort.Strings(s.channels[user.ChannelList[i]].UserList) chn.UserList[i] = ToRFC1459(to)
sort.Strings(chn.UserList)
break break
} }
} }

View File

@ -5,6 +5,7 @@
package girc package girc
import ( import (
"log"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -56,7 +57,8 @@ const mockConnEndState = `:nick2!nick2@other.int QUIT :example reason
` `
func TestState(t *testing.T) { func TestState(t *testing.T) {
c, conn, server := genMockConn() c, conn, server := genMockConn(t)
defer c.Close() defer c.Close()
go mockReadBuffer(conn) go mockReadBuffer(conn)
@ -71,50 +73,57 @@ func TestState(t *testing.T) {
finishStart := make(chan bool, 1) finishStart := make(chan bool, 1)
go debounce(250*time.Millisecond, bounceStart, func() { go debounce(250*time.Millisecond, bounceStart, func() {
if motd := c.ServerMOTD(); motd != "example motd" { if motd := c.ServerMOTD(); motd != "example motd" {
t.Fatalf("Client.ServerMOTD() returned invalid MOTD: %q", motd) t.Errorf("Client.ServerMOTD() returned invalid MOTD: %q", motd)
} }
if network := c.NetworkName(); network != "DummyIRC" && network != "DUMMY" { if network := c.NetworkName(); network != "DummyIRC" && network != "DUMMY" {
t.Fatalf("User.Network == %q, want \"DummyIRC\" or \"DUMMY\"", network) t.Errorf("User.Network == %q, want \"DummyIRC\" or \"DUMMY\"", network)
} }
if caseExample, ok := c.GetServerOption("NICKLEN"); !ok || caseExample != "20" { if caseExample, ok := c.GetServerOption("NICKLEN"); !ok || caseExample != "20" {
t.Fatalf("Client.GetServerOptions returned invalid ISUPPORT variable: %q", caseExample) t.Errorf("Client.GetServerOptions returned invalid ISUPPORT variable: %q", caseExample)
} }
t.Logf("getting user list")
users := c.UserList() users := c.UserList()
t.Logf("getting channel list")
channels := c.ChannelList() channels := c.ChannelList()
if !reflect.DeepEqual(users, []string{"fhjones", "nick2"}) { if !reflect.DeepEqual(users, []string{"fhjones", "nick2"}) {
// This could fail too, if sorting isn't occurring. // This could fail too, if sorting isn't occurring.
t.Fatalf("got state users %#v, wanted: %#v", users, []string{"fhjones", "nick2"}) t.Errorf("got state users %#v, wanted: %#v", users, []string{"fhjones", "nick2"})
} }
if !reflect.DeepEqual(channels, []string{"#channel", "#channel2"}) { if !reflect.DeepEqual(channels, []string{"#channel", "#channel2"}) {
// This could fail too, if sorting isn't occurring. // This could fail too, if sorting isn't occurring.
t.Fatalf("got state channels %#v, wanted: %#v", channels, []string{"#channel", "#channel2"}) t.Errorf("got state channels %#v, wanted: %#v", channels, []string{"#channel", "#channel2"})
} }
fullChannels := c.Channels() fullChannels := c.Channels()
for i := 0; i < len(fullChannels); i++ { for i := 0; i < len(fullChannels); i++ {
if fullChannels[i].Name != channels[i] { if fullChannels[i].Name != channels[i] {
t.Fatalf("fullChannels name doesn't map to same name in ChannelsList: %q :: %#v", fullChannels[i].Name, channels) t.Errorf("fullChannels name doesn't map to same name in ChannelsList: %q :: %#v", fullChannels[i].Name, channels)
} }
} }
fullUsers := c.Users() fullUsers := c.Users()
for i := 0; i < len(fullUsers); i++ { for i := 0; i < len(fullUsers); i++ {
if fullUsers[i].Nick != users[i] { if fullUsers[i].Nick != users[i] {
t.Fatalf("fullUsers nick doesn't map to same nick in UsersList: %q :: %#v", fullUsers[i].Nick, users) t.Errorf("fullUsers nick doesn't map to same nick in UsersList: %q :: %#v", fullUsers[i].Nick, users)
} }
} }
ch := c.LookupChannel("#channel") ch := c.LookupChannel("#channel")
if ch == nil { if ch == nil {
t.Fatal("Client.LookupChannel returned nil on existing channel") t.Error("Client.LookupChannel returned nil on existing channel")
return
} }
adm := ch.Admins(c) adm := ch.Admins(c)
if adm == nil {
t.Errorf("admin list is nil")
t.Fail()
}
admList := []string{} admList := []string{}
for i := 0; i < len(adm); i++ { for i := 0; i < len(adm); i++ {
admList = append(admList, adm[i].Nick) admList = append(admList, adm[i].Nick)
@ -126,87 +135,93 @@ func TestState(t *testing.T) {
} }
if !reflect.DeepEqual(admList, []string{"nick2"}) { if !reflect.DeepEqual(admList, []string{"nick2"}) {
t.Fatalf("got Channel.Admins() == %#v, wanted %#v", admList, []string{"nick2"}) t.Errorf("got Channel.Admins() == %#v, wanted %#v", admList, []string{"nick2"})
} }
if !reflect.DeepEqual(trustedList, []string{"nick2"}) { if !reflect.DeepEqual(trustedList, []string{"nick2"}) {
t.Fatalf("got Channel.Trusted() == %#v, wanted %#v", trustedList, []string{"nick2"}) t.Errorf("got Channel.Trusted() == %#v, wanted %#v", trustedList, []string{"nick2"})
} }
if topic := ch.Topic; topic != "example topic" { if topic := ch.Topic; topic != "example topic" {
t.Fatalf("Channel.Topic == %q, want \"example topic\"", topic) t.Errorf("Channel.Topic == %q, want \"example topic\"", topic)
} }
if ch.Network != "DummyIRC" && ch.Network != "DUMMY" { if ch.Network != "DummyIRC" && ch.Network != "DUMMY" {
t.Fatalf("Channel.Network == %q, want \"DummyIRC\" or \"DUMMY\"", ch.Network) t.Errorf("Channel.Network == %q, want \"DummyIRC\" or \"DUMMY\"", ch.Network)
} }
if in := ch.UserIn("fhjones"); !in { if in := ch.UserIn("fhjones"); !in {
t.Fatalf("Channel.UserIn == %t, want %t", in, true) t.Errorf("Channel.UserIn == %t, want %t", in, true)
} }
if users := ch.Users(c); len(users) != 2 { if users := ch.Users(c); len(users) != 2 {
t.Fatalf("Channel.Users == %#v, wanted length of 2", users) t.Errorf("Channel.Users == %#v, wanted length of 2", users)
} }
if h := c.GetHost(); h != "local.int" { if h := c.GetHost(); h != "local.int" {
t.Fatalf("Client.GetHost() == %q, want local.int", h) t.Errorf("Client.GetHost() == %q, want local.int", h)
} }
if nick := c.GetNick(); nick != "fhjones" { if nick := c.GetNick(); nick != "fhjones" {
t.Fatalf("Client.GetNick() == %q, want nick", nick) t.Errorf("Client.GetNick() == %q, want nick", nick)
} }
if ident := c.GetIdent(); ident != "~user" { if ident := c.GetIdent(); ident != "~user" {
t.Fatalf("Client.GetIdent() == %q, want ~user", ident) t.Errorf("Client.GetIdent() == %q, want ~user", ident)
} }
user := c.LookupUser("fhjones") user := c.LookupUser("fhjones")
if user == nil { if user == nil {
t.Fatal("Client.LookupUser() returned nil on existing user") t.Errorf("Client.LookupUser() returned nil on existing user")
return
} }
if !reflect.DeepEqual(user.ChannelList, []string{"#channel", "#channel2"}) { if !reflect.DeepEqual(user.ChannelList, []string{"#channel", "#channel2"}) {
t.Fatalf("User.ChannelList == %#v, wanted %#v", user.ChannelList, []string{"#channel", "#channel2"}) t.Errorf("User.ChannelList == %#v, wanted %#v", user.ChannelList, []string{"#channel", "#channel2"})
} }
if count := len(user.Channels(c)); count != 2 { if count := len(user.Channels(c)); count != 2 {
t.Fatalf("len(User.Channels) == %d, want 2", count) t.Errorf("len(User.Channels) == %d, want 2", count)
} }
if user.Nick != "fhjones" { if user.Nick != "fhjones" {
t.Fatalf("User.Nick == %q, wanted \"nick\"", user.Nick) t.Errorf("User.Nick == %q, wanted \"nick\"", user.Nick)
} }
if user.Extras.Name != "realname" { if user.Extras.Name != "realname" {
t.Fatalf("User.Extras.Name == %q, wanted \"realname\"", user.Extras.Name) t.Errorf("User.Extras.Name == %q, wanted \"realname\"", user.Extras.Name)
} }
if user.Host != "local.int" { if user.Host != "local.int" {
t.Fatalf("User.Host == %q, wanted \"local.int\"", user.Host) t.Errorf("User.Host == %q, wanted \"local.int\"", user.Host)
} }
if user.Ident != "~user" { if user.Ident != "~user" {
t.Fatalf("User.Ident == %q, wanted \"~user\"", user.Ident) t.Errorf("User.Ident == %q, wanted \"~user\"", user.Ident)
} }
if user.Network != "DummyIRC" && user.Network != "DUMMY" { if user.Network != "DummyIRC" && user.Network != "DUMMY" {
t.Fatalf("User.Network == %q, want \"DummyIRC\" or \"DUMMY\"", user.Network) t.Errorf("User.Network == %q, want \"DummyIRC\" or \"DUMMY\"", user.Network)
} }
if !user.InChannel("#channel2") { if !user.InChannel("#channel2") {
t.Fatal("User.InChannel() returned false for existing channel") t.Error("User.InChannel() returned false for existing channel")
return
} }
finishStart <- true finishStart <- true
}) })
cuid := c.Handlers.AddBg(UPDATE_STATE, func(c *Client, e Event) { cuid := c.Handlers.AddBg(UPDATE_STATE, func(c *Client, e Event) {
println(e.String())
bounceStart <- true bounceStart <- true
}) })
conn.SetDeadline(time.Now().Add(5 * time.Second)) err := conn.SetDeadline(time.Now().Add(5 * time.Second))
_, err := conn.Write([]byte(mockConnStartState)) if err != nil {
log.Fatalf(err.Error())
}
_, err = conn.Write([]byte(mockConnStartState))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -222,11 +237,11 @@ func TestState(t *testing.T) {
finishEnd := make(chan bool, 1) finishEnd := make(chan bool, 1)
go debounce(250*time.Millisecond, bounceEnd, func() { go debounce(250*time.Millisecond, bounceEnd, func() {
if !reflect.DeepEqual(c.ChannelList(), []string{"#channel"}) { if !reflect.DeepEqual(c.ChannelList(), []string{"#channel"}) {
t.Fatalf("Client.ChannelList() == %#v, wanted %#v", c.ChannelList(), []string{"#channel"}) t.Errorf("Client.ChannelList() == %#v, wanted %#v", c.ChannelList(), []string{"#channel"})
} }
if !reflect.DeepEqual(c.UserList(), []string{"notjones"}) { if !reflect.DeepEqual(c.UserList(), []string{"notjones"}) {
t.Fatalf("Client.UserList() == %#v, wanted %#v", c.UserList(), []string{"notjones"}) t.Errorf("Client.UserList() == %#v, wanted %#v", c.UserList(), []string{"notjones"})
} }
user := c.LookupUser("notjones") user := c.LookupUser("notjones")
@ -235,18 +250,19 @@ func TestState(t *testing.T) {
} }
if !reflect.DeepEqual(user.ChannelList, []string{"#channel"}) { if !reflect.DeepEqual(user.ChannelList, []string{"#channel"}) {
t.Fatalf("user.ChannelList == %q, wanted %q", user.ChannelList, []string{"#channel"}) t.Errorf("user.ChannelList == %q, wanted %q", user.ChannelList, []string{"#channel"})
} }
channel := c.LookupChannel("#channel") channel := c.LookupChannel("#channel")
if channel == nil { if channel == nil {
t.Fatal("Client.LookupChannel() returned nil for existing channel") t.Error("Client.LookupChannel() returned nil for existing channel")
} }
if !reflect.DeepEqual(channel.UserList, []string{"notjones"}) { if !reflect.DeepEqual(channel.UserList, []string{"notjones"}) {
t.Fatalf("channel.UserList == %q, wanted %q", channel.UserList, []string{"notjones"}) t.Errorf("channel.UserList == %q, wanted %q", channel.UserList, []string{"notjones"})
} }
t.Logf(c.String())
finishEnd <- true finishEnd <- true
}) })