Pass tests for user db (likely still broken somehow)

This commit is contained in:
kayos@tcp.direct 2023-01-08 10:16:04 -08:00
parent 5935e8127d
commit 968706823c
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
2 changed files with 215 additions and 142 deletions

@ -6,16 +6,45 @@ import (
"errors" "errors"
"sync" "sync"
"github.com/davecgh/go-spew/spew" "git.tcp.direct/kayos/common/entropy"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type StringMapper interface {
Map() map[string]string
}
type AuthMethod interface {
json.Marshaler
StringMapper
// Authenticate authenticates the user.
Authenticate() error
// Name returns the name of the authentication method.
Name() string
}
func AuthMethodFromMap(m map[string]string) AuthMethod {
switch m["type"] {
case "password":
return &UserPass{
Username: m["pass_username"],
Password: m["password"],
}
case "publickey":
return &PubKey{
Username: m["pub_username"],
Pub: []byte(m["pubkey"]),
}
}
return nil
}
var ErrAccessDenied = errors.New("access denied") var ErrAccessDenied = errors.New("access denied")
type User struct { type User struct {
Username string `json:"username"` Username string `json:"username"`
AuthMethods []any `json:"auth_methods"` AuthMethods []map[string]string `json:"auth_methods"`
authMethods []AuthMethod *sync.Mutex
} }
type UserPass struct { type UserPass struct {
@ -23,210 +52,235 @@ type UserPass struct {
Password string `json:"password"` Password string `json:"password"`
} }
func NewUserPass(username, password string) *UserPass { func (up *UserPass) Name() string {
hash, err := HashPassword(password)
if err != nil {
panic(err)
}
return &UserPass{
Username: username,
Password: hash,
}
}
func (u *UserPass) Name() string {
return "password" return "password"
} }
func (u *UserPass) Authenticate() error { func (up *UserPass) Map() map[string]string {
user, err := GetUser(u.Username) return map[string]string{
"type": "password",
"pass_username": up.Username,
"password": up.Password,
}
}
func (up *UserPass) MarshalJSON() ([]byte, error) {
return json.Marshal(up.Map())
}
func NewUserPass(hashIt bool, username, password string) *UserPass {
var input = password
var err error
if hashIt {
input, err = HashPassword(password)
if err != nil { if err != nil {
log.Warn().Err(err).Str("username", u.Username).Msg("error getting user") panic(err)
}
}
return &UserPass{
Username: username,
Password: input,
}
}
func (up *UserPass) Authenticate() error {
user, err := GetUser(up.Username)
if err != nil {
log.Warn().Err(err).Str("username", up.Username).Msg("error getting user")
FakeCycle() FakeCycle()
return ErrAccessDenied return ErrAccessDenied
} }
for _, method := range user.authMethods { for _, method := range user.AuthMethods {
if method.Name() == "password" { switch method["type"] {
userPass := method.(*UserPass) case "password":
CheckPasswordHash(u.Password, userPass.Password) if method["pass_username"] == up.Username && CheckPasswordHash(up.Password, method["password"]) {
return nil
}
default:
continue
} }
} }
return ErrAccessDenied return ErrAccessDenied
} }
type PubKey struct { type PubKey struct {
Username string `json:"pubusername"` Username string `json:"pub_username"`
Pub []byte `json:"pubkey"` Pub []byte `json:"pubkey"`
} }
func (p PubKey) Name() string { func (pk *PubKey) Name() string {
return "publickey" return "publickey"
} }
func (p PubKey) Authenticate() error { func (pk *PubKey) Map() map[string]string {
u, err := GetUser(p.Username) return map[string]string{
"type": "publickey",
"pub_username": pk.Username,
"pubkey": string(pk.Pub),
}
}
func (pk *PubKey) MarshalJSON() ([]byte, error) {
return json.Marshal(pk.Map())
}
func (pk *PubKey) Authenticate() error {
user, err := GetUser(pk.Username)
if err != nil { if err != nil {
log.Warn().Err(err).Str("username", p.Username).Msg("error getting user") log.Warn().Err(err).Str("username", pk.Username).Msg("error getting user")
FakeCycle() FakeCycle()
return ErrAccessDenied return ErrAccessDenied
} }
spew.Dump(u) for _, method := range user.AuthMethods {
for _, method := range u.authMethods { switch method["type"] {
if method.Name() == "publickey" { case "publickey":
pubKey := method.(*PubKey) if method["pub_username"] == pk.Username && bytes.Equal([]byte(method["pubkey"]), pk.Pub) {
if bytes.Equal(pubKey.Pub, p.Pub) {
return nil return nil
} }
default:
continue
} }
} }
return ErrAccessDenied return ErrAccessDenied
} }
type AuthMethod interface {
// Name returns the name of the authentication method.
Name() string
// Authenticate authenticates the user.
Authenticate() error
}
func GetUser(username string) (*User, error) { func GetUser(username string) (*User, error) {
res, err := db.With("users").Get([]byte(username)) res, err := db.With("users").Get([]byte(username))
if err != nil { if err != nil {
return nil, err return nil, err
} }
var user User var user User
if err := json.Unmarshal(res, &user); err != nil { if err = json.Unmarshal(res, &user); err != nil {
return nil, err return nil, err
} }
for _, method := range user.AuthMethods { user.Mutex = &sync.Mutex{}
var up UserPass
var pk PubKey
jm, err := json.Marshal(method)
if err != nil {
return nil, err
}
if uperr := json.Unmarshal(jm, &up); uperr == nil {
if up.Username == "" || up.Password == "" {
continue
}
user.authMethods = append(user.authMethods, &up)
user.AuthMethods = append(user.AuthMethods, &up)
}
if pkerr := json.Unmarshal(jm, &pk); pkerr == nil {
if pk.Username == "" || len(pk.Pub) == 0 {
continue
}
user.authMethods = append(user.authMethods, &pk)
user.AuthMethods = append(user.AuthMethods, &pk)
}
}
return &user, nil return &user, nil
} }
func NewUser(username string, authMethods ...AuthMethod) error { func NewUser(username string, authMethods ...AuthMethod) (*User, error) {
if len(username) == 0 { if len(username) == 0 {
return errors.New("username cannot be empty") return nil, errors.New("username cannot be empty")
} }
if len(authMethods) == 0 { if len(authMethods) == 0 {
return errors.New("at least one authentication method must be provided") return nil, errors.New("at least one authentication method must be provided")
} }
var methods []AuthMethod var methods []map[string]string
var jsonMethods []any
for _, method := range authMethods { for _, method := range authMethods {
if method == nil { if method == nil {
return errors.New("authentication method cannot be nil") return nil, errors.New("authentication method cannot be nil")
} }
methods = append(methods, method) switch method.Name() {
jsonMethods = append(jsonMethods, method) case "password":
usableMethod := method.(*UserPass)
if len(usableMethod.Username) == 0 {
return nil, errors.New("username cannot be empty")
}
if len(usableMethod.Password) == 0 {
return nil, errors.New("password cannot be empty")
}
methods = append(methods, method.Map())
case "publickey":
usableMethod := method.(*PubKey)
if len(usableMethod.Username) == 0 {
return nil, errors.New("username cannot be empty")
}
if len(usableMethod.Pub) == 0 {
return nil, errors.New("public key cannot be empty")
}
methods = append(methods, method.Map())
}
}
if len(methods) == 0 {
return nil, errors.New("at least one authentication method must be provided")
} }
user := &User{ user := &User{
Username: username, Username: username,
AuthMethods: jsonMethods, AuthMethods: methods,
Mutex: &sync.Mutex{},
} }
b, err := json.Marshal(user) b, err := json.Marshal(user)
if err != nil { if err != nil {
return err return nil, err
} }
spew.Dump(b) return user, db.With("users").Put([]byte(username), b)
return db.With("users").Put([]byte(username), b)
} }
func (user *User) AddAuthMethod(method AuthMethod) error { func (user *User) AddAuthMethod(method AuthMethod) (*User, error) {
user.Lock()
defer user.Unlock()
if method == nil { if method == nil {
return errors.New("authentication method cannot be nil") return user, errors.New("authentication method cannot be nil")
} }
user.authMethods = append(user.authMethods, method) user.AuthMethods = append(user.AuthMethods, method.Map())
user.AuthMethods = append(user.AuthMethods, method)
b, err := json.Marshal(user) b, err := json.Marshal(user)
if err != nil { if err != nil {
return err return user, err
} }
return db.With("users").Put([]byte(user.Username), b) return user, db.With("users").Put([]byte(user.Username), b)
} }
func DelUser(username string) error { func DelUser(username string) error {
return db.With("users").Delete([]byte(username)) return db.With("users").Delete([]byte(username))
} }
func (user *User) DelPubKey(pubkey []byte) error { func (user *User) DelPubKey(pubkey []byte) (*User, error) {
user.Lock()
defer user.Unlock()
var found = false var found = false
var jsonMethods []any var methods []map[string]string
var methods []AuthMethod for _, method := range user.AuthMethods {
for _, method := range user.authMethods { m := AuthMethodFromMap(method)
if method.Name() == "publickey" { if m.Name() == "publickey" {
pubKey := method.(*PubKey) pubKey := m.(*PubKey)
if bytes.Equal(pubKey.Pub, pubkey) { if bytes.Equal(pubKey.Pub, pubkey) {
found = true found = true
continue continue
} }
} }
methods = append(methods, method) methods = append(methods, method)
jsonMethods = append(jsonMethods, method)
} }
if !found { if !found {
return errors.New("public key not found") return user, errors.New("public key not found")
} }
user.AuthMethods = jsonMethods user.AuthMethods = methods
user.authMethods = methods
if b, err := json.Marshal(user); err == nil { if b, err := json.Marshal(user); err == nil {
return db.With("users").Put([]byte(user.Username), b) return user, db.With("users").Put([]byte(user.Username), b)
} else { } else {
return err return user, err
} }
} }
func (user *User) ChangePassword(newPassword string) error { func (user *User) ChangePassword(newPassword string) (*User, error) {
user.Lock()
defer user.Unlock()
var ponce = &sync.Once{} var ponce = &sync.Once{}
var methods []any var methods []map[string]string
var authMethods []AuthMethod for _, method := range user.AuthMethods {
for _, method := range user.authMethods { m := AuthMethodFromMap(method)
if method.Name() == "password" { if m.Name() == "password" {
ponce.Do(func() { ponce.Do(func() {
method.(*UserPass).Password = newPassword hashed, err := HashPassword(newPassword)
if err != nil {
panic(err)
}
m.(*UserPass).Password = hashed
}) })
} }
methods = append(methods, method) methods = append(methods, m.Map())
authMethods = append(authMethods, method)
} }
user.AuthMethods = methods user.AuthMethods = methods
user.authMethods = authMethods
b, err := json.Marshal(user) b, err := json.Marshal(user)
if err != nil { if err != nil {
return err return user, err
} }
return db.With("users").Put([]byte(user.Username), b) return user, db.With("users").Put([]byte(user.Username), b)
} }
func provisionFakeUser() *User { func provisionFakeUser() *User {
err := NewUser("0", &UserPass{Password: "0"}) user, err := NewUser("0", NewUserPass(true, "0", entropy.RandStrWithUpper(32)))
if err != nil { if err != nil {
log.Panic().Err(err).Msg("error creating fake user") log.Panic().Err(err).Msg("error creating fake user")
} }
var user *User
user, err = GetUser("0")
if err != nil {
log.Panic().Err(err).Msg("error getting user")
}
return user return user
} }
@ -237,7 +291,10 @@ func FakeCycle() {
user = provisionFakeUser() user = provisionFakeUser()
} }
for _, method := range user.authMethods { for n, method := range user.AuthMethods {
_ = method.Authenticate() if n > 2 {
break
}
_ = AuthMethodFromMap(method).Authenticate()
} }
} }

@ -14,90 +14,106 @@ func TestUsers(t *testing.T) {
} }
}) })
t.Run("NewUser", func(t *testing.T) { t.Run("NewUser", func(t *testing.T) {
if err := NewUser("test"); err == nil { if _, err := NewUser("test1"); err == nil {
t.Fatal("expected error creating user with no auth method") t.Fatal("expected error creating user with no auth method")
} }
if _, err := GetUser("test"); err == nil { if _, err := GetUser("test1"); err == nil {
t.Fatal("expected error getting user with no auth method") t.Fatal("expected error getting user with no auth method")
} }
if err := NewUser("test", NewUserPass("test", "test")); err != nil { if _, err := NewUser("test1", NewUserPass(true, "test", "test")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
tu, err := GetUser("test") tu, err := GetUser("test1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(tu.authMethods) != 1 { if len(tu.AuthMethods) != 1 {
t.Fatalf("expected 1 auth method, got %d", len(tu.authMethods)) t.Fatalf("expected 1 auth method, got %d", len(tu.AuthMethods))
} }
if tu.authMethods[0].Name() != "password" { if tu.AuthMethods[0]["type"] != "password" {
t.Fatalf("expected auth method to be 'password', got '%s'", tu.authMethods[0].Name()) t.Fatalf("expected auth method to be 'password', got '%s'", tu.AuthMethods[0]["type"])
} }
if tu.Username != "test" { if tu.Username != "test1" {
t.Fatalf("expected username to be 'test', got '%s'", tu.Username) t.Fatalf("expected username to be 'test', got '%s'", tu.Username)
} }
}) })
t.Run("AddAuthMethod", func(t *testing.T) { t.Run("AddAuthMethod", func(t *testing.T) {
user, err := GetUser("test") user, err := NewUser("test2", NewUserPass(true, "test2", "test2"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = user.AddAuthMethod(nil); err == nil { if user, err = user.AddAuthMethod(nil); err == nil {
t.Fatal("expected error adding nil auth method") t.Fatal("expected error adding nil auth method")
} }
if err = user.AddAuthMethod(&PubKey{Username: "test", Pub: []byte("test")}); err != nil { if user == nil {
t.Fatal("expected user to not be nil")
}
if user, err = user.AddAuthMethod(&PubKey{Username: "test2", Pub: []byte("pub")}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(user.authMethods) != 2 { if len(user.AuthMethods) != 2 {
t.Fatalf("expected 2 auth methods, got %d", len(user.authMethods)) t.Fatalf("expected 2 auth methods, got %d", len(user.AuthMethods))
} }
if user.authMethods[0].Name() != "password" { pk := &PubKey{Username: "test2", Pub: []byte("pub")}
t.Fatalf("expected auth method to be 'password', got '%s'", user.authMethods[0].Name()) if err = pk.Authenticate(); err != nil {
t.Fatal("expected pub key to authenticate")
} }
if user.authMethods[1].Name() != "publickey" { if user, err = user.AddAuthMethod(&PubKey{Username: "test2", Pub: []byte("pub2")}); err != nil {
t.Fatalf("expected auth method to be 'publickey', got '%s'", user.authMethods[1].Name()) t.Fatal(err)
}
if len(user.AuthMethods) != 3 {
t.Fatalf("expected 2 auth methods, got %d", len(user.AuthMethods))
}
if user.AuthMethods[0]["type"] != "password" {
t.Fatalf("expected auth method to be 'password', got '%s'", user.AuthMethods[0]["type"])
}
if user.AuthMethods[1]["type"] != "publickey" {
t.Fatalf("expected auth method to be 'publickey', got '%s'", user.AuthMethods[1]["type"])
} }
auth := &PubKey{ auth := &PubKey{
Username: "test", Username: "test2",
Pub: []byte("test"), Pub: []byte("pub"),
} }
if err = auth.Authenticate(); err != nil { if err = auth.Authenticate(); err != nil {
t.Fatalf("expected auth to succeed, got: %v", err) t.Fatalf("expected auth to succeed, got: %v", err)
} }
auth.Pub = []byte("test2") auth.Pub = []byte("asdjfas")
if err = auth.Authenticate(); err == nil { if err = auth.Authenticate(); err == nil {
t.Fatal("expected auth to fail") t.Fatal("expected auth to fail")
} }
}) })
t.Run("DelPubKey", func(t *testing.T) { t.Run("DelPubKey", func(t *testing.T) {
user, err := GetUser("test") user, err := GetUser("test2")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = user.DelPubKey([]byte("test2")); err == nil { if user, err = user.DelPubKey([]byte("fdsafdas")); err == nil {
t.Fatal("expected error deleting non-existent key") t.Fatal("expected error deleting non-existent key")
} }
if err = user.DelPubKey([]byte("test")); err != nil { if user == nil {
t.Fatal("expected user to not be nil")
}
if user, err = user.DelPubKey([]byte("pub2")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
auth := NewUserPass("test", "test") auth := NewUserPass(false, "test2", "test2")
if err := auth.Authenticate(); err != nil { if err = auth.Authenticate(); err != nil {
t.Fatalf("expected userpass to still be there after deleting public key, got: %v", err) t.Fatalf("expected userpass to still be there after deleting public key, got: %v", err)
} }
}) })
t.Run("ChangePassword", func(t *testing.T) { t.Run("ChangePassword", func(t *testing.T) {
user, err := GetUser("test") user, err := GetUser("test2")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = user.ChangePassword("test2"); err != nil { if user, err = user.ChangePassword("test5"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
auth := NewUserPass("test", "test") auth := NewUserPass(false, "test2", "test2")
if err = auth.Authenticate(); err == nil { if err = auth.Authenticate(); err == nil {
t.Fatal("expected auth to fail using old password") t.Fatal("expected auth to fail using old password")
} }
auth = NewUserPass("test", "test2") auth = NewUserPass(false, "test2", "test5")
if err = auth.Authenticate(); err != nil { if err = auth.Authenticate(); err != nil {
t.Fatalf("expected auth to succeed using new password, got: %v", err) t.Fatalf("expected auth to succeed using new password, got: %v", err)
} }