Compare commits
10 Commits
master
...
remote-for
Author | SHA1 | Date | |
---|---|---|---|
|
21408e9087 | ||
|
bedd02c8ff | ||
|
bb40c420c7 | ||
|
d411603a3a | ||
|
537a1bee09 | ||
|
98813d13df | ||
|
824322f98c | ||
|
c6ea04edc4 | ||
|
9f705628f8 | ||
|
53fc9a232a |
31
_examples/ssh-remoteforward/portforward.go
Normal file
31
_examples/ssh-remoteforward/portforward.go
Normal file
@ -0,0 +1,31 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
|
||||
server := ssh.Server{
|
||||
LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
|
||||
log.Println("Accepted forward", dhost, dport)
|
||||
return true
|
||||
}),
|
||||
Addr: ":2222",
|
||||
Handler: ssh.Handler(func(s ssh.Session) {
|
||||
io.WriteString(s, "Remote forwarding available...\n")
|
||||
select {}
|
||||
}),
|
||||
ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, host string, port uint32) bool {
|
||||
log.Println("attempt to bind", host, port, "granted")
|
||||
return true
|
||||
}),
|
||||
}
|
||||
|
||||
log.Fatal(server.ListenAndServe())
|
||||
}
|
@ -48,7 +48,7 @@ var (
|
||||
ContextKeyServer = &contextKey{"ssh-server"}
|
||||
|
||||
// ContextKeyConn is a context key for use with Contexts in this package.
|
||||
// The associated value will be of type gossh.Conn.
|
||||
// The associated value will be of type gossh.ServerConn.
|
||||
ContextKeyConn = &contextKey{"ssh-conn"}
|
||||
|
||||
// ContextKeyPublicKey is a context key for use with Contexts in this package.
|
||||
|
2
doc.go
2
doc.go
@ -1,5 +1,4 @@
|
||||
/*
|
||||
|
||||
Package ssh wraps the crypto/ssh package with a higher-level API for building
|
||||
SSH servers. The goal of the API was to make it as simple as using net/http, so
|
||||
the API is very similar.
|
||||
@ -42,6 +41,5 @@ exposed to you via the Session interface.
|
||||
|
||||
The one big feature missing from the Session abstraction is signals. This was
|
||||
started, but not completed. Pull Requests welcome!
|
||||
|
||||
*/
|
||||
package ssh
|
||||
|
48
server.go
48
server.go
@ -24,16 +24,18 @@ type Server struct {
|
||||
HostSigners []Signer // private keys for the host key, must have at least one
|
||||
Version string // server version to be sent before the initial handshake
|
||||
|
||||
PasswordHandler PasswordHandler // password authentication handler
|
||||
PublicKeyHandler PublicKeyHandler // public key authentication handler
|
||||
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
|
||||
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
|
||||
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
|
||||
PasswordHandler PasswordHandler // password authentication handler
|
||||
PublicKeyHandler PublicKeyHandler // public key authentication handler
|
||||
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
|
||||
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
|
||||
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
|
||||
ReversePortForwardingCallback ReversePortForwardingCallback //callback for allowing reverse port forwarding, denies all if nil
|
||||
|
||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||
|
||||
channelHandlers map[string]channelHandler
|
||||
requestHandlers map[string]RequestHandler
|
||||
|
||||
listenerWg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
@ -42,6 +44,9 @@ type Server struct {
|
||||
connWg sync.WaitGroup
|
||||
doneChan chan struct{}
|
||||
}
|
||||
type RequestHandler interface {
|
||||
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||||
}
|
||||
|
||||
// internal for now
|
||||
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||
@ -57,6 +62,19 @@ func (srv *Server) ensureHostSigner() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) ensureHandlers() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
srv.requestHandlers = map[string]RequestHandler{
|
||||
"tcpip-forward": forwardedTCPHandler{},
|
||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||
}
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) config(ctx Context) *gossh.ServerConfig {
|
||||
config := &gossh.ServerConfig{}
|
||||
for _, signer := range srv.HostSigners {
|
||||
@ -144,6 +162,7 @@ func (srv *Server) Shutdown(ctx context.Context) error {
|
||||
//
|
||||
// Serve always returns a non-nil error.
|
||||
func (srv *Server) Serve(l net.Listener) error {
|
||||
srv.ensureHandlers()
|
||||
defer l.Close()
|
||||
if err := srv.ensureHostSigner(); err != nil {
|
||||
return err
|
||||
@ -217,7 +236,8 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
||||
|
||||
ctx.SetValue(ContextKeyConn, sshConn)
|
||||
applyConnMetadata(ctx, sshConn)
|
||||
go gossh.DiscardRequests(reqs)
|
||||
//go gossh.DiscardRequests(reqs)
|
||||
go srv.handleRequests(ctx, reqs)
|
||||
for ch := range chans {
|
||||
handler, found := srv.channelHandlers[ch.ChannelType()]
|
||||
if !found {
|
||||
@ -228,6 +248,22 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
|
||||
for req := range in {
|
||||
handler, found := srv.requestHandlers[req.Type]
|
||||
if !found && req.WantReply {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
/*reqCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() */
|
||||
ret, payload := handler.HandleRequest(ctx, srv, req)
|
||||
if req.WantReply {
|
||||
req.Reply(ret, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the TCP network address srv.Addr and then calls
|
||||
// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
|
||||
// ListenAndServe always returns a non-nil error.
|
||||
|
@ -282,6 +282,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
req.Reply(true, nil)
|
||||
default:
|
||||
// TODO: debug log
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (srv *Server) serveOnce(l net.Listener) error {
|
||||
srv.ensureHandlers()
|
||||
if err := srv.ensureHostSigner(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
3
ssh.go
3
ssh.go
@ -50,6 +50,9 @@ type ConnCallback func(conn net.Conn) net.Conn
|
||||
// LocalPortForwardingCallback is a hook for allowing port forwarding
|
||||
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool
|
||||
|
||||
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
|
||||
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool
|
||||
|
||||
// Window represents the size of a PTY window.
|
||||
type Window struct {
|
||||
Width int
|
||||
|
150
tcpip.go
150
tcpip.go
@ -2,35 +2,41 @@ package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type forwardData struct {
|
||||
DestinationHost string
|
||||
DestinationPort uint32
|
||||
const (
|
||||
forwardedTCPChannelType = "forwarded-tcpip"
|
||||
)
|
||||
|
||||
OriginatorHost string
|
||||
OriginatorPort uint32
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type localForwardChannelData struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
|
||||
d := forwardData{}
|
||||
d := localForwardChannelData{}
|
||||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
||||
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestinationHost, d.DestinationPort) {
|
||||
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) {
|
||||
newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
dest := net.JoinHostPort(d.DestinationHost, strconv.FormatInt(int64(d.DestinationPort), 10))
|
||||
|
||||
dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10))
|
||||
|
||||
var dialer net.Dialer
|
||||
dconn, err := dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
@ -56,3 +62,127 @@ func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh
|
||||
io.Copy(dconn, ch)
|
||||
}()
|
||||
}
|
||||
|
||||
type remoteForwardRequest struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardSuccess struct {
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardCancelRequest struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardChannelData struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
type forwardedTCPHandler struct {
|
||||
forwards map[string]net.Listener
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
|
||||
h.Lock()
|
||||
if h.forwards == nil {
|
||||
h.forwards = make(map[string]net.Listener)
|
||||
}
|
||||
h.Unlock()
|
||||
conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
|
||||
switch req.Type {
|
||||
case "tcpip-forward":
|
||||
var reqPayload remoteForwardRequest
|
||||
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||
// TODO: log parse failure
|
||||
return false, []byte{}
|
||||
}
|
||||
if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) {
|
||||
return false, []byte("port forwarding is disabled")
|
||||
}
|
||||
addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
// TODO: log listen failure
|
||||
return false, []byte{}
|
||||
}
|
||||
_, destPortStr, _ := net.SplitHostPort(ln.Addr().String())
|
||||
destPort, _ := strconv.Atoi(destPortStr)
|
||||
h.Lock()
|
||||
h.forwards[addr] = ln
|
||||
h.Unlock()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.Lock()
|
||||
ln, ok := h.forwards[addr]
|
||||
h.Unlock()
|
||||
if ok {
|
||||
ln.Close()
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
// TODO: log accept failure
|
||||
break
|
||||
}
|
||||
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
|
||||
originPort, _ := strconv.Atoi(orignPortStr)
|
||||
payload := gossh.Marshal(&remoteForwardChannelData{
|
||||
DestAddr: reqPayload.BindAddr,
|
||||
DestPort: uint32(destPort),
|
||||
OriginAddr: originAddr,
|
||||
OriginPort: uint32(originPort),
|
||||
})
|
||||
go func() {
|
||||
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
|
||||
if err != nil {
|
||||
// TODO: log failure to open channel
|
||||
log.Println(err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
go gossh.DiscardRequests(reqs)
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer c.Close()
|
||||
io.Copy(ch, c)
|
||||
}()
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer c.Close()
|
||||
io.Copy(c, ch)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
h.Lock()
|
||||
delete(h.forwards, addr)
|
||||
h.Unlock()
|
||||
}()
|
||||
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
|
||||
|
||||
case "cancel-tcpip-forward":
|
||||
var reqPayload remoteForwardCancelRequest
|
||||
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||
// TODO: log parse failure
|
||||
return false, []byte{}
|
||||
}
|
||||
addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
|
||||
h.Lock()
|
||||
ln, ok := h.forwards[addr]
|
||||
h.Unlock()
|
||||
if ok {
|
||||
ln.Close()
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user