sh3lly/cmd/sh3lly/cmd.go

188 lines
4.0 KiB
Go

package main
import (
"fmt"
_ "net/http/pprof"
"os"
"os/signal"
"strings"
"syscall"
"github.com/rs/zerolog"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"git.tcp.direct/kayos/sh3lly/auth"
"git.tcp.direct/kayos/sh3lly/bus"
"git.tcp.direct/kayos/sh3lly/chat/message"
"git.tcp.direct/kayos/sh3lly/config"
"git.tcp.direct/kayos/sh3lly/host"
"git.tcp.direct/kayos/sh3lly/shells"
"git.tcp.direct/kayos/sh3lly/sshd"
)
var log zerolog.Logger
var authdb *auth.UserDB
func init() {
var err error
config.Init()
log = config.StartLogger()
authdb, err = auth.NewUserDB(config.DataPath)
if err != nil {
// TODO: probably not panic
panic(err)
}
checkArgs()
shells.Start()
}
func confirm() bool {
var s string
fmt.Printf("Continue? (y/N) ")
_, err := fmt.Scan(&s)
if err != nil {
panic(err)
}
s = strings.TrimSpace(s)
s = strings.ToLower(s)
if s == "y" || s == "yes" {
return true
}
return false
}
func check(err error) {
if err != nil {
println(err.Error())
os.Exit(1)
}
}
func registerAdmin(username string) {
var usr *auth.RegisteredUser
println("registering " + username + "...")
if !confirm() {
os.Exit(0)
}
print("Enter Password: ")
//goland:noinspection VacuumOverwroteError
bytePassword, err := term.ReadPassword(syscall.Stdin)
check(err)
password := string(bytePassword)
usr, err = authdb.Register(username, password)
check(err)
check(authdb.SetPrivLevel(usr, auth.Admin))
println("successfully registered user: " + username)
os.Exit(0)
}
func deleteUser(username string) {
println("deleting " + username + "...")
if !confirm() {
os.Exit(0)
}
check(authdb.Delete(username))
os.Exit(0)
}
func checkArgs() {
for i, arg := range os.Args {
getusername := func() string {
if len(os.Args) < i+2 {
println("missing username")
os.Exit(1)
}
return os.Args[i+1]
}
switch arg {
case "--register-admin":
registerAdmin(getusername())
case "--delete-user":
deleteUser(getusername())
}
}
}
func getSigner() ssh.Signer {
// TODO: replace with config path
signer, err := ReadPrivateKey("./privkey.pem")
if err != nil {
log.Error().Err(err).Msg("Failed to read identity private key, generating new...")
if err := sshd.Keygen("./", 4096); err != nil {
log.Fatal().Err(err).Msg("failed to generate new keys")
}
signer, err = ReadPrivateKey("./privkey.pem")
if err != nil {
panic(err)
}
}
return signer
}
func main() {
signer := getSigner()
authwrapper := auth.MakeAuth(authdb)
authwrapper.AddHostKey(signer)
// TODO: Replace with config string
authwrapper.ServerVersion = fmt.Sprintf("SSH-2.0-%s", config.SSHVersion)
// FIXME: Should we be using config.NoClientAuth = true by default?
listenstr := fmt.Sprintf("%s:%d", config.BindAddr, config.BindPort)
s, err := sshd.ListenSSH(listenstr, authwrapper)
if err != nil {
log.Fatal().Err(err).Str("listenstr", listenstr).Msg("failed to listen")
}
defer s.Close()
// s.RateLimit = sshd.NewInputLimiter
_ = bus.NewMessenger("reg")
fmt.Printf("Listening for SSH connections on %v\n", s.Addr().String())
HostServer := host.NewHost(s, authdb)
HostServer.SetTheme(message.Themes[0])
// TODO: replace with config string
HostServer.Version = config.SSHVersion
/* var motd = false
if motd {
HostServer.GetMOTD = func() (string, error) {
// TODO: replace with config string
motd, err := ioutil.ReadFile("motd")
if err != nil {
return "", err
}
motdString := string(motd)
// hack to normalize line endings into \r\n
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
return motdString, nil
}
motdString, err:= HostServer.GetMOTD()
if err != nil {
log.Fatal().Err(err).Msg("failed to load motd")
}
HostServer.SetMotd(motdString)
}*/
go HostServer.Serve()
// Construct interrupt handler
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt)
<-sig // Wait for ^C signal
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
authdb.DB.Sync()
}