188 lines
4.0 KiB
Go
188 lines
4.0 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
_ "net/http/pprof"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/rs/zerolog"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/term"
|
|
|
|
"git.tcp.direct/kayos/sh3lly/auth"
|
|
"git.tcp.direct/kayos/sh3lly/bus"
|
|
"git.tcp.direct/kayos/sh3lly/chat/message"
|
|
"git.tcp.direct/kayos/sh3lly/config"
|
|
"git.tcp.direct/kayos/sh3lly/host"
|
|
"git.tcp.direct/kayos/sh3lly/shells"
|
|
"git.tcp.direct/kayos/sh3lly/sshd"
|
|
)
|
|
|
|
var log zerolog.Logger
|
|
var authdb *auth.UserDB
|
|
|
|
func init() {
|
|
var err error
|
|
config.Init()
|
|
log = config.StartLogger()
|
|
|
|
authdb, err = auth.NewUserDB(config.DataPath)
|
|
if err != nil {
|
|
// TODO: probably not panic
|
|
panic(err)
|
|
}
|
|
|
|
checkArgs()
|
|
|
|
shells.Start()
|
|
}
|
|
|
|
func confirm() bool {
|
|
var s string
|
|
|
|
fmt.Printf("Continue? (y/N) ")
|
|
_, err := fmt.Scan(&s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
s = strings.TrimSpace(s)
|
|
s = strings.ToLower(s)
|
|
|
|
if s == "y" || s == "yes" {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func check(err error) {
|
|
if err != nil {
|
|
println(err.Error())
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func registerAdmin(username string) {
|
|
var usr *auth.RegisteredUser
|
|
println("registering " + username + "...")
|
|
if !confirm() {
|
|
os.Exit(0)
|
|
}
|
|
print("Enter Password: ")
|
|
//goland:noinspection VacuumOverwroteError
|
|
bytePassword, err := term.ReadPassword(syscall.Stdin)
|
|
check(err)
|
|
password := string(bytePassword)
|
|
usr, err = authdb.Register(username, password)
|
|
check(err)
|
|
check(authdb.SetPrivLevel(usr, auth.Admin))
|
|
println("successfully registered user: " + username)
|
|
os.Exit(0)
|
|
}
|
|
|
|
func deleteUser(username string) {
|
|
println("deleting " + username + "...")
|
|
if !confirm() {
|
|
os.Exit(0)
|
|
}
|
|
check(authdb.Delete(username))
|
|
os.Exit(0)
|
|
}
|
|
|
|
func checkArgs() {
|
|
for i, arg := range os.Args {
|
|
getusername := func() string {
|
|
if len(os.Args) < i+2 {
|
|
println("missing username")
|
|
os.Exit(1)
|
|
}
|
|
return os.Args[i+1]
|
|
}
|
|
|
|
switch arg {
|
|
case "--register-admin":
|
|
registerAdmin(getusername())
|
|
case "--delete-user":
|
|
deleteUser(getusername())
|
|
}
|
|
}
|
|
}
|
|
|
|
func getSigner() ssh.Signer {
|
|
// TODO: replace with config path
|
|
signer, err := ReadPrivateKey("./privkey.pem")
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to read identity private key, generating new...")
|
|
if err := sshd.Keygen("./", 4096); err != nil {
|
|
log.Fatal().Err(err).Msg("failed to generate new keys")
|
|
}
|
|
signer, err = ReadPrivateKey("./privkey.pem")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
return signer
|
|
}
|
|
|
|
func main() {
|
|
signer := getSigner()
|
|
|
|
authwrapper := auth.MakeAuth(authdb)
|
|
authwrapper.AddHostKey(signer)
|
|
|
|
// TODO: Replace with config string
|
|
authwrapper.ServerVersion = fmt.Sprintf("SSH-2.0-%s", config.SSHVersion)
|
|
// FIXME: Should we be using config.NoClientAuth = true by default?
|
|
|
|
listenstr := fmt.Sprintf("%s:%d", config.BindAddr, config.BindPort)
|
|
s, err := sshd.ListenSSH(listenstr, authwrapper)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Str("listenstr", listenstr).Msg("failed to listen")
|
|
}
|
|
defer s.Close()
|
|
// s.RateLimit = sshd.NewInputLimiter
|
|
|
|
_ = bus.NewMessenger("reg")
|
|
|
|
fmt.Printf("Listening for SSH connections on %v\n", s.Addr().String())
|
|
|
|
HostServer := host.NewHost(s, authdb)
|
|
HostServer.SetTheme(message.Themes[0])
|
|
// TODO: replace with config string
|
|
HostServer.Version = config.SSHVersion
|
|
|
|
/* var motd = false
|
|
if motd {
|
|
HostServer.GetMOTD = func() (string, error) {
|
|
// TODO: replace with config string
|
|
motd, err := ioutil.ReadFile("motd")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
motdString := string(motd)
|
|
// hack to normalize line endings into \r\n
|
|
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
|
|
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
|
|
return motdString, nil
|
|
}
|
|
motdString, err:= HostServer.GetMOTD()
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("failed to load motd")
|
|
}
|
|
HostServer.SetMotd(motdString)
|
|
}*/
|
|
|
|
go HostServer.Serve()
|
|
|
|
// Construct interrupt handler
|
|
sig := make(chan os.Signal, 1)
|
|
signal.Notify(sig, os.Interrupt)
|
|
|
|
<-sig // Wait for ^C signal
|
|
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
|
|
authdb.DB.Sync()
|
|
}
|