diff --git a/_examples/ssh-remoteforward/portforward.go b/_examples/ssh-remoteforward/portforward.go new file mode 100644 index 0000000..33bb1e9 --- /dev/null +++ b/_examples/ssh-remoteforward/portforward.go @@ -0,0 +1,31 @@ +package main + +import ( + "io" + "log" + + "github.com/gliderlabs/ssh" +) + +func main() { + + log.Println("starting ssh server on port 2222...") + + server := ssh.Server{ + LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool { + log.Println("Accepted forward", dhost, dport) + return true + }), + Addr: ":2222", + Handler: ssh.Handler(func(s ssh.Session) { + io.WriteString(s, "Remote forwarding available...\n") + select {} + }), + ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, host string, port uint32) bool { + log.Println("attempt to bind", host, port, "granted") + return true + }), + } + + log.Fatal(server.ListenAndServe()) +} diff --git a/context.go b/context.go index 008ab5b..31531f4 100644 --- a/context.go +++ b/context.go @@ -48,7 +48,7 @@ var ( ContextKeyServer = &contextKey{"ssh-server"} // ContextKeyConn is a context key for use with Contexts in this package. - // The associated value will be of type gossh.Conn. + // The associated value will be of type gossh.ServerConn. ContextKeyConn = &contextKey{"ssh-conn"} // ContextKeyPublicKey is a context key for use with Contexts in this package. diff --git a/doc.go b/doc.go index 0361fb9..5a10393 100644 --- a/doc.go +++ b/doc.go @@ -1,5 +1,4 @@ /* - Package ssh wraps the crypto/ssh package with a higher-level API for building SSH servers. The goal of the API was to make it as simple as using net/http, so the API is very similar. @@ -42,6 +41,5 @@ exposed to you via the Session interface. The one big feature missing from the Session abstraction is signals. This was started, but not completed. Pull Requests welcome! - */ package ssh diff --git a/server.go b/server.go index 09739e1..31e3353 100644 --- a/server.go +++ b/server.go @@ -24,16 +24,18 @@ type Server struct { HostSigners []Signer // private keys for the host key, must have at least one Version string // server version to be sent before the initial handshake - PasswordHandler PasswordHandler // password authentication handler - PublicKeyHandler PublicKeyHandler // public key authentication handler - PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil - ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling - LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + PasswordHandler PasswordHandler // password authentication handler + PublicKeyHandler PublicKeyHandler // public key authentication handler + PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil + ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling + LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + ReversePortForwardingCallback ReversePortForwardingCallback //callback for allowing reverse port forwarding, denies all if nil IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty channelHandlers map[string]channelHandler + requestHandlers map[string]RequestHandler listenerWg sync.WaitGroup mu sync.Mutex @@ -42,6 +44,9 @@ type Server struct { connWg sync.WaitGroup doneChan chan struct{} } +type RequestHandler interface { + HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) +} // internal for now type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) @@ -57,6 +62,19 @@ func (srv *Server) ensureHostSigner() error { return nil } +func (srv *Server) ensureHandlers() { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.requestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler{}, + "cancel-tcpip-forward": forwardedTCPHandler{}, + } + srv.channelHandlers = map[string]channelHandler{ + "session": sessionHandler, + "direct-tcpip": directTcpipHandler, + } +} + func (srv *Server) config(ctx Context) *gossh.ServerConfig { config := &gossh.ServerConfig{} for _, signer := range srv.HostSigners { @@ -144,6 +162,7 @@ func (srv *Server) Shutdown(ctx context.Context) error { // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { + srv.ensureHandlers() defer l.Close() if err := srv.ensureHostSigner(); err != nil { return err @@ -217,7 +236,8 @@ func (srv *Server) handleConn(newConn net.Conn) { ctx.SetValue(ContextKeyConn, sshConn) applyConnMetadata(ctx, sshConn) - go gossh.DiscardRequests(reqs) + //go gossh.DiscardRequests(reqs) + go srv.handleRequests(ctx, reqs) for ch := range chans { handler, found := srv.channelHandlers[ch.ChannelType()] if !found { @@ -228,6 +248,22 @@ func (srv *Server) handleConn(newConn net.Conn) { } } +func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { + for req := range in { + handler, found := srv.requestHandlers[req.Type] + if !found && req.WantReply { + req.Reply(false, nil) + continue + } + /*reqCtx, cancel := context.WithCancel(ctx) + defer cancel() */ + ret, payload := handler.HandleRequest(ctx, srv, req) + if req.WantReply { + req.Reply(ret, payload) + } + } +} + // ListenAndServe listens on the TCP network address srv.Addr and then calls // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. // ListenAndServe always returns a non-nil error. diff --git a/session.go b/session.go index b745477..c3db354 100644 --- a/session.go +++ b/session.go @@ -282,6 +282,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { req.Reply(true, nil) default: // TODO: debug log + req.Reply(false, nil) } } } diff --git a/session_test.go b/session_test.go index f23a4fb..06c7724 100644 --- a/session_test.go +++ b/session_test.go @@ -11,6 +11,7 @@ import ( ) func (srv *Server) serveOnce(l net.Listener) error { + srv.ensureHandlers() if err := srv.ensureHostSigner(); err != nil { return err } diff --git a/ssh.go b/ssh.go index 0173775..88cf934 100644 --- a/ssh.go +++ b/ssh.go @@ -50,6 +50,9 @@ type ConnCallback func(conn net.Conn) net.Conn // LocalPortForwardingCallback is a hook for allowing port forwarding type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool +// ReversePortForwardingCallback is a hook for allowing reverse port forwarding +type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool + // Window represents the size of a PTY window. type Window struct { Width int diff --git a/tcpip.go b/tcpip.go index 3c8280f..64542ff 100644 --- a/tcpip.go +++ b/tcpip.go @@ -2,35 +2,41 @@ package ssh import ( "io" + "log" "net" "strconv" + "sync" gossh "golang.org/x/crypto/ssh" ) -// direct-tcpip data struct as specified in RFC4254, Section 7.2 -type forwardData struct { - DestinationHost string - DestinationPort uint32 +const ( + forwardedTCPChannelType = "forwarded-tcpip" +) - OriginatorHost string - OriginatorPort uint32 +// direct-tcpip data struct as specified in RFC4254, Section 7.2 +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 } func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - d := forwardData{} + d := localForwardChannelData{} if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) return } - if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestinationHost, d.DestinationPort) { + if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { newChan.Reject(gossh.Prohibited, "port forwarding is disabled") return } - dest := net.JoinHostPort(d.DestinationHost, strconv.FormatInt(int64(d.DestinationPort), 10)) - + dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) + var dialer net.Dialer dconn, err := dialer.DialContext(ctx, "tcp", dest) if err != nil { @@ -56,3 +62,127 @@ func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh io.Copy(dconn, ch) }() } + +type remoteForwardRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardSuccess struct { + BindPort uint32 +} + +type remoteForwardCancelRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardChannelData struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 +} + +type forwardedTCPHandler struct { + forwards map[string]net.Listener + sync.Mutex +} + +func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + switch req.Type { + case "tcpip-forward": + var reqPayload remoteForwardRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { + return false, []byte("port forwarding is disabled") + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + ln, err := net.Listen("tcp", addr) + if err != nil { + // TODO: log listen failure + return false, []byte{} + } + _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) + destPort, _ := strconv.Atoi(destPortStr) + h.Lock() + h.forwards[addr] = ln + h.Unlock() + go func() { + <-ctx.Done() + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + }() + go func() { + for { + c, err := ln.Accept() + if err != nil { + // TODO: log accept failure + break + } + originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) + originPort, _ := strconv.Atoi(orignPortStr) + payload := gossh.Marshal(&remoteForwardChannelData{ + DestAddr: reqPayload.BindAddr, + DestPort: uint32(destPort), + OriginAddr: originAddr, + OriginPort: uint32(originPort), + }) + go func() { + ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) + if err != nil { + // TODO: log failure to open channel + log.Println(err) + c.Close() + return + } + go gossh.DiscardRequests(reqs) + go func() { + defer ch.Close() + defer c.Close() + io.Copy(ch, c) + }() + go func() { + defer ch.Close() + defer c.Close() + io.Copy(c, ch) + }() + }() + } + h.Lock() + delete(h.forwards, addr) + h.Unlock() + }() + return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) + + case "cancel-tcpip-forward": + var reqPayload remoteForwardCancelRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + return true, nil + default: + return false, nil + } +}