From 968706823c86ff96a934f2ea1a4fe6303b831b64 Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Sun, 8 Jan 2023 10:16:04 -0800 Subject: [PATCH] Pass tests for user db (likely still broken somehow) --- internal/data/users.go | 281 ++++++++++++++++++++++-------------- internal/data/users_test.go | 76 ++++++---- 2 files changed, 215 insertions(+), 142 deletions(-) diff --git a/internal/data/users.go b/internal/data/users.go index ecbf0cc..043e544 100644 --- a/internal/data/users.go +++ b/internal/data/users.go @@ -6,16 +6,45 @@ import ( "errors" "sync" - "github.com/davecgh/go-spew/spew" + "git.tcp.direct/kayos/common/entropy" "github.com/rs/zerolog/log" ) +type StringMapper interface { + Map() map[string]string +} + +type AuthMethod interface { + json.Marshaler + StringMapper + // Authenticate authenticates the user. + Authenticate() error + // Name returns the name of the authentication method. + Name() string +} + +func AuthMethodFromMap(m map[string]string) AuthMethod { + switch m["type"] { + case "password": + return &UserPass{ + Username: m["pass_username"], + Password: m["password"], + } + case "publickey": + return &PubKey{ + Username: m["pub_username"], + Pub: []byte(m["pubkey"]), + } + } + return nil +} + var ErrAccessDenied = errors.New("access denied") type User struct { - Username string `json:"username"` - AuthMethods []any `json:"auth_methods"` - authMethods []AuthMethod + Username string `json:"username"` + AuthMethods []map[string]string `json:"auth_methods"` + *sync.Mutex } type UserPass struct { @@ -23,210 +52,235 @@ type UserPass struct { Password string `json:"password"` } -func NewUserPass(username, password string) *UserPass { - hash, err := HashPassword(password) - if err != nil { - panic(err) - } - return &UserPass{ - Username: username, - Password: hash, - } -} - -func (u *UserPass) Name() string { +func (up *UserPass) Name() string { return "password" } -func (u *UserPass) Authenticate() error { - user, err := GetUser(u.Username) +func (up *UserPass) Map() map[string]string { + return map[string]string{ + "type": "password", + "pass_username": up.Username, + "password": up.Password, + } +} + +func (up *UserPass) MarshalJSON() ([]byte, error) { + return json.Marshal(up.Map()) +} + +func NewUserPass(hashIt bool, username, password string) *UserPass { + var input = password + var err error + if hashIt { + input, err = HashPassword(password) + if err != nil { + panic(err) + } + } + return &UserPass{ + Username: username, + Password: input, + } +} + +func (up *UserPass) Authenticate() error { + user, err := GetUser(up.Username) if err != nil { - log.Warn().Err(err).Str("username", u.Username).Msg("error getting user") + log.Warn().Err(err).Str("username", up.Username).Msg("error getting user") FakeCycle() return ErrAccessDenied } - for _, method := range user.authMethods { - if method.Name() == "password" { - userPass := method.(*UserPass) - CheckPasswordHash(u.Password, userPass.Password) + for _, method := range user.AuthMethods { + switch method["type"] { + case "password": + if method["pass_username"] == up.Username && CheckPasswordHash(up.Password, method["password"]) { + return nil + } + default: + continue } } return ErrAccessDenied } type PubKey struct { - Username string `json:"pubusername"` + Username string `json:"pub_username"` Pub []byte `json:"pubkey"` } -func (p PubKey) Name() string { +func (pk *PubKey) Name() string { return "publickey" } -func (p PubKey) Authenticate() error { - u, err := GetUser(p.Username) +func (pk *PubKey) Map() map[string]string { + return map[string]string{ + "type": "publickey", + "pub_username": pk.Username, + "pubkey": string(pk.Pub), + } +} + +func (pk *PubKey) MarshalJSON() ([]byte, error) { + return json.Marshal(pk.Map()) +} + +func (pk *PubKey) Authenticate() error { + user, err := GetUser(pk.Username) if err != nil { - log.Warn().Err(err).Str("username", p.Username).Msg("error getting user") + log.Warn().Err(err).Str("username", pk.Username).Msg("error getting user") FakeCycle() return ErrAccessDenied } - spew.Dump(u) - for _, method := range u.authMethods { - if method.Name() == "publickey" { - pubKey := method.(*PubKey) - if bytes.Equal(pubKey.Pub, p.Pub) { + for _, method := range user.AuthMethods { + switch method["type"] { + case "publickey": + if method["pub_username"] == pk.Username && bytes.Equal([]byte(method["pubkey"]), pk.Pub) { return nil } + default: + continue } } return ErrAccessDenied } -type AuthMethod interface { - // Name returns the name of the authentication method. - Name() string - // Authenticate authenticates the user. - Authenticate() error -} - func GetUser(username string) (*User, error) { res, err := db.With("users").Get([]byte(username)) if err != nil { return nil, err } var user User - if err := json.Unmarshal(res, &user); err != nil { + if err = json.Unmarshal(res, &user); err != nil { return nil, err } - for _, method := range user.AuthMethods { - var up UserPass - var pk PubKey - jm, err := json.Marshal(method) - if err != nil { - return nil, err - } - if uperr := json.Unmarshal(jm, &up); uperr == nil { - if up.Username == "" || up.Password == "" { - continue - } - user.authMethods = append(user.authMethods, &up) - user.AuthMethods = append(user.AuthMethods, &up) - } - if pkerr := json.Unmarshal(jm, &pk); pkerr == nil { - if pk.Username == "" || len(pk.Pub) == 0 { - continue - } - user.authMethods = append(user.authMethods, &pk) - user.AuthMethods = append(user.AuthMethods, &pk) - } - } + user.Mutex = &sync.Mutex{} return &user, nil } -func NewUser(username string, authMethods ...AuthMethod) error { +func NewUser(username string, authMethods ...AuthMethod) (*User, error) { if len(username) == 0 { - return errors.New("username cannot be empty") + return nil, errors.New("username cannot be empty") } if len(authMethods) == 0 { - return errors.New("at least one authentication method must be provided") + return nil, errors.New("at least one authentication method must be provided") } - var methods []AuthMethod - var jsonMethods []any + var methods []map[string]string for _, method := range authMethods { if method == nil { - return errors.New("authentication method cannot be nil") + return nil, errors.New("authentication method cannot be nil") } - methods = append(methods, method) - jsonMethods = append(jsonMethods, method) + switch method.Name() { + case "password": + usableMethod := method.(*UserPass) + if len(usableMethod.Username) == 0 { + return nil, errors.New("username cannot be empty") + } + if len(usableMethod.Password) == 0 { + return nil, errors.New("password cannot be empty") + } + methods = append(methods, method.Map()) + case "publickey": + usableMethod := method.(*PubKey) + if len(usableMethod.Username) == 0 { + return nil, errors.New("username cannot be empty") + } + if len(usableMethod.Pub) == 0 { + return nil, errors.New("public key cannot be empty") + } + methods = append(methods, method.Map()) + } + } + if len(methods) == 0 { + return nil, errors.New("at least one authentication method must be provided") } user := &User{ Username: username, - AuthMethods: jsonMethods, + AuthMethods: methods, + Mutex: &sync.Mutex{}, } b, err := json.Marshal(user) if err != nil { - return err + return nil, err } - spew.Dump(b) - return db.With("users").Put([]byte(username), b) + return user, db.With("users").Put([]byte(username), b) } -func (user *User) AddAuthMethod(method AuthMethod) error { +func (user *User) AddAuthMethod(method AuthMethod) (*User, error) { + user.Lock() + defer user.Unlock() if method == nil { - return errors.New("authentication method cannot be nil") + return user, errors.New("authentication method cannot be nil") } - user.authMethods = append(user.authMethods, method) - user.AuthMethods = append(user.AuthMethods, method) + user.AuthMethods = append(user.AuthMethods, method.Map()) b, err := json.Marshal(user) if err != nil { - return err + return user, err } - return db.With("users").Put([]byte(user.Username), b) + return user, db.With("users").Put([]byte(user.Username), b) } func DelUser(username string) error { return db.With("users").Delete([]byte(username)) } -func (user *User) DelPubKey(pubkey []byte) error { +func (user *User) DelPubKey(pubkey []byte) (*User, error) { + user.Lock() + defer user.Unlock() var found = false - var jsonMethods []any - var methods []AuthMethod - for _, method := range user.authMethods { - if method.Name() == "publickey" { - pubKey := method.(*PubKey) + var methods []map[string]string + for _, method := range user.AuthMethods { + m := AuthMethodFromMap(method) + if m.Name() == "publickey" { + pubKey := m.(*PubKey) if bytes.Equal(pubKey.Pub, pubkey) { found = true continue } } methods = append(methods, method) - jsonMethods = append(jsonMethods, method) } if !found { - return errors.New("public key not found") + return user, errors.New("public key not found") } - user.AuthMethods = jsonMethods - user.authMethods = methods + user.AuthMethods = methods if b, err := json.Marshal(user); err == nil { - return db.With("users").Put([]byte(user.Username), b) + return user, db.With("users").Put([]byte(user.Username), b) } else { - return err + return user, err } } -func (user *User) ChangePassword(newPassword string) error { +func (user *User) ChangePassword(newPassword string) (*User, error) { + user.Lock() + defer user.Unlock() var ponce = &sync.Once{} - var methods []any - var authMethods []AuthMethod - for _, method := range user.authMethods { - if method.Name() == "password" { + var methods []map[string]string + for _, method := range user.AuthMethods { + m := AuthMethodFromMap(method) + if m.Name() == "password" { ponce.Do(func() { - method.(*UserPass).Password = newPassword + hashed, err := HashPassword(newPassword) + if err != nil { + panic(err) + } + m.(*UserPass).Password = hashed }) } - methods = append(methods, method) - authMethods = append(authMethods, method) + methods = append(methods, m.Map()) } user.AuthMethods = methods - user.authMethods = authMethods b, err := json.Marshal(user) if err != nil { - return err + return user, err } - return db.With("users").Put([]byte(user.Username), b) + return user, db.With("users").Put([]byte(user.Username), b) } func provisionFakeUser() *User { - err := NewUser("0", &UserPass{Password: "0"}) + user, err := NewUser("0", NewUserPass(true, "0", entropy.RandStrWithUpper(32))) if err != nil { log.Panic().Err(err).Msg("error creating fake user") } - var user *User - user, err = GetUser("0") - if err != nil { - log.Panic().Err(err).Msg("error getting user") - } return user } @@ -237,7 +291,10 @@ func FakeCycle() { user = provisionFakeUser() } - for _, method := range user.authMethods { - _ = method.Authenticate() + for n, method := range user.AuthMethods { + if n > 2 { + break + } + _ = AuthMethodFromMap(method).Authenticate() } } diff --git a/internal/data/users_test.go b/internal/data/users_test.go index 77f466b..8f6a214 100644 --- a/internal/data/users_test.go +++ b/internal/data/users_test.go @@ -14,90 +14,106 @@ func TestUsers(t *testing.T) { } }) t.Run("NewUser", func(t *testing.T) { - if err := NewUser("test"); err == nil { + if _, err := NewUser("test1"); err == nil { t.Fatal("expected error creating user with no auth method") } - if _, err := GetUser("test"); err == nil { + if _, err := GetUser("test1"); err == nil { t.Fatal("expected error getting user with no auth method") } - if err := NewUser("test", NewUserPass("test", "test")); err != nil { + if _, err := NewUser("test1", NewUserPass(true, "test", "test")); err != nil { t.Fatal(err) } - tu, err := GetUser("test") + tu, err := GetUser("test1") if err != nil { t.Fatal(err) } - if len(tu.authMethods) != 1 { - t.Fatalf("expected 1 auth method, got %d", len(tu.authMethods)) + if len(tu.AuthMethods) != 1 { + t.Fatalf("expected 1 auth method, got %d", len(tu.AuthMethods)) } - if tu.authMethods[0].Name() != "password" { - t.Fatalf("expected auth method to be 'password', got '%s'", tu.authMethods[0].Name()) + if tu.AuthMethods[0]["type"] != "password" { + t.Fatalf("expected auth method to be 'password', got '%s'", tu.AuthMethods[0]["type"]) } - if tu.Username != "test" { + if tu.Username != "test1" { t.Fatalf("expected username to be 'test', got '%s'", tu.Username) } }) t.Run("AddAuthMethod", func(t *testing.T) { - user, err := GetUser("test") + user, err := NewUser("test2", NewUserPass(true, "test2", "test2")) if err != nil { t.Fatal(err) } - if err = user.AddAuthMethod(nil); err == nil { + if user, err = user.AddAuthMethod(nil); err == nil { t.Fatal("expected error adding nil auth method") } - if err = user.AddAuthMethod(&PubKey{Username: "test", Pub: []byte("test")}); err != nil { + if user == nil { + t.Fatal("expected user to not be nil") + } + if user, err = user.AddAuthMethod(&PubKey{Username: "test2", Pub: []byte("pub")}); err != nil { t.Fatal(err) } - if len(user.authMethods) != 2 { - t.Fatalf("expected 2 auth methods, got %d", len(user.authMethods)) + if len(user.AuthMethods) != 2 { + t.Fatalf("expected 2 auth methods, got %d", len(user.AuthMethods)) } - if user.authMethods[0].Name() != "password" { - t.Fatalf("expected auth method to be 'password', got '%s'", user.authMethods[0].Name()) + pk := &PubKey{Username: "test2", Pub: []byte("pub")} + if err = pk.Authenticate(); err != nil { + t.Fatal("expected pub key to authenticate") } - if user.authMethods[1].Name() != "publickey" { - t.Fatalf("expected auth method to be 'publickey', got '%s'", user.authMethods[1].Name()) + if user, err = user.AddAuthMethod(&PubKey{Username: "test2", Pub: []byte("pub2")}); err != nil { + t.Fatal(err) + } + if len(user.AuthMethods) != 3 { + t.Fatalf("expected 2 auth methods, got %d", len(user.AuthMethods)) + } + if user.AuthMethods[0]["type"] != "password" { + t.Fatalf("expected auth method to be 'password', got '%s'", user.AuthMethods[0]["type"]) + } + if user.AuthMethods[1]["type"] != "publickey" { + t.Fatalf("expected auth method to be 'publickey', got '%s'", user.AuthMethods[1]["type"]) } auth := &PubKey{ - Username: "test", - Pub: []byte("test"), + Username: "test2", + Pub: []byte("pub"), } if err = auth.Authenticate(); err != nil { t.Fatalf("expected auth to succeed, got: %v", err) } - auth.Pub = []byte("test2") + auth.Pub = []byte("asdjfas") if err = auth.Authenticate(); err == nil { t.Fatal("expected auth to fail") } }) t.Run("DelPubKey", func(t *testing.T) { - user, err := GetUser("test") + user, err := GetUser("test2") if err != nil { t.Fatal(err) } - if err = user.DelPubKey([]byte("test2")); err == nil { + if user, err = user.DelPubKey([]byte("fdsafdas")); err == nil { t.Fatal("expected error deleting non-existent key") } - if err = user.DelPubKey([]byte("test")); err != nil { + if user == nil { + t.Fatal("expected user to not be nil") + } + if user, err = user.DelPubKey([]byte("pub2")); err != nil { t.Fatal(err) } - auth := NewUserPass("test", "test") - if err := auth.Authenticate(); err != nil { + auth := NewUserPass(false, "test2", "test2") + if err = auth.Authenticate(); err != nil { t.Fatalf("expected userpass to still be there after deleting public key, got: %v", err) } }) t.Run("ChangePassword", func(t *testing.T) { - user, err := GetUser("test") + user, err := GetUser("test2") if err != nil { t.Fatal(err) } - if err = user.ChangePassword("test2"); err != nil { + if user, err = user.ChangePassword("test5"); err != nil { t.Fatal(err) } - auth := NewUserPass("test", "test") + auth := NewUserPass(false, "test2", "test2") if err = auth.Authenticate(); err == nil { t.Fatal("expected auth to fail using old password") } - auth = NewUserPass("test", "test2") + auth = NewUserPass(false, "test2", "test5") if err = auth.Authenticate(); err != nil { t.Fatalf("expected auth to succeed using new password, got: %v", err) }