diff --git a/lib/ssh/agent/client.go b/lib/ssh/agent/client.go new file mode 100644 index 0000000..3822f05 --- /dev/null +++ b/lib/ssh/agent/client.go @@ -0,0 +1,663 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package agent implements the ssh-agent protocol, and provides both +// a client and a server. The client can talk to a standard ssh-agent +// that uses UNIX sockets, and one could implement an alternative +// ssh-agent process using the sample server. +// +// References: +// [PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD +package agent // import "github.com/zmap/zgrab/ztools/xssh/agent" + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "sync" + + "github.com/zmap/zgrab/ztools/xssh" + "golang.org/x/crypto/ed25519" +) + +// Agent represents the capabilities of an ssh-agent. +type Agent interface { + // List returns the identities known to the agent. + List() ([]*Key, error) + + // Sign has the agent sign the data using a protocol 2 key as defined + // in [PROTOCOL.agent] section 2.6.2. + Sign(key xssh.PublicKey, data []byte) (*xssh.Signature, error) + + // Add adds a private key to the agent. + Add(key AddedKey) error + + // Remove removes all identities with the given public key. + Remove(key xssh.PublicKey) error + + // RemoveAll removes all identities. + RemoveAll() error + + // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. + Lock(passphrase []byte) error + + // Unlock undoes the effect of Lock + Unlock(passphrase []byte) error + + // Signers returns signers for all the known keys. + Signers() ([]xssh.Signer, error) +} + +// AddedKey describes an SSH key to be added to an Agent. +type AddedKey struct { + // PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or + // *ecdsa.PrivateKey, which will be inserted into the agent. + PrivateKey interface{} + // Certificate, if not nil, is communicated to the agent and will be + // stored with the key. + Certificate *xssh.Certificate + // Comment is an optional, free-form string. + Comment string + // LifetimeSecs, if not zero, is the number of seconds that the + // agent will store the key for. + LifetimeSecs uint32 + // ConfirmBeforeUse, if true, requests that the agent confirm with the + // user before each use of this key. + ConfirmBeforeUse bool +} + +// See [PROTOCOL.agent], section 3. +const ( + agentRequestV1Identities = 1 + agentRemoveAllV1Identities = 9 + + // 3.2 Requests from client to agent for protocol 2 key operations + agentAddIdentity = 17 + agentRemoveIdentity = 18 + agentRemoveAllIdentities = 19 + agentAddIdConstrained = 25 + + // 3.3 Key-type independent requests from client to agent + agentAddSmartcardKey = 20 + agentRemoveSmartcardKey = 21 + agentLock = 22 + agentUnlock = 23 + agentAddSmartcardKeyConstrained = 26 + + // 3.7 Key constraint identifiers + agentConstrainLifetime = 1 + agentConstrainConfirm = 2 +) + +// maxAgentResponseBytes is the maximum agent reply size that is accepted. This +// is a sanity check, not a limit in the spec. +const maxAgentResponseBytes = 16 << 20 + +// Agent messages: +// These structures mirror the wire format of the corresponding xssh agent +// messages found in [PROTOCOL.agent]. + +// 3.4 Generic replies from agent to client +const agentFailure = 5 + +type failureAgentMsg struct{} + +const agentSuccess = 6 + +type successAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentRequestIdentities = 11 + +type requestIdentitiesAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentIdentitiesAnswer = 12 + +type identitiesAnswerAgentMsg struct { + NumKeys uint32 `sshtype:"12"` + Keys []byte `ssh:"rest"` +} + +// See [PROTOCOL.agent], section 2.6.2. +const agentSignRequest = 13 + +type signRequestAgentMsg struct { + KeyBlob []byte `sshtype:"13"` + Data []byte + Flags uint32 +} + +// See [PROTOCOL.agent], section 2.6.2. + +// 3.6 Replies from agent to client for protocol 2 key operations +const agentSignResponse = 14 + +type signResponseAgentMsg struct { + SigBlob []byte `sshtype:"14"` +} + +type publicKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +// Key represents a protocol 2 public key as defined in +// [PROTOCOL.agent], section 2.5.2. +type Key struct { + Format string + Blob []byte + Comment string +} + +func (k *Key) MarshalJSON() ([]byte, error) { + panic("unimplemented") +} + +func clientErr(err error) error { + return fmt.Errorf("agent: client error: %v", err) +} + +// String returns the storage form of an agent key with the format, base64 +// encoded serialized key, and the comment if it is not empty. +func (k *Key) String() string { + s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) + + if k.Comment != "" { + s += " " + k.Comment + } + + return s +} + +// Type returns the public key type. +func (k *Key) Type() string { + return k.Format +} + +// Marshal returns key blob to satisfy the ssh.PublicKey interface. +func (k *Key) Marshal() []byte { + return k.Blob +} + +// Verify satisfies the ssh.PublicKey interface. +func (k *Key) Verify(data []byte, sig *xssh.Signature) error { + pubKey, err := xssh.ParsePublicKey(k.Blob) + if err != nil { + return fmt.Errorf("agent: bad public key: %v", err) + } + return pubKey.Verify(data, sig) +} + +type wireKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +func parseKey(in []byte) (out *Key, rest []byte, err error) { + var record struct { + Blob []byte + Comment string + Rest []byte `ssh:"rest"` + } + + if err := xssh.Unmarshal(in, &record); err != nil { + return nil, nil, err + } + + var wk wireKey + if err := xssh.Unmarshal(record.Blob, &wk); err != nil { + return nil, nil, err + } + + return &Key{ + Format: wk.Format, + Blob: record.Blob, + Comment: record.Comment, + }, record.Rest, nil +} + +// client is a client for an ssh-agent process. +type client struct { + // conn is typically a *net.UnixConn + conn io.ReadWriter + // mu is used to prevent concurrent access to the agent + mu sync.Mutex +} + +// NewClient returns an Agent that talks to an ssh-agent process over +// the given connection. +func NewClient(rw io.ReadWriter) Agent { + return &client{conn: rw} +} + +// call sends an RPC to the agent. On success, the reply is +// unmarshaled into reply and replyType is set to the first byte of +// the reply, which contains the type of the message. +func (c *client) call(req []byte) (reply interface{}, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + msg := make([]byte, 4+len(req)) + binary.BigEndian.PutUint32(msg, uint32(len(req))) + copy(msg[4:], req) + if _, err = c.conn.Write(msg); err != nil { + return nil, clientErr(err) + } + + var respSizeBuf [4]byte + if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { + return nil, clientErr(err) + } + respSize := binary.BigEndian.Uint32(respSizeBuf[:]) + if respSize > maxAgentResponseBytes { + return nil, clientErr(err) + } + + buf := make([]byte, respSize) + if _, err = io.ReadFull(c.conn, buf); err != nil { + return nil, clientErr(err) + } + reply, err = unmarshal(buf) + if err != nil { + return nil, clientErr(err) + } + return reply, err +} + +func (c *client) simpleCall(req []byte) error { + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +func (c *client) RemoveAll() error { + return c.simpleCall([]byte{agentRemoveAllIdentities}) +} + +func (c *client) Remove(key xssh.PublicKey) error { + req := xssh.Marshal(&agentRemoveIdentityMsg{ + KeyBlob: key.Marshal(), + }) + return c.simpleCall(req) +} + +func (c *client) Lock(passphrase []byte) error { + req := xssh.Marshal(&agentLockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +func (c *client) Unlock(passphrase []byte) error { + req := xssh.Marshal(&agentUnlockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +// List returns the identities known to the agent. +func (c *client) List() ([]*Key, error) { + // see [PROTOCOL.agent] section 2.5.2. + req := []byte{agentRequestIdentities} + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *identitiesAnswerAgentMsg: + if msg.NumKeys > maxAgentResponseBytes/8 { + return nil, errors.New("agent: too many keys in agent reply") + } + keys := make([]*Key, msg.NumKeys) + data := msg.Keys + for i := uint32(0); i < msg.NumKeys; i++ { + var key *Key + var err error + if key, data, err = parseKey(data); err != nil { + return nil, err + } + keys[i] = key + } + return keys, nil + case *failureAgentMsg: + return nil, errors.New("agent: failed to list keys") + } + panic("unreachable") +} + +// Sign has the agent sign the data using a protocol 2 key as defined +// in [PROTOCOL.agent] section 2.6.2. +func (c *client) Sign(key xssh.PublicKey, data []byte) (*xssh.Signature, error) { + req := xssh.Marshal(signRequestAgentMsg{ + KeyBlob: key.Marshal(), + Data: data, + }) + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *signResponseAgentMsg: + var sig xssh.Signature + if err := xssh.Unmarshal(msg.SigBlob, &sig); err != nil { + return nil, err + } + + return &sig, nil + case *failureAgentMsg: + return nil, errors.New("agent: failed to sign challenge") + } + panic("unreachable") +} + +// unmarshal parses an agent message in packet, returning the parsed +// form and the message type of packet. +func unmarshal(packet []byte) (interface{}, error) { + if len(packet) < 1 { + return nil, errors.New("agent: empty packet") + } + var msg interface{} + switch packet[0] { + case agentFailure: + return new(failureAgentMsg), nil + case agentSuccess: + return new(successAgentMsg), nil + case agentIdentitiesAnswer: + msg = new(identitiesAnswerAgentMsg) + case agentSignResponse: + msg = new(signResponseAgentMsg) + case agentV1IdentitiesAnswer: + msg = new(agentV1IdentityMsg) + default: + return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) + } + if err := xssh.Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +type rsaKeyMsg struct { + Type string `sshtype:"17|25"` + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type dsaKeyMsg struct { + Type string `sshtype:"17|25"` + P *big.Int + Q *big.Int + G *big.Int + Y *big.Int + X *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ecdsaKeyMsg struct { + Type string `sshtype:"17|25"` + Curve string + KeyBytes []byte + D *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ed25519KeyMsg struct { + Type string `sshtype:"17|25"` + Pub []byte + Priv []byte + Comments string + Constraints []byte `ssh:"rest"` +} + +// Insert adds a private key to the agent. +func (c *client) insertKey(s interface{}, comment string, constraints []byte) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = xssh.Marshal(rsaKeyMsg{ + Type: xssh.KeyAlgoRSA, + N: k.N, + E: big.NewInt(int64(k.E)), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + Constraints: constraints, + }) + case *dsa.PrivateKey: + req = xssh.Marshal(dsaKeyMsg{ + Type: xssh.KeyAlgoDSA, + P: k.P, + Q: k.Q, + G: k.G, + Y: k.Y, + X: k.X, + Comments: comment, + Constraints: constraints, + }) + case *ecdsa.PrivateKey: + nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) + req = xssh.Marshal(ecdsaKeyMsg{ + Type: "ecdsa-sha2-" + nistID, + Curve: nistID, + KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), + D: k.D, + Comments: comment, + Constraints: constraints, + }) + case *ed25519.PrivateKey: + req = xssh.Marshal(ed25519KeyMsg{ + Type: xssh.KeyAlgoED25519, + Pub: []byte(*k)[32:], + Priv: []byte(*k), + Comments: comment, + Constraints: constraints, + }) + default: + return fmt.Errorf("agent: unsupported key type %T", s) + } + + // if constraints are present then the message type needs to be changed. + if len(constraints) != 0 { + req[0] = agentAddIdConstrained + } + + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +type rsaCertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type dsaCertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + X *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ecdsaCertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + D *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ed25519CertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + Pub []byte + Priv []byte + Comments string + Constraints []byte `ssh:"rest"` +} + +// Add adds a private key to the agent. If a certificate is given, +// that certificate is added instead as public key. +func (c *client) Add(key AddedKey) error { + var constraints []byte + + if secs := key.LifetimeSecs; secs != 0 { + constraints = append(constraints, agentConstrainLifetime) + + var secsBytes [4]byte + binary.BigEndian.PutUint32(secsBytes[:], secs) + constraints = append(constraints, secsBytes[:]...) + } + + if key.ConfirmBeforeUse { + constraints = append(constraints, agentConstrainConfirm) + } + + if cert := key.Certificate; cert == nil { + return c.insertKey(key.PrivateKey, key.Comment, constraints) + } else { + return c.insertCert(key.PrivateKey, cert, key.Comment, constraints) + } +} + +func (c *client) insertCert(s interface{}, cert *xssh.Certificate, comment string, constraints []byte) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = xssh.Marshal(rsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + Constraints: constraints, + }) + case *dsa.PrivateKey: + req = xssh.Marshal(dsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + X: k.X, + Comments: comment, + Constraints: constraints, + }) + case *ecdsa.PrivateKey: + req = xssh.Marshal(ecdsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Comments: comment, + Constraints: constraints, + }) + case *ed25519.PrivateKey: + req = xssh.Marshal(ed25519CertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + Pub: []byte(*k)[32:], + Priv: []byte(*k), + Comments: comment, + Constraints: constraints, + }) + default: + return fmt.Errorf("agent: unsupported key type %T", s) + } + + // if constraints are present then the message type needs to be changed. + if len(constraints) != 0 { + req[0] = agentAddIdConstrained + } + + signer, err := xssh.NewSignerFromKey(s) + if err != nil { + return err + } + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return errors.New("agent: signer and cert have different public key") + } + + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +// Signers provides a callback for client authentication. +func (c *client) Signers() ([]xssh.Signer, error) { + keys, err := c.List() + if err != nil { + return nil, err + } + + var result []xssh.Signer + for _, k := range keys { + result = append(result, &agentKeyringSigner{c, k}) + } + return result, nil +} + +type agentKeyringSigner struct { + agent *client + pub xssh.PublicKey +} + +func (s *agentKeyringSigner) PublicKey() xssh.PublicKey { + return s.pub +} + +func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*xssh.Signature, error) { + // The agent has its own entropy source, so the rand argument is ignored. + return s.agent.Sign(s.pub, data) +} diff --git a/lib/ssh/agent/client_test.go b/lib/ssh/agent/client_test.go new file mode 100644 index 0000000..a350092 --- /dev/null +++ b/lib/ssh/agent/client_test.go @@ -0,0 +1,343 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "errors" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + "time" + + "github.com/zmap/zgrab/ztools/xssh" +) + +// startAgent executes ssh-agent, and returns a Agent interface to it. +func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { + if testing.Short() { + // ssh-agent is not always available, and the key + // types supported vary by platform. + t.Skip("skipping test due to -short") + } + + bin, err := exec.LookPath("ssh-agent") + if err != nil { + t.Skip("could not find ssh-agent") + } + + cmd := exec.Command(bin, "-s") + out, err := cmd.Output() + if err != nil { + t.Fatalf("cmd.Output: %v", err) + } + + /* Output looks like: + + SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; + SSH_AGENT_PID=15542; export SSH_AGENT_PID; + echo Agent pid 15542; + */ + fields := bytes.Split(out, []byte(";")) + line := bytes.SplitN(fields[0], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AUTH_SOCK" { + t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) + } + socket = string(line[1]) + + line = bytes.SplitN(fields[2], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AGENT_PID" { + t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) + } + pidStr := line[1] + pid, err := strconv.Atoi(string(pidStr)) + if err != nil { + t.Fatalf("Atoi(%q): %v", pidStr, err) + } + + conn, err := net.Dial("unix", string(socket)) + if err != nil { + t.Fatalf("net.Dial: %v", err) + } + + ac := NewClient(conn) + return ac, socket, func() { + proc, _ := os.FindProcess(pid) + if proc != nil { + proc.Kill() + } + conn.Close() + os.RemoveAll(filepath.Dir(socket)) + } +} + +func testAgent(t *testing.T, key interface{}, cert *xssh.Certificate, lifetimeSecs uint32) { + agent, _, cleanup := startAgent(t) + defer cleanup() + + testAgentInterface(t, agent, key, cert, lifetimeSecs) +} + +func testKeyring(t *testing.T, key interface{}, cert *xssh.Certificate, lifetimeSecs uint32) { + a := NewKeyring() + testAgentInterface(t, a, key, cert, lifetimeSecs) +} + +func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *xssh.Certificate, lifetimeSecs uint32) { + signer, err := xssh.NewSignerFromKey(key) + if err != nil { + t.Fatalf("NewSignerFromKey(%T): %v", key, err) + } + // The agent should start up empty. + if keys, err := agent.List(); err != nil { + t.Fatalf("RequestIdentities: %v", err) + } else if len(keys) > 0 { + t.Fatalf("got %d keys, want 0: %v", len(keys), keys) + } + + // Attempt to insert the key, with certificate if specified. + var pubKey xssh.PublicKey + if cert != nil { + err = agent.Add(AddedKey{ + PrivateKey: key, + Certificate: cert, + Comment: "comment", + LifetimeSecs: lifetimeSecs, + }) + pubKey = cert + } else { + err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) + pubKey = signer.PublicKey() + } + if err != nil { + t.Fatalf("insert(%T): %v", key, err) + } + + // Did the key get inserted successfully? + if keys, err := agent.List(); err != nil { + t.Fatalf("List: %v", err) + } else if len(keys) != 1 { + t.Fatalf("got %v, want 1 key", keys) + } else if keys[0].Comment != "comment" { + t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") + } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { + t.Fatalf("key mismatch") + } + + // Can the agent make a valid signature? + data := []byte("hello") + sig, err := agent.Sign(pubKey, data) + if err != nil { + t.Fatalf("Sign(%s): %v", pubKey.Type(), err) + } + + if err := pubKey.Verify(data, sig); err != nil { + t.Fatalf("Verify(%s): %v", pubKey.Type(), err) + } + + // If the key has a lifetime, is it removed when it should be? + if lifetimeSecs > 0 { + time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond) + keys, err := agent.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(keys) > 0 { + t.Fatalf("key not expired") + } + } + +} + +func DISABLED_TestAgent(t *testing.T) { + for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} { + testAgent(t, testPrivateKeys[keyType], nil, 0) + testKeyring(t, testPrivateKeys[keyType], nil, 1) + } +} + +func TestCert(t *testing.T) { + cert := &xssh.Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: xssh.CertTimeInfinity, + CertType: xssh.UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + testAgent(t, testPrivateKeys["rsa"], cert, 0) + testKeyring(t, testPrivateKeys["rsa"], cert, 1) +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func TestAuth(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + agent, _, cleanup := startAgent(t) + defer cleanup() + + if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { + t.Errorf("Add: %v", err) + } + + serverConf := xssh.ServerConfig{} + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.PublicKeyCallback = func(c xssh.ConnMetadata, key xssh.PublicKey) (*xssh.Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, errors.New("pubkey rejected") + } + + go func() { + conn, _, _, err := xssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + conn.Close() + }() + + conf := xssh.ClientConfig{} + conf.Auth = append(conf.Auth, xssh.PublicKeysCallback(agent.Signers)) + conn, _, _, err := xssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + conn.Close() +} + +func TestLockClient(t *testing.T) { + agent, _, cleanup := startAgent(t) + defer cleanup() + testLockAgent(agent, t) +} + +func testLockAgent(agent Agent, t *testing.T) { + if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { + t.Errorf("Add: %v", err) + } + if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { + t.Errorf("Add: %v", err) + } + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 2 { + t.Errorf("Want 2 keys, got %v", keys) + } + + passphrase := []byte("secret") + if err := agent.Lock(passphrase); err != nil { + t.Errorf("Lock: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 0 { + t.Errorf("Want 0 keys, got %v", keys) + } + + signer, _ := xssh.NewSignerFromKey(testPrivateKeys["rsa"]) + if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { + t.Fatalf("Sign did not fail") + } + + if err := agent.Remove(signer.PublicKey()); err == nil { + t.Fatalf("Remove did not fail") + } + + if err := agent.RemoveAll(); err == nil { + t.Fatalf("RemoveAll did not fail") + } + + if err := agent.Unlock(nil); err == nil { + t.Errorf("Unlock with wrong passphrase succeeded") + } + if err := agent.Unlock(passphrase); err != nil { + t.Errorf("Unlock: %v", err) + } + + if err := agent.Remove(signer.PublicKey()); err != nil { + t.Fatalf("Remove: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 1 { + t.Errorf("Want 1 keys, got %v", keys) + } +} + +func TestAgentLifetime(t *testing.T) { + agent, _, cleanup := startAgent(t) + defer cleanup() + + for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { + // Add private keys to the agent. + err := agent.Add(AddedKey{ + PrivateKey: testPrivateKeys[keyType], + Comment: "comment", + LifetimeSecs: 1, + }) + if err != nil { + t.Fatalf("add: %v", err) + } + // Add certs to the agent. + cert := &xssh.Certificate{ + Key: testPublicKeys[keyType], + ValidBefore: xssh.CertTimeInfinity, + CertType: xssh.UserCert, + } + cert.SignCert(rand.Reader, testSigners[keyType]) + err = agent.Add(AddedKey{ + PrivateKey: testPrivateKeys[keyType], + Certificate: cert, + Comment: "comment", + LifetimeSecs: 1, + }) + if err != nil { + t.Fatalf("add: %v", err) + } + } + time.Sleep(1100 * time.Millisecond) + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 0 { + t.Errorf("Want 0 keys, got %v", len(keys)) + } +} diff --git a/lib/ssh/agent/example_test.go b/lib/ssh/agent/example_test.go new file mode 100644 index 0000000..582b878 --- /dev/null +++ b/lib/ssh/agent/example_test.go @@ -0,0 +1,40 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent_test + +import ( + "log" + "net" + "os" + + "github.com/zmap/zgrab/ztools/xssh" + "github.com/zmap/zgrab/ztools/xssh/agent" +) + +func ExampleClientAgent() { + // ssh-agent has a UNIX socket under $SSH_AUTH_SOCK + socket := os.Getenv("SSH_AUTH_SOCK") + conn, err := net.Dial("unix", socket) + if err != nil { + log.Fatalf("net.Dial: %v", err) + } + agentClient := agent.NewClient(conn) + config := &xssh.ClientConfig{ + User: "username", + Auth: []xssh.AuthMethod{ + // Use a callback rather than PublicKeys + // so we only consult the agent once the remote server + // wants it. + xssh.PublicKeysCallback(agentClient.Signers), + }, + } + + sshc, err := xssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatalf("Dial: %v", err) + } + // .. use sshc + sshc.Close() +} diff --git a/lib/ssh/agent/forward.go b/lib/ssh/agent/forward.go new file mode 100644 index 0000000..c7969ac --- /dev/null +++ b/lib/ssh/agent/forward.go @@ -0,0 +1,103 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "errors" + "io" + "net" + "sync" + + "github.com/zmap/zgrab/ztools/xssh" +) + +// RequestAgentForwarding sets up agent forwarding for the session. +// ForwardToAgent or ForwardToRemote should be called to route +// the authentication requests. +func RequestAgentForwarding(session *xssh.Session) error { + ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) + if err != nil { + return err + } + if !ok { + return errors.New("forwarding request denied") + } + return nil +} + +// ForwardToAgent routes authentication requests to the given keyring. +func ForwardToAgent(client *xssh.Client, keyring Agent) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go xssh.DiscardRequests(reqs) + go func() { + ServeAgent(keyring, channel) + channel.Close() + }() + } + }() + return nil +} + +const channelType = "auth-agent@openssh.com" + +// ForwardToRemote routes authentication requests to the ssh-agent +// process serving on the given unix socket. +func ForwardToRemote(client *xssh.Client, addr string) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + conn, err := net.Dial("unix", addr) + if err != nil { + return err + } + conn.Close() + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go xssh.DiscardRequests(reqs) + go forwardUnixSocket(channel, addr) + } + }() + return nil +} + +func forwardUnixSocket(channel xssh.Channel, addr string) { + conn, err := net.Dial("unix", addr) + if err != nil { + return + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + + wg.Wait() + conn.Close() + channel.Close() +} diff --git a/lib/ssh/agent/keyring.go b/lib/ssh/agent/keyring.go new file mode 100644 index 0000000..71af400 --- /dev/null +++ b/lib/ssh/agent/keyring.go @@ -0,0 +1,215 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "sync" + "time" + + "github.com/zmap/zgrab/ztools/xssh" +) + +type privKey struct { + signer xssh.Signer + comment string + expire *time.Time +} + +type keyring struct { + mu sync.Mutex + keys []privKey + + locked bool + passphrase []byte +} + +var errLocked = errors.New("agent: locked") + +// NewKeyring returns an Agent that holds keys in memory. It is safe +// for concurrent use by multiple goroutines. +func NewKeyring() Agent { + return &keyring{} +} + +// RemoveAll removes all identities. +func (r *keyring) RemoveAll() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.keys = nil + return nil +} + +// removeLocked does the actual key removal. The caller must already be holding the +// keyring mutex. +func (r *keyring) removeLocked(want []byte) error { + found := false + for i := 0; i < len(r.keys); { + if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { + found = true + r.keys[i] = r.keys[len(r.keys)-1] + r.keys = r.keys[:len(r.keys)-1] + continue + } else { + i++ + } + } + + if !found { + return errors.New("agent: key not found") + } + return nil +} + +// Remove removes all identities with the given public key. +func (r *keyring) Remove(key xssh.PublicKey) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + return r.removeLocked(key.Marshal()) +} + +// Lock locks the agent. Sign and Remove will fail, and List will return an empty list. +func (r *keyring) Lock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.locked = true + r.passphrase = passphrase + return nil +} + +// Unlock undoes the effect of Lock +func (r *keyring) Unlock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if !r.locked { + return errors.New("agent: not locked") + } + if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { + return fmt.Errorf("agent: incorrect passphrase") + } + + r.locked = false + r.passphrase = nil + return nil +} + +// expireKeysLocked removes expired keys from the keyring. If a key was added +// with a lifetimesecs contraint and seconds >= lifetimesecs seconds have +// ellapsed, it is removed. The caller *must* be holding the keyring mutex. +func (r *keyring) expireKeysLocked() { + for _, k := range r.keys { + if k.expire != nil && time.Now().After(*k.expire) { + r.removeLocked(k.signer.PublicKey().Marshal()) + } + } +} + +// List returns the identities known to the agent. +func (r *keyring) List() ([]*Key, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + // section 2.7: locked agents return empty. + return nil, nil + } + + r.expireKeysLocked() + var ids []*Key + for _, k := range r.keys { + pub := k.signer.PublicKey() + ids = append(ids, &Key{ + Format: pub.Type(), + Blob: pub.Marshal(), + Comment: k.comment}) + } + return ids, nil +} + +// Insert adds a private key to the keyring. If a certificate +// is given, that certificate is added as public key. Note that +// any constraints given are ignored. +func (r *keyring) Add(key AddedKey) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + signer, err := xssh.NewSignerFromKey(key.PrivateKey) + + if err != nil { + return err + } + + if cert := key.Certificate; cert != nil { + signer, err = xssh.NewCertSigner(cert, signer) + if err != nil { + return err + } + } + + p := privKey{ + signer: signer, + comment: key.Comment, + } + + if key.LifetimeSecs > 0 { + t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second) + p.expire = &t + } + + r.keys = append(r.keys, p) + + return nil +} + +// Sign returns a signature for the data. +func (r *keyring) Sign(key xssh.PublicKey, data []byte) (*xssh.Signature, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + r.expireKeysLocked() + wanted := key.Marshal() + for _, k := range r.keys { + if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { + return k.signer.Sign(rand.Reader, data) + } + } + return nil, errors.New("not found") +} + +// Signers returns signers for all the known keys. +func (r *keyring) Signers() ([]xssh.Signer, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + r.expireKeysLocked() + s := make([]xssh.Signer, 0, len(r.keys)) + for _, k := range r.keys { + s = append(s, k.signer) + } + return s, nil +} diff --git a/lib/ssh/agent/keyring_test.go b/lib/ssh/agent/keyring_test.go new file mode 100644 index 0000000..e5d50e7 --- /dev/null +++ b/lib/ssh/agent/keyring_test.go @@ -0,0 +1,76 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import "testing" + +func addTestKey(t *testing.T, a Agent, keyName string) { + err := a.Add(AddedKey{ + PrivateKey: testPrivateKeys[keyName], + Comment: keyName, + }) + if err != nil { + t.Fatalf("failed to add key %q: %v", keyName, err) + } +} + +func removeTestKey(t *testing.T, a Agent, keyName string) { + err := a.Remove(testPublicKeys[keyName]) + if err != nil { + t.Fatalf("failed to remove key %q: %v", keyName, err) + } +} + +func validateListedKeys(t *testing.T, a Agent, expectedKeys []string) { + listedKeys, err := a.List() + if err != nil { + t.Fatalf("failed to list keys: %v", err) + return + } + actualKeys := make(map[string]bool) + for _, key := range listedKeys { + actualKeys[key.Comment] = true + } + + matchedKeys := make(map[string]bool) + for _, expectedKey := range expectedKeys { + if !actualKeys[expectedKey] { + t.Fatalf("expected key %q, but was not found", expectedKey) + } else { + matchedKeys[expectedKey] = true + } + } + + for actualKey := range actualKeys { + if !matchedKeys[actualKey] { + t.Fatalf("key %q was found, but was not expected", actualKey) + } + } +} + +func TestKeyringAddingAndRemoving(t *testing.T) { + keyNames := []string{"dsa", "ecdsa", "rsa", "user"} + + // add all test private keys + k := NewKeyring() + for _, keyName := range keyNames { + addTestKey(t, k, keyName) + } + validateListedKeys(t, k, keyNames) + + // remove a key in the middle + keyToRemove := keyNames[1] + keyNames = append(keyNames[:1], keyNames[2:]...) + + removeTestKey(t, k, keyToRemove) + validateListedKeys(t, k, keyNames) + + // remove all keys + err := k.RemoveAll() + if err != nil { + t.Fatalf("failed to remove all keys: %v", err) + } + validateListedKeys(t, k, []string{}) +} diff --git a/lib/ssh/agent/server.go b/lib/ssh/agent/server.go new file mode 100644 index 0000000..605df62 --- /dev/null +++ b/lib/ssh/agent/server.go @@ -0,0 +1,451 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "math/big" + + "github.com/zmap/zgrab/ztools/xssh" + "golang.org/x/crypto/ed25519" +) + +// Server wraps an Agent and uses it to implement the agent side of +// the SSH-agent, wire protocol. +type server struct { + agent Agent +} + +func (s *server) processRequestBytes(reqData []byte) []byte { + rep, err := s.processRequest(reqData) + if err != nil { + if err != errLocked { + // TODO(hanwen): provide better logging interface? + log.Printf("agent %d: %v", reqData[0], err) + } + return []byte{agentFailure} + } + + if err == nil && rep == nil { + return []byte{agentSuccess} + } + + return xssh.Marshal(rep) +} + +func marshalKey(k *Key) []byte { + var record struct { + Blob []byte + Comment string + } + record.Blob = k.Marshal() + record.Comment = k.Comment + + return xssh.Marshal(&record) +} + +// See [PROTOCOL.agent], section 2.5.1. +const agentV1IdentitiesAnswer = 2 + +type agentV1IdentityMsg struct { + Numkeys uint32 `sshtype:"2"` +} + +type agentRemoveIdentityMsg struct { + KeyBlob []byte `sshtype:"18"` +} + +type agentLockMsg struct { + Passphrase []byte `sshtype:"22"` +} + +type agentUnlockMsg struct { + Passphrase []byte `sshtype:"23"` +} + +func (s *server) processRequest(data []byte) (interface{}, error) { + switch data[0] { + case agentRequestV1Identities: + return &agentV1IdentityMsg{0}, nil + + case agentRemoveAllV1Identities: + return nil, nil + + case agentRemoveIdentity: + var req agentRemoveIdentityMsg + if err := xssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := xssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) + + case agentRemoveAllIdentities: + return nil, s.agent.RemoveAll() + + case agentLock: + var req agentLockMsg + if err := xssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + return nil, s.agent.Lock(req.Passphrase) + + case agentUnlock: + var req agentLockMsg + if err := xssh.Unmarshal(data, &req); err != nil { + return nil, err + } + return nil, s.agent.Unlock(req.Passphrase) + + case agentSignRequest: + var req signRequestAgentMsg + if err := xssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := xssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + k := &Key{ + Format: wk.Format, + Blob: req.KeyBlob, + } + + sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags. + if err != nil { + return nil, err + } + return &signResponseAgentMsg{SigBlob: xssh.Marshal(sig)}, nil + + case agentRequestIdentities: + keys, err := s.agent.List() + if err != nil { + return nil, err + } + + rep := identitiesAnswerAgentMsg{ + NumKeys: uint32(len(keys)), + } + for _, k := range keys { + rep.Keys = append(rep.Keys, marshalKey(k)...) + } + return rep, nil + + case agentAddIdConstrained, agentAddIdentity: + return nil, s.insertIdentity(data) + } + + return nil, fmt.Errorf("unknown opcode %d", data[0]) +} + +func parseRSAKey(req []byte) (*AddedKey, error) { + var k rsaKeyMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + if k.E.BitLen() > 30 { + return nil, errors.New("agent: RSA public exponent too large") + } + priv := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: int(k.E.Int64()), + N: k.N, + }, + D: k.D, + Primes: []*big.Int{k.P, k.Q}, + } + priv.Precompute() + + return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil +} + +func parseEd25519Key(req []byte) (*AddedKey, error) { + var k ed25519KeyMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + priv := ed25519.PrivateKey(k.Priv) + return &AddedKey{PrivateKey: &priv, Comment: k.Comments}, nil +} + +func parseDSAKey(req []byte) (*AddedKey, error) { + var k dsaKeyMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + priv := &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Y, + }, + X: k.X, + } + + return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil +} + +func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) { + priv = &ecdsa.PrivateKey{ + D: privScalar, + } + + switch curveName { + case "nistp256": + priv.Curve = elliptic.P256() + case "nistp384": + priv.Curve = elliptic.P384() + case "nistp521": + priv.Curve = elliptic.P521() + default: + return nil, fmt.Errorf("agent: unknown curve %q", curveName) + } + + priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes) + if priv.X == nil || priv.Y == nil { + return nil, errors.New("agent: point not on curve") + } + + return priv, nil +} + +func parseEd25519Cert(req []byte) (*AddedKey, error) { + var k ed25519CertMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + pubKey, err := xssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + priv := ed25519.PrivateKey(k.Priv) + cert, ok := pubKey.(*xssh.Certificate) + if !ok { + return nil, errors.New("agent: bad ED25519 certificate") + } + return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil +} + +func parseECDSAKey(req []byte) (*AddedKey, error) { + var k ecdsaKeyMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + + priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D) + if err != nil { + return nil, err + } + + return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil +} + +func parseRSACert(req []byte) (*AddedKey, error) { + var k rsaCertMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + + pubKey, err := xssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + + cert, ok := pubKey.(*xssh.Certificate) + if !ok { + return nil, errors.New("agent: bad RSA certificate") + } + + // An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go + var rsaPub struct { + Name string + E *big.Int + N *big.Int + } + if err := xssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil { + return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) + } + + if rsaPub.E.BitLen() > 30 { + return nil, errors.New("agent: RSA public exponent too large") + } + + priv := rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: int(rsaPub.E.Int64()), + N: rsaPub.N, + }, + D: k.D, + Primes: []*big.Int{k.Q, k.P}, + } + priv.Precompute() + + return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil +} + +func parseDSACert(req []byte) (*AddedKey, error) { + var k dsaCertMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + pubKey, err := xssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + cert, ok := pubKey.(*xssh.Certificate) + if !ok { + return nil, errors.New("agent: bad DSA certificate") + } + + // A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go + var w struct { + Name string + P, Q, G, Y *big.Int + } + if err := xssh.Unmarshal(cert.Key.Marshal(), &w); err != nil { + return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) + } + + priv := &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + }, + Y: w.Y, + }, + X: k.X, + } + + return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil +} + +func parseECDSACert(req []byte) (*AddedKey, error) { + var k ecdsaCertMsg + if err := xssh.Unmarshal(req, &k); err != nil { + return nil, err + } + + pubKey, err := xssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + cert, ok := pubKey.(*xssh.Certificate) + if !ok { + return nil, errors.New("agent: bad ECDSA certificate") + } + + // An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go + var ecdsaPub struct { + Name string + ID string + Key []byte + } + if err := xssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil { + return nil, err + } + + priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D) + if err != nil { + return nil, err + } + + return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil +} + +func (s *server) insertIdentity(req []byte) error { + var record struct { + Type string `sshtype:"17|25"` + Rest []byte `ssh:"rest"` + } + + if err := xssh.Unmarshal(req, &record); err != nil { + return err + } + + var addedKey *AddedKey + var err error + + switch record.Type { + case xssh.KeyAlgoRSA: + addedKey, err = parseRSAKey(req) + case xssh.KeyAlgoDSA: + addedKey, err = parseDSAKey(req) + case xssh.KeyAlgoECDSA256, xssh.KeyAlgoECDSA384, xssh.KeyAlgoECDSA521: + addedKey, err = parseECDSAKey(req) + case xssh.KeyAlgoED25519: + addedKey, err = parseEd25519Key(req) + case xssh.CertAlgoRSAv01: + addedKey, err = parseRSACert(req) + case xssh.CertAlgoDSAv01: + addedKey, err = parseDSACert(req) + case xssh.CertAlgoECDSA256v01, xssh.CertAlgoECDSA384v01, xssh.CertAlgoECDSA521v01: + addedKey, err = parseECDSACert(req) + case xssh.CertAlgoED25519v01: + addedKey, err = parseEd25519Cert(req) + default: + return fmt.Errorf("agent: not implemented: %q", record.Type) + } + + if err != nil { + return err + } + return s.agent.Add(*addedKey) +} + +// ServeAgent serves the agent protocol on the given connection. It +// returns when an I/O error occurs. +func ServeAgent(agent Agent, c io.ReadWriter) error { + s := &server{agent} + + var length [4]byte + for { + if _, err := io.ReadFull(c, length[:]); err != nil { + return err + } + l := binary.BigEndian.Uint32(length[:]) + if l > maxAgentResponseBytes { + // We also cap requests. + return fmt.Errorf("agent: request too large: %d", l) + } + + req := make([]byte, l) + if _, err := io.ReadFull(c, req); err != nil { + return err + } + + repData := s.processRequestBytes(req) + if len(repData) > maxAgentResponseBytes { + return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) + } + + binary.BigEndian.PutUint32(length[:], uint32(len(repData))) + if _, err := c.Write(length[:]); err != nil { + return err + } + if _, err := c.Write(repData); err != nil { + return err + } + } +} diff --git a/lib/ssh/agent/server_test.go b/lib/ssh/agent/server_test.go new file mode 100644 index 0000000..fa7aeb4 --- /dev/null +++ b/lib/ssh/agent/server_test.go @@ -0,0 +1,207 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "crypto" + "crypto/rand" + "fmt" + "testing" + + "github.com/zmap/zgrab/ztools/xssh" +) + +func TestServer(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + client := NewClient(c1) + + go ServeAgent(NewKeyring(), c2) + + testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) +} + +func TestLockServer(t *testing.T) { + testLockAgent(NewKeyring(), t) +} + +func TestSetupForwardAgent(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + _, socket, cleanup := startAgent(t) + defer cleanup() + + serverConf := xssh.ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + incoming := make(chan *xssh.ServerConn, 1) + go func() { + conn, _, _, err := xssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + incoming <- conn + }() + + conf := xssh.ClientConfig{} + conn, chans, reqs, err := xssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + client := xssh.NewClient(conn, chans, reqs) + + if err := ForwardToRemote(client, socket); err != nil { + t.Fatalf("SetupForwardAgent: %v", err) + } + + server := <-incoming + ch, reqs, err := server.OpenChannel(channelType, nil) + if err != nil { + t.Fatalf("OpenChannel(%q): %v", channelType, err) + } + go xssh.DiscardRequests(reqs) + + agentClient := NewClient(ch) + testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) + conn.Close() +} + +func TestV1ProtocolMessages(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + c := NewClient(c1) + + go ServeAgent(NewKeyring(), c2) + + testV1ProtocolMessages(t, c.(*client)) +} + +func testV1ProtocolMessages(t *testing.T, c *client) { + reply, err := c.call([]byte{agentRequestV1Identities}) + if err != nil { + t.Fatalf("v1 request all failed: %v", err) + } + if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 { + t.Fatalf("invalid request all response: %#v", reply) + } + + reply, err = c.call([]byte{agentRemoveAllV1Identities}) + if err != nil { + t.Fatalf("v1 remove all failed: %v", err) + } + if _, ok := reply.(*successAgentMsg); !ok { + t.Fatalf("invalid remove all response: %#v", reply) + } +} + +func verifyKey(sshAgent Agent) error { + keys, err := sshAgent.List() + if err != nil { + return fmt.Errorf("listing keys: %v", err) + } + + if len(keys) != 1 { + return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys)) + } + + buf := make([]byte, 128) + if _, err := rand.Read(buf); err != nil { + return fmt.Errorf("rand: %v", err) + } + + sig, err := sshAgent.Sign(keys[0], buf) + if err != nil { + return fmt.Errorf("sign: %v", err) + } + + if err := keys[0].Verify(buf, sig); err != nil { + return fmt.Errorf("verify: %v", err) + } + return nil +} + +func addKeyToAgent(key crypto.PrivateKey) error { + sshAgent := NewKeyring() + if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil { + return fmt.Errorf("add: %v", err) + } + return verifyKey(sshAgent) +} + +func TestKeyTypes(t *testing.T) { + for k, v := range testPrivateKeys { + if err := addKeyToAgent(v); err != nil { + t.Errorf("error adding key type %s, %v", k, err) + } + if err := addCertToAgentSock(v, nil); err != nil { + t.Errorf("error adding key type %s, %v", k, err) + } + } +} + +func addCertToAgentSock(key crypto.PrivateKey, cert *xssh.Certificate) error { + a, b, err := netPipe() + if err != nil { + return err + } + agentServer := NewKeyring() + go ServeAgent(agentServer, a) + + agentClient := NewClient(b) + if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil { + return fmt.Errorf("add: %v", err) + } + return verifyKey(agentClient) +} + +func addCertToAgent(key crypto.PrivateKey, cert *xssh.Certificate) error { + sshAgent := NewKeyring() + if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil { + return fmt.Errorf("add: %v", err) + } + return verifyKey(sshAgent) +} + +func TestCertTypes(t *testing.T) { + for keyType, key := range testPublicKeys { + cert := &xssh.Certificate{ + ValidPrincipals: []string{"gopher1"}, + ValidAfter: 0, + ValidBefore: xssh.CertTimeInfinity, + Key: key, + Serial: 1, + CertType: xssh.UserCert, + SignatureKey: testPublicKeys["rsa"], + Permissions: xssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil { + t.Fatalf("signcert: %v", err) + } + if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil { + t.Fatalf("%v", err) + } + if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil { + t.Fatalf("%v", err) + } + } +} diff --git a/lib/ssh/agent/testdata_test.go b/lib/ssh/agent/testdata_test.go new file mode 100644 index 0000000..5240d61 --- /dev/null +++ b/lib/ssh/agent/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package agent + +import ( + "crypto/rand" + "fmt" + + "github.com/zmap/zgrab/ztools/xssh" + "github.com/zmap/zgrab/ztools/xssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]xssh.Signer + testPublicKeys map[string]xssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]xssh.Signer, n) + testPublicKeys = make(map[string]xssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = xssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = xssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &xssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: xssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: xssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = xssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/lib/ssh/benchmark_test.go b/lib/ssh/benchmark_test.go new file mode 100644 index 0000000..968115a --- /dev/null +++ b/lib/ssh/benchmark_test.go @@ -0,0 +1,122 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "errors" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + b.Fatalf("Client: %v", err) + } + ch, incoming, _ := newCh.Accept() + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + b.Fatalf("ReadFull: %v", err) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/lib/ssh/buffer.go b/lib/ssh/buffer.go new file mode 100644 index 0000000..80d80ae --- /dev/null +++ b/lib/ssh/buffer.go @@ -0,0 +1,98 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive os.EOF. +func (b *buffer) eof() error { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() + return nil +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/lib/ssh/buffer_test.go b/lib/ssh/buffer_test.go new file mode 100644 index 0000000..c474559 --- /dev/null +++ b/lib/ssh/buffer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "io" + "testing" +) + +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") + +func TestBufferReadwrite(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + r, _ := b.Read(make([]byte, 10)) + if r != 10 { + t.Fatalf("Expected written == read == 10, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + r, _ = b.Read(make([]byte, 10)) + if r != 5 { + t.Fatalf("Expected written == read == 5, written: 5, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:10]) + r, _ = b.Read(make([]byte, 5)) + if r != 5 { + t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + b.write(alphabet[5:15]) + r, _ = b.Read(make([]byte, 10)) + r2, _ := b.Read(make([]byte, 10)) + if r != 10 || r2 != 5 || 15 != r+r2 { + t.Fatal("Expected written == read == 15") + } +} + +func TestBufferClose(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + b.eof() + _, err := b.Read(make([]byte, 5)) + if err != nil { + t.Fatal("expected read of 5 to not return EOF") + } + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err := b.Read(make([]byte, 5)) + r2, err2 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || err != nil || err2 != nil { + t.Fatal("expected reads of 5 and 5") + } + + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err = b.Read(make([]byte, 5)) + r2, err2 = b.Read(make([]byte, 10)) + r3, err3 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { + t.Fatal("expected reads of 5 and 5 and 0, with EOF") + } + + b = newBuffer() + b.write(make([]byte, 5)) + b.write(make([]byte, 10)) + b.eof() + r, err = b.Read(make([]byte, 9)) + r2, err2 = b.Read(make([]byte, 3)) + r3, err3 = b.Read(make([]byte, 3)) + r4, err4 := b.Read(make([]byte, 10)) + if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { + t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) + } + if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { + t.Fatal("Expected written == read == 15", r, r2, r3, r4) + } +} diff --git a/lib/ssh/certs.go b/lib/ssh/certs.go new file mode 100644 index 0000000..b192b61 --- /dev/null +++ b/lib/ssh/certs.go @@ -0,0 +1,624 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sort" + "strconv" + "time" +) + +// These constants from [PROTOCOL.certkeys] represent the algorithm names +// for certificate types supported by this package. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" + CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string `json:"algorithm,omitempty"` + Blob []byte `json:"value,omitempty"` +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature + SignatureRaw []byte +} + +type JsonValidity struct { + ValidAfter time.Time `json:"valid_after"` + ValidBefore time.Time `json:"valid_before"` + Length uint64 `json:"length"` +} + +type JsonCertType struct { + Id uint32 `json:"id"` + Name string `json:"name"` +} + +type JsonPubKeyWrapper struct { + PublicKeyJsonLog + Raw []byte `json:"raw,omitempty"` + Fingerprint string `json:"fingerprint_sha256"` + Algorithm string `json:"algorithm"` +} + +type JsonSignature struct { + Parsed *Signature `json:"parsed,omitempty"` + Raw []byte `json:"raw,omitempty"` +} + +type JsonCertificate struct { + Nonce []byte `json:"nonce,omitempty"` + Key *JsonPubKeyWrapper `json:"key,omitempty"` + Serial string `json:"serial"` + CertType *JsonCertType `json:"cert_type,omitempty"` + KeyId string `json:"key_id,omitempty"` + ValidPrincipals []string `json:"valid_principals,omitempty"` + Validity *JsonValidity `json:"validity,omitempty"` + Reserved []byte `json:"reserved,omitempty"` + SignatureKey *JsonPubKeyWrapper `json:"signature_key"` + Signature *JsonSignature `json:"signature,omitempty"` + ParseError string `json:"parse_error,omitempty"` + Extensions *JsonExtensions `json:"extensions,omitempty"` + CriticalOptions *JsonCriticalOptions `json:"critical_options,omitempty"` +} + +func (c *Certificate) MarshalJSON() ([]byte, error) { + temp := JsonCertificate{ + Nonce: c.Nonce, + Serial: strconv.FormatUint(c.Serial, 10), + KeyId: c.KeyId, + ValidPrincipals: c.ValidPrincipals, + Reserved: c.Reserved, + } + + temp.Signature = new(JsonSignature) + temp.Signature.Raw = c.SignatureRaw + temp.Signature.Parsed = c.Signature + + temp.Key = new(JsonPubKeyWrapper) + + ok := temp.Key.AddPublicKey(c.Key) + if !ok { + temp.ParseError = "Cannot parse key to JSON" + } + + temp.Key.Raw = c.Key.Marshal() + tempHash := sha256.Sum256(temp.Key.Raw) + temp.Key.Fingerprint = hex.EncodeToString(tempHash[:]) + temp.Key.Algorithm = c.Key.Type() + + temp.CertType = new(JsonCertType) + temp.CertType.Id = c.CertType + switch c.CertType { + case 1: + temp.CertType.Name = "USER" + break + + case 2: + temp.CertType.Name = "HOST" + break + + default: + temp.CertType.Name = "unknown" + break + } + + var validityLength uint64 = 0 + if c.ValidBefore > c.ValidAfter { + validityLength = c.ValidBefore - c.ValidAfter + } + + temp.Validity = &JsonValidity{ + ValidAfter: time.Unix(int64(c.ValidAfter), 0).UTC(), + ValidBefore: time.Unix(int64(c.ValidBefore), 0).UTC(), + Length: validityLength, + } + + temp.SignatureKey = new(JsonPubKeyWrapper) + ok = temp.SignatureKey.AddPublicKey(c.SignatureKey) + if !ok { + temp.ParseError = "Cannot parse signature key to JSON" + } + + temp.SignatureKey.Raw = c.SignatureKey.Marshal() + tempHash = sha256.Sum256(temp.SignatureKey.Raw) + temp.SignatureKey.Fingerprint = hex.EncodeToString(tempHash[:]) + temp.SignatureKey.Algorithm = c.SignatureKey.Type() + + if len(c.CriticalOptions) != 0 { + temp.CriticalOptions = new(JsonCriticalOptions) + temp.CriticalOptions.options = c.CriticalOptions + } + + if len(c.Extensions) != 0 { + temp.Extensions = new(JsonExtensions) + temp.Extensions.extensions = c.Extensions + } + + return json.Marshal(temp) +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureRaw = g.Signature + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return nil, errors.New("ssh: signer and cert have different public key") + } + + return &openSSHCertSigner{cert, signer}, nil +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsAuthority should return true if the key is recognized as + // an authority. This allows for certificates to be signed by other + // certificates. + IsAuthority func(auth PublicKey) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + + return c.CheckCert(addr, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) + } + + for opt := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + if !c.IsAuthority(cert.SignatureKey) { + return fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert sets c.SignatureKey to the authority's public key and stores a +// Signature, by authority, in the certificate. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +var certAlgoNames = map[string]string{ + KeyAlgoRSA: CertAlgoRSAv01, + KeyAlgoDSA: CertAlgoDSAv01, + KeyAlgoECDSA256: CertAlgoECDSA256v01, + KeyAlgoECDSA384: CertAlgoECDSA384v01, + KeyAlgoECDSA521: CertAlgoECDSA521v01, + KeyAlgoED25519: CertAlgoED25519v01, +} + +// certToPrivAlgo returns the underlying algorithm for a certificate algorithm. +// Panics if a non-certificate algorithm is passed. +func certToPrivAlgo(algo string) string { + for privAlgo, pubAlgo := range certAlgoNames { + if pubAlgo == algo { + return privAlgo + } + } + panic("unknown cert algorithm") +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the key name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + algo, ok := certAlgoNames[c.Key.Type()] + if !ok { + panic("unknown cert key type " + c.Key.Type()) + } + return algo +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/lib/ssh/certs_test.go b/lib/ssh/certs_test.go new file mode 100644 index 0000000..a5219f4 --- /dev/null +++ b/lib/ssh/certs_test.go @@ -0,0 +1,216 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto/rand" + "reflect" + "testing" + "time" +) + +// Cert generated by ssh-keygen 6.0p1 Debian-4. +// % ssh-keygen -s ca-key -I test user-key +const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` + +func TestParseCert(t *testing.T) { + authKeyBytes := []byte(exampleSSHCert) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 +// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub +// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN +// Critical Options: +// force-command /bin/sleep +// source-address 192.168.1.0/24 +// Extensions: +// permit-X11-forwarding +// permit-agent-forwarding +// permit-port-forwarding +// permit-pty +// permit-user-rc +const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` + +func TestParseCertWithOptions(t *testing.T) { + opts := map[string]string{ + "source-address": "192.168.1.0/24", + "force-command": "/bin/sleep", + } + exts := map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + authKeyBytes := []byte(exampleSSHCertWithOptions) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + cert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + if !reflect.DeepEqual(cert.CriticalOptions, opts) { + t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) + } + if !reflect.DeepEqual(cert.Extensions, exts) { + t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) + } + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("user", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, + } + if err := checker.CheckCert("user", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } + + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } + } +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsAuthority: func(p PublicKey) bool { + return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, name := range []string{"hostname", "otherhost"} { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errc := make(chan error) + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(certSigner) + _, _, _, err := NewServerConn(c1, &conf) + errc <- err + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + } + _, _, _, err = NewClientConn(c2, name, config) + + succeed := name == "hostname" + if (err == nil) != succeed { + t.Fatalf("NewClientConn(%q): %v", name, err) + } + + err = <-errc + if (err == nil) != succeed { + t.Fatalf("NewServerConn(%q): %v", name, err) + } + } +} diff --git a/lib/ssh/channel.go b/lib/ssh/channel.go new file mode 100644 index 0000000..a9db6f2 --- /dev/null +++ b/lib/ssh/channel.go @@ -0,0 +1,633 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + // If the channel is closed before a reply is returned, io.EOF + // is returned. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window. + windowMu sync.Mutex + myWindow uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (c *channel) writePacket(packet []byte) error { + c.writeMu.Lock() + if c.sentClose { + c.writeMu.Unlock() + return io.EOF + } + c.sentClose = (packet[0] == msgChannelClose) + err := c.mux.conn.writePacket(packet) + c.writeMu.Unlock() + return err +} + +func (c *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], c.remoteId) + return c.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if c.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + c.writeMu.Lock() + packet := c.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + c.writeMu.Unlock() + + for len(data) > 0 { + space := min(c.maxRemotePayload, len(data)) + if space, err = c.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], c.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = c.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + c.writeMu.Lock() + c.packetPool[extendedCode] = packet + c.writeMu.Unlock() + + return n, err +} + +func (c *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > c.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + c.windowMu.Lock() + if c.myWindow < length { + c.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + c.myWindow -= length + c.windowMu.Unlock() + + if extended == 1 { + c.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + c.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(n uint32) error { + c.windowMu.Lock() + // Since myWindow is managed on our side, and can never exceed + // the initial window setting, we don't worry about overflow. + c.myWindow += uint32(n) + c.windowMu.Unlock() + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: uint32(n), + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (c *channel) responseMessageReceived() error { + if c.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if c.decided { + return errors.New("ssh: duplicate response received for channel") + } + c.decided = true + return nil +} + +func (c *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return c.handleData(packet) + case msgChannelClose: + c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) + c.mux.chanList.remove(c.localId) + c.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + c.extPending.eof() + c.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := c.responseMessageReceived(); err != nil { + return err + } + c.mux.chanList.remove(msg.PeersId) + c.msg <- msg + case *channelOpenConfirmMsg: + if err := c.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + c.remoteId = msg.MyId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.MyWindow) + c.msg <- msg + case *windowAdjustMsg: + if !c.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: c, + } + + c.incomingRequests <- &req + default: + c.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, 16), + msg: make(chan interface{}, 16), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (c *channel) Accept() (Channel, <-chan *Request, error) { + if c.decided { + return nil, nil, errDecidedAlready + } + c.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersId: c.remoteId, + MyId: c.localId, + MyWindow: c.myWindow, + MaxPacketSize: c.maxIncomingPayload, + } + c.decided = true + if err := c.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return c, c.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersId: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersId: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersId: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersId: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersId: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersId: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/lib/ssh/cipher.go b/lib/ssh/cipher.go new file mode 100644 index 0000000..dd023ab --- /dev/null +++ b/lib/ssh/cipher.go @@ -0,0 +1,579 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "io/ioutil" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type streamCipherMode struct { + keySize int + ivSize int + skip int + createFunc func(key, iv []byte) (cipher.Stream, error) +} + +func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { + if len(key) < c.keySize { + panic("ssh: key length too small for cipher") + } + if len(iv) < c.ivSize { + panic("ssh: iv too small for cipher") + } + + stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) + if err != nil { + return nil, err + } + + var streamDump []byte + if c.skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := c.skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + return stream, nil +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*streamCipherMode{ + // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, + "aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, + "aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, + + // Ciphers from RFC4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, 1536, newRC4}, + "arcfour256": {32, 0, 1536, newRC4}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, 0, newRC4}, + + // AES-GCM is not a stream cipher, so it is constructed with a + // special case. If we add any more non-stream ciphers, we + // should invest a cleaner way to do this. + gcmCipherID: {16, 12, 0, nil}, + + // CBC mode is insecure and so is not included in the default config. + // (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely + // needed, it's possible to specify a custom Config to enable it. + // You should expect that an active attacker can recover plaintext if + // you do. + aes128cbcID: {16, aes.BlockSize, 0, nil}, + + // 3des-cbc is insecure and is disabled by default. + tripledescbcID: {24, des.BlockSize, 0, nil}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + s.mac.Write(data) + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writePacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + s.mac.Write(packet) + s.mac.Write(padding) + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded.") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + padding := plain[0] + if padding < 4 || padding >= 20 { + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + macSize uint32 + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte + + // Amount of data we should still read to hide which + // verification error triggered. + oracleCamouflage uint32 +} + +func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + cbc := &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + } + if cbc.mac != nil { + cbc.macSize = uint32(cbc.mac.Size()) + } + + return cbc, nil +} + +func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, iv, key, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func newTripleDESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, iv, key, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +// cbcError represents a verification error that may leak information. +type cbcError string + +func (e cbcError) Error() string { return string(e) } + +func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + p, err := c.readPacketLeaky(seqNum, r) + if err != nil { + if _, ok := err.(cbcError); ok { + // Verification error: read a fixed amount of + // data, to make distinguishing between + // failing MAC and failing length check more + // difficult. + io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) + } + } + return p, err +} + +func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, cbcError("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, cbcError("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, cbcError("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, cbcError("ssh: invalid packet length") + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + c.macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { + return nil, err + } else { + c.oracleCamouflage -= uint32(n) + } + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, cbcError("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + c.macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} diff --git a/lib/ssh/cipher_test.go b/lib/ssh/cipher_test.go new file mode 100644 index 0000000..26eee25 --- /dev/null +++ b/lib/ssh/cipher_test.go @@ -0,0 +1,127 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/rand" + "testing" +) + +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range defaultCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("default cipher %q is unknown", cipherAlgo) + } + } +} + +func TestPacketCiphers(t *testing.T) { + // Still test aes128cbc cipher although it's commented out. + cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} + defer delete(cipherModes, aes128cbcID) + + for cipher := range cipherModes { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) + continue + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) + continue + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writePacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writePacket(%q): %v", cipher, err) + continue + } + + packet, err := server.readPacket(0, buf) + if err != nil { + t.Errorf("readPacket(%q): %v", cipher, err) + continue + } + + if string(packet) != want { + t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) + } + } +} + +func TestCBCOracleCounterMeasure(t *testing.T) { + cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} + defer delete(cipherModes, aes128cbcID) + + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: aes128cbcID, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writePacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writePacket: %v", err) + } + + packetSize := buf.Len() + buf.Write(make([]byte, 2*maxPacket)) + + // We corrupt each byte, but this usually will only test the + // 'packet too large' or 'MAC failure' cases. + lastRead := -1 + for i := 0; i < packetSize; i++ { + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + fresh := &bytes.Buffer{} + fresh.Write(buf.Bytes()) + fresh.Bytes()[i] ^= 0x01 + + before := fresh.Len() + _, err = server.readPacket(0, fresh) + if err == nil { + t.Errorf("corrupt byte %d: readPacket succeeded ", i) + continue + } + if _, ok := err.(cbcError); !ok { + t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) + continue + } + + after := fresh.Len() + bytesRead := before - after + if bytesRead < maxPacket { + t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) + continue + } + + if i > 0 && bytesRead != lastRead { + t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) + } + lastRead = bytesRead + } +} diff --git a/lib/ssh/client.go b/lib/ssh/client.go new file mode 100644 index 0000000..4a8e1b6 --- /dev/null +++ b/lib/ssh/client.go @@ -0,0 +1,272 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "errors" + "fmt" + "net" + "strings" + "sync" + "time" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, port forwarding and tunneled dialing. +type Client struct { + Conn + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, 16) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + conn := &connection{ + sshConn: sshConn{conn: c}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + if config.ConnLog != nil { + config.ConnLog.ServerID = new(EndpointId) + config.ConnLog.ServerID.Raw = string(c.serverVersion) + + serverSplitId := strings.SplitN(string(c.serverVersion), " ", 2) + if len(serverSplitId) == 2 { + config.ConnLog.ServerID.Comment = serverSplitId[1] + } + + serverSplitGroup := strings.SplitN(serverSplitId[0], "-", 3) + if serverSplitGroup[0] == "SSH" { + // If ID doesn't start with "SSH", don't attempt to parse. + if len(serverSplitGroup) > 1 { + config.ConnLog.ServerID.ProtoVersion = serverSplitGroup[1] + } + + if len(serverSplitGroup) == 3 { + config.ConnLog.ServerID.SoftwareVersion = serverSplitGroup[2] + } + } + } + if pkgConfig.Verbose { + if config.ConnLog != nil { + //config.ConnLog.ClientIDString = string(c.clientVersion) + } + + if config.ConnLog != nil { + config.ConnLog.ClientID = new(EndpointId) + config.ConnLog.ClientID.Raw = string(c.clientVersion) + + clientSplitId := strings.SplitN(string(c.clientVersion), " ", 2) + if len(clientSplitId) == 2 { + config.ConnLog.ClientID.Comment = clientSplitId[1] + } + + clientSplitGroup := strings.SplitN(clientSplitId[0], "-", 3) + if clientSplitGroup[0] == "SSH" { + // If ID doesn't start with "SSH", don't attempt to parse. + if len(clientSplitGroup) > 1 { + config.ConnLog.ClientID.ProtoVersion = clientSplitGroup[1] + } + + if len(clientSplitGroup) == 3 { + config.ConnLog.ClientID.SoftwareVersion = clientSplitGroup[2] + } + } + } + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + + if err := c.transport.requestInitialKeyChange(); err != nil { + return err + } + + // We just did the key change, so the session ID is established. + c.sessionID = c.transport.getSessionID() + + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key +// exchange. +func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + + if config.Timeout != 0 { + conn.SetDeadline(time.Now().Add(config.Timeout)) + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback, if not nil, is called during the cryptographic + // handshake to validate the server's host key. A nil HostKeyCallback + // implies that all host keys are accepted. + HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string + + // HostKeyAlgorithms lists the key types that the client will + // accept from the server as host key, in order of + // preference. If empty, a reasonable default is used. Any + // string returned from PublicKey.Type method may be used, or + // any of the CertAlgoXxxx and KeyAlgoXxxx constants. + HostKeyAlgorithms []string + + // Timeout is the maximum amount of time for the TCP connection to establish. + // + // A Timeout of zero means no timeout. + Timeout time.Duration + + // If true, send the "none" Authentication Request to collect the advertised + // userauth method names, but do not attempt to authenticate. + DontAuthenticate bool +} diff --git a/lib/ssh/client_auth.go b/lib/ssh/client_auth.go new file mode 100644 index 0000000..9caf3c4 --- /dev/null +++ b/lib/ssh/client_auth.go @@ -0,0 +1,486 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + if c.transport.config.ConnLog != nil && !pkgConfig.CollectUserAuth { + // Use ConnLog existence to indicate that this is a run and not testing + return nil + } + + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + tried := make(map[string]bool) + var lastMethods []string + for auth := AuthMethod(new(noneAuth)); auth != nil; { + ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) + if err != nil { + return err + } + if ok { + // success + return nil + } + + if c.transport.config.ConnLog != nil { + c.transport.config.ConnLog.UserAuth = methods + } + if config.DontAuthenticate { + return nil + } + + tried[auth.method()] = true + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if tried[candidateMethod] { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) +} + +func keys(m map[string]bool) []string { + s := make([]string, 0, len(m)) + + for key := range m { + s = append(s, key) + } + return s +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return false, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return false, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return false, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + // Authentication is performed in two stages. The first stage sends an + // enquiry to test if each key is acceptable to the remote. The second + // stage attempts to authenticate with the valid keys obtained in the + // first stage. + + signers, err := cb() + if err != nil { + return false, nil, err + } + var validKeys []Signer + for _, signer := range signers { + if ok, err := validateKey(signer.PublicKey(), user, c); ok { + validKeys = append(validKeys, signer) + } else { + if err != nil { + return false, nil, err + } + } + } + + // methods that may continue if this auth is not successful. + var methods []string + for _, signer := range validKeys { + pub := signer.PublicKey() + + pubKey := pub.Marshal() + sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, []byte(pub.Type()), pubKey)) + if err != nil { + return false, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: pub.Type(), + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return false, nil, err + } + var success bool + success, methods, err = handleAuthResponse(c) + if err != nil { + return false, nil, err + } + if success { + return success, methods, err + } + } + return false, methods, nil +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: key.Type(), + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + algoname := key.Type() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + // TODO(gpaul): add callback to present the banner to the user + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (bool, []string, error) { + for { + packet, err := c.readPacket() + if err != nil { + return false, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + // TODO: add callback to present the banner to the user + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + default: + return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the user and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns a AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return false, nil, err + } + + for { + packet, err := c.readPacket() + if err != nil { + return false, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + // TODO: Print banners during userauth. + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + default: + return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return false, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.User, msg.Instruction, prompts, echos) + if err != nil { + return false, nil, err + } + + if len(answers) != len(prompts) { + return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return false, nil, err + } + } +} + +type retryableAuthMethod struct { + authMethod AuthMethod + maxTries int +} + +func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok bool, methods []string, err error) { + for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { + ok, methods, err = r.authMethod.auth(session, user, c, rand) + if ok || err != nil { // either success or error terminate + return ok, methods, err + } + } + return ok, methods, err +} + +func (r *retryableAuthMethod) method() string { + return r.authMethod.method() +} + +// RetryableAuthMethod is a decorator for other auth methods enabling them to +// be retried up to maxTries before considering that AuthMethod itself failed. +// If maxTries is <= 0, will retry indefinitely +// +// This is useful for interactive clients using challenge/response type +// authentication (e.g. Keyboard-Interactive, Password, etc) where the user +// could mistype their response resulting in the server issuing a +// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4 +// [keyboard-interactive]); Without this decorator, the non-retryable +// AuthMethod would be removed from future consideration, and never tried again +// (and so the user would never be able to retry their entry). +func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { + return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} +} diff --git a/lib/ssh/client_auth_test.go b/lib/ssh/client_auth_test.go new file mode 100644 index 0000000..1f98c5f --- /dev/null +++ b/lib/ssh/client_auth_test.go @@ -0,0 +1,472 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "os" + "strings" + "testing" +) + +type keyboardInteractive map[string]string + +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) + } + return answers, nil +} + +// reused internally by tests +var clientPassword = "tiger" + +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig +func tryAuth(t *testing.T, config *ClientConfig) error { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", + "instruction", + []string{"question1", "question2"}, + []bool{true, true}) + if err != nil { + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil + } + return nil, errors.New("keyboard-interactive failed") + }, + AuthLogCallback: func(conn ConnMetadata, method string, err error) { + t.Logf("user %q, method %q: %v", conn.User(), method, err) + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err +} + +func TestClientAuthPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodFallback(t *testing.T) { + var passwordCalled bool + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PasswordCallback( + func() (string, error) { + passwordCalled = true + return "WRONG", nil + }), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + if passwordCalled { + t.Errorf("password auth tried before public-key auth.") + } +} + +func TestAuthMethodWrongPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "answer2", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "WRONG", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") + } +} + +// the mock server will only authenticate ssh-rsa keys +func TestAuthMethodInvalidPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("dsa private key should not have authenticated with rsa public key") + } +} + +// the client should authenticate with the second key +func TestAuthMethodRSAandDSA(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with rsa key: %v", err) + } +} + +func TestClientHMAC(t *testing.T) { + for _, mac := range supportedMACs { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + Config: Config{ + MACs: []string{mac}, + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) + } + } +} + +// issue 4285. +func TestClientUnsupportedCipher(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + Ciphers: []string{"aes128-cbc"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil { + t.Errorf("expected no ciphers in common") + } +} + +func TestClientUnsupportedKex(t *testing.T) { + if os.Getenv("GO_BUILDER_NAME") != "" { + t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198") + } + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { + t.Errorf("got %v, expected 'common algorithm'", err) + } +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + t.Log("should succeed") + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("corrupted signature") + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + t.Log("revoked") + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + t.Log("sign with wrong key") + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritative key") + } + + t.Log("host cert") + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + t.Log("principal specified") + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("wrong principal specified") + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + t.Log("added critical option") + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + t.Log("allowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + t.Log("disallowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") + } +} + +func testPermissionsPassing(withPermissions bool, t *testing.T) { + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "nopermissions" { + return nil, nil + } else { + return &Permissions{}, nil + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + } + if withPermissions { + clientConfig.User = "permissions" + } else { + clientConfig.User = "nopermissions" + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatal(err) + } + if p := serverConn.Permissions; (p != nil) != withPermissions { + t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) + } +} + +func TestPermissionsPassing(t *testing.T) { + testPermissionsPassing(true, t) +} + +func TestNoPermissionsPassing(t *testing.T) { + testPermissionsPassing(false, t) +} + +func TestRetryableAuth(t *testing.T) { + n := 0 + passwords := []string{"WRONG1", "WRONG2"} + + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 2), + PublicKeys(testSigners["rsa"]), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if n != 2 { + t.Fatalf("Did not try all passwords") + } +} + +func ExampleRetryableAuthMethod(t *testing.T) { + user := "testuser" + NumberOfPrompts := 3 + + // Normally this would be a callback that prompts the user to answer the + // provided questions + Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return []string{"answer1", "answer2"}, nil + } + + config := &ClientConfig{ + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +// Test if username is received on server side when NoClientAuth is used +func TestClientAuthNone(t *testing.T) { + user := "testuser" + serverConfig := &ServerConfig{ + NoClientAuth: true, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + User: user, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatalf("newServer: %v", err) + } + if serverConn.User() != user { + t.Fatalf("server: got %q, want %q", serverConn.User(), user) + } +} diff --git a/lib/ssh/client_test.go b/lib/ssh/client_test.go new file mode 100644 index 0000000..830f4f6 --- /dev/null +++ b/lib/ssh/client_test.go @@ -0,0 +1,39 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "net" + "testing" +) + +func testClientVersion(t *testing.T, config *ClientConfig, expected string) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + receivedVersion := make(chan string, 1) + go func() { + version, err := readVersion(serverConn) + if err != nil { + receivedVersion <- "" + } else { + receivedVersion <- string(version) + } + serverConn.Close() + }() + NewClientConn(clientConn, "", config) + actual := <-receivedVersion + if actual != expected { + t.Fatalf("got %s; want %s", actual, expected) + } +} + +func TestCustomClientVersion(t *testing.T) { + version := "Test-Client-Version-0.0" + testClientVersion(t, &ClientConfig{ClientVersion: version}, version) +} + +func TestDefaultClientVersion(t *testing.T) { + testClientVersion(t, &ClientConfig{}, packageVersion) +} diff --git a/lib/ssh/common.go b/lib/ssh/common.go new file mode 100644 index 0000000..67ebb3a --- /dev/null +++ b/lib/ssh/common.go @@ -0,0 +1,396 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "crypto" + "crypto/rand" + "encoding/json" + "fmt" + "io" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// defaultCiphers specifies the default ciphers in preference order. +var defaultCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", + "arcfour256", "arcfour128", +} + +// allSupportedCiphers specifies all ciphers which are supported +var allSupportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + gcmCipherID, + "arcfour256", "arcfour128", + // Not offered by default: + "arcfour", aes128cbcID, tripledescbcID, +} + +// defaultKexAlgos specifies the default key-exchange algorithms in +// preference order. +var defaultKexAlgos = []string{ + kexAlgoCurve25519SHA256, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA1, kexAlgoDH1SHA1, +} + +// allSupportedKexAlgos specifies all key-exchange algorithms supported +var allSupportedKexAlgos = []string{ + kexAlgoCurve25519SHA256, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA1, kexAlgoDH1SHA1, + // Not enabled by default: + kexAlgoDHGEXSHA1, kexAlgoDHGEXSHA256, +} + +// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSA, KeyAlgoDSA, + + KeyAlgoED25519, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported algorithms to their respective +// hashes needed for signature verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + CertAlgoRSAv01: crypto.SHA1, + CertAlgoDSAv01: crypto.SHA1, + CertAlgoECDSA256v01: crypto.SHA256, + CertAlgoECDSA384v01: crypto.SHA384, + CertAlgoECDSA521v01: crypto.SHA512, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommon(what string, client []string, server []string) (common string, err error) { + for _, c := range client { + for _, s := range server { + if c == s { + return c, nil + } + } + } + return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) +} + +type directionAlgorithms struct { + Cipher string `json:"cipher"` + MAC string `json:"mac"` + Compression string `json:"compression"` +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func (alg *algorithms) MarshalJSON() ([]byte, error) { + aux := struct { + Kex string `json:"dh_kex_algorithm"` + HostKey string `json:"host_key_algorithm"` + W directionAlgorithms `json:"client_to_server_alg_group"` + R directionAlgorithms `json:"server_to_client_alg_group"` + }{ + Kex: alg.kex, + HostKey: alg.hostKey, + W: alg.w, + R: alg.r, + } + + return json.Marshal(aux) +} + +func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { + result := &algorithms{} + + result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if err != nil { + return + } + + result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if err != nil { + return + } + + result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if err != nil { + return + } + + result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if err != nil { + return + } + + result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if err != nil { + return + } + + result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if err != nil { + return + } + + result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if err != nil { + return + } + + result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if err != nil { + return + } + + return result, nil +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, 1 gigabyte is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a + // default set of algorithms is used. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible + // default is used. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default + // is used. + MACs []string + + // A pointer to the handshake log IOT allow incremental building + ConnLog *HandshakeLog +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = defaultCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // reject the cipher if we have no cipherModes definition + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = defaultKexAlgos + } + + if c.MACs == nil { + c.MACs = supportedMACs + } + + if c.RekeyThreshold == 0 { + // RFC 4253, section 9 suggests rekeying after 1G. + c.RekeyThreshold = 1 << 30 + } + if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. +func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo []byte + PubKey []byte + }{ + sessionId, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/lib/ssh/config.go b/lib/ssh/config.go new file mode 100644 index 0000000..7798913 --- /dev/null +++ b/lib/ssh/config.go @@ -0,0 +1,11 @@ +package xssh + +func MakeXSSHConfig() *ClientConfig { + ret := new(ClientConfig) + ret.DontAuthenticate = true // IOT scan ethically, never attempt to authenticate + ret.ClientVersion = pkgConfig.ClientID + ret.HostKeyAlgorithms = pkgConfig.HostKeyAlgorithms.Get() + ret.KeyExchanges = pkgConfig.KexAlgorithms.Get() + ret.Ciphers = pkgConfig.Ciphers.Get() + return ret +} diff --git a/lib/ssh/connection.go b/lib/ssh/connection.go new file mode 100644 index 0000000..f01f70c --- /dev/null +++ b/lib/ssh/connection.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + User() string + + // SessionID returns the sesson hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the server's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshconn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/lib/ssh/doc.go b/lib/ssh/doc.go new file mode 100644 index 0000000..4bbc826 --- /dev/null +++ b/lib/ssh/doc.go @@ -0,0 +1,18 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 +*/ +package xssh // import "github.com/zmap/zgrab/ztools/xssh" diff --git a/lib/ssh/example_test.go b/lib/ssh/example_test.go new file mode 100644 index 0000000..2b4563a --- /dev/null +++ b/lib/ssh/example_test.go @@ -0,0 +1,262 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh_test + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + + "github.com/zmap/zgrab/ztools/xssh" + "github.com/zmap/zgrab/ztools/xssh/terminal" +) + +func ExampleNewServerConn() { + // Public key authentication is done by comparing + // the public key of a received connection + // with the entries in the authorized_keys file. + authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys") + if err != nil { + log.Fatalf("Failed to load authorized_keys, err: %v", err) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := xssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + log.Fatal(err) + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &xssh.ServerConfig{ + // Remove to disable password auth. + PasswordCallback: func(c xssh.ConnMetadata, pass []byte) (*xssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + + // Remove to disable public key auth. + PublicKeyCallback: func(c xssh.ConnMetadata, pubKey xssh.PublicKey) (*xssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return nil, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + privateBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := xssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := xssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake: ", err) + } + // The incoming Request channel must be serviced. + go xssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(xssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + go func(in <-chan *xssh.Request) { + for req := range in { + req.Reply(req.Type == "shell", nil) + } + }(requests) + + term := terminal.NewTerminal(channel, "> ") + + go func() { + defer channel.Close() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + } +} + +func ExampleDial() { + // An SSH client is represented with a ClientConn. + // + // To authenticate with the remote server you must pass at least one + // implementation of AuthMethod via the Auth field in ClientConfig. + config := &xssh.ClientConfig{ + User: "username", + Auth: []xssh.AuthMethod{ + xssh.Password("yourpassword"), + }, + } + client, err := xssh.Dial("tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := client.NewSession() + if err != nil { + log.Fatal("Failed to create session: ", err) + } + defer session.Close() + + // Once a Session is created, you can execute a single command on + // the remote side using the Run method. + var b bytes.Buffer + session.Stdout = &b + if err := session.Run("/usr/bin/whoami"); err != nil { + log.Fatal("Failed to run: " + err.Error()) + } + fmt.Println(b.String()) +} + +func ExamplePublicKeys() { + // A public key may be used to authenticate against the remote + // server by using an unencrypted PEM-encoded private key file. + // + // If you have an encrypted private key, the crypto/x509 package + // can be used to decrypt it. + key, err := ioutil.ReadFile("/home/user/.ssh/id_rsa") + if err != nil { + log.Fatalf("unable to read private key: %v", err) + } + + // Create the Signer for this private key. + signer, err := xssh.ParsePrivateKey(key) + if err != nil { + log.Fatalf("unable to parse private key: %v", err) + } + + config := &xssh.ClientConfig{ + User: "user", + Auth: []xssh.AuthMethod{ + // Use the PublicKeys method for remote authentication. + xssh.PublicKeys(signer), + }, + } + + // Connect to the remote server and perform the SSH handshake. + client, err := xssh.Dial("tcp", "host.com:22", config) + if err != nil { + log.Fatalf("unable to connect: %v", err) + } + defer client.Close() +} + +func ExampleClient_Listen() { + config := &xssh.ClientConfig{ + User: "username", + Auth: []xssh.AuthMethod{ + xssh.Password("password"), + }, + } + // Dial your xssh server. + conn, err := xssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + + // Request the remote side to open port 8080 on all interfaces. + l, err := conn.Listen("tcp", "0.0.0.0:8080") + if err != nil { + log.Fatal("unable to register tcp forward: ", err) + } + defer l.Close() + + // Serve HTTP with your SSH server acting as a reverse proxy. + http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + fmt.Fprintf(resp, "Hello world!\n") + })) +} + +func ExampleSession_RequestPty() { + // Create client config + config := &xssh.ClientConfig{ + User: "username", + Auth: []xssh.AuthMethod{ + xssh.Password("password"), + }, + } + // Connect to xssh server + conn, err := xssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + // Create a session + session, err := conn.NewSession() + if err != nil { + log.Fatal("unable to create session: ", err) + } + defer session.Close() + // Set up terminal modes + modes := xssh.TerminalModes{ + xssh.ECHO: 0, // disable echoing + xssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + xssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + // Request pseudo terminal + if err := session.RequestPty("xterm", 40, 80, modes); err != nil { + log.Fatal("request for pseudo terminal failed: ", err) + } + // Start remote shell + if err := session.Shell(); err != nil { + log.Fatal("failed to start shell: ", err) + } +} diff --git a/lib/ssh/flag.go b/lib/ssh/flag.go new file mode 100644 index 0000000..61c9829 --- /dev/null +++ b/lib/ssh/flag.go @@ -0,0 +1,157 @@ +package xssh + +import ( + "errors" + "flag" + "fmt" + "strings" +) + +var pkgConfig XSSHConfig + +type XSSHConfig struct { + ClientID string + HostKeyAlgorithms HostKeyAlgorithmsList + KexAlgorithms KexAlgorithmsList + Verbose bool + CollectUserAuth bool + Ciphers CipherList + GexMinBits uint + GexMaxBits uint + GexPreferredBits uint +} + +type HostKeyAlgorithmsList struct { + Algorithms []string +} + +func (hkaList *HostKeyAlgorithmsList) String() string { + return strings.Join(hkaList.Algorithms, ",") +} + +func (hkaList *HostKeyAlgorithmsList) Set(value string) error { + for _, alg := range strings.Split(value, ",") { + isValid := false + for _, val := range supportedHostKeyAlgos { + if val == alg { + isValid = true + break + } + } + + if !isValid { + return errors.New(fmt.Sprintf(`host key algorithm not supported: "%s"`, alg)) + } + + hkaList.Algorithms = append(hkaList.Algorithms, alg) + } + return nil +} + +func (hkaList *HostKeyAlgorithmsList) Get() []string { + if len(hkaList.Algorithms) == 0 { + return supportedHostKeyAlgos + } else { + return hkaList.Algorithms + } +} + +type KexAlgorithmsList struct { + Algorithms []string +} + +func (kaList *KexAlgorithmsList) String() string { + return strings.Join(kaList.Algorithms, ",") +} + +func (kaList *KexAlgorithmsList) Set(value string) error { + for _, alg := range strings.Split(value, ",") { + isValid := false + for _, val := range allSupportedKexAlgos { + if val == alg { + isValid = true + break + } + } + + if !isValid { + return errors.New(fmt.Sprintf(`DH KEX algorithm not supported: "%s"`, alg)) + } + + kaList.Algorithms = append(kaList.Algorithms, alg) + } + return nil +} + +func (kaList *KexAlgorithmsList) Get() []string { + if len(kaList.Algorithms) == 0 { + return defaultKexAlgos + } else { + return kaList.Algorithms + } +} + +type CipherList struct { + Ciphers []string +} + +func (cList *CipherList) String() string { + return strings.Join(cList.Ciphers, ",") +} + +func (cList *CipherList) Set(value string) error { + for _, inCipher := range strings.Split(value, ",") { + isValid := false + for _, knownCipher := range allSupportedCiphers { + if inCipher == knownCipher { + isValid = true + break + } + } + + if !isValid { + return errors.New(fmt.Sprintf(`cipher not supported: "%s"`, inCipher)) + } + + cList.Ciphers = append(cList.Ciphers, inCipher) + } + + return nil +} + +func (cList *CipherList) Get() []string { + if len(cList.Ciphers) == 0 { + return defaultCiphers + } else { + return cList.Ciphers + } +} + +func init() { + flag.StringVar(&pkgConfig.ClientID, "xssh-client-id", packageVersion, "Specify the client ID string to use") + + hostKeyAlgUsage := fmt.Sprintf( + "A comma-separated list of which host key algorithms to offer (default \"%s\")", + strings.Join(supportedHostKeyAlgos, ","), + ) + flag.Var(&pkgConfig.HostKeyAlgorithms, "xssh-host-key-algorithms", hostKeyAlgUsage) + + kexAlgUsage := fmt.Sprintf( + "A comma-separated list of which DH key exchange algorithms to offer (default \"%s\")", + strings.Join(defaultKexAlgos, ","), + ) + flag.Var(&pkgConfig.KexAlgorithms, "xssh-kex-algorithms", kexAlgUsage) + + ciphersUsage := fmt.Sprintf( + `A comma-separated list of which ciphers to offer (default "%s")`, + strings.Join(defaultCiphers, ",")) + flag.Var(&pkgConfig.Ciphers, "xssh-ciphers", ciphersUsage) + + flag.BoolVar(&pkgConfig.Verbose, "xssh-verbose", false, "Output additional information, including X/SSH client properties from the SSH handshake.") + + flag.BoolVar(&pkgConfig.CollectUserAuth, "xssh-userauth", false, "Use the 'none' authentication request to see what userauth methods are allowed.") + + flag.UintVar(&pkgConfig.GexMinBits, "xssh-gex-min-bits", 1024, "The minimum number of bits for the DH GEX prime.") + flag.UintVar(&pkgConfig.GexMaxBits, "xssh-gex-max-bits", 8192, "The maximum number of bits for the DH GEX prime.") + flag.UintVar(&pkgConfig.GexPreferredBits, "xssh-gex-preferred-bits", 2048, "The preferred number of bits for the DH GEX prime.") +} diff --git a/lib/ssh/handshake.go b/lib/ssh/handshake.go new file mode 100644 index 0000000..5f0e6c2 --- /dev/null +++ b/lib/ssh/handshake.go @@ -0,0 +1,482 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + // hostKeys is non-empty if we are the server. In that case, + // it contains all host keys that can be used to sign the + // connection. + hostKeys []Signer + + // hostKeyAlgorithms is non-empty if we are the client. In that case, + // we accept these key types from the server as host key. + hostKeyAlgorithms []string + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + // data for host key checking + hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + dialAddress string + remoteAddr net.Addr + + readSinceKex uint64 + + // Protects the writing side of the connection + mu sync.Mutex + cond *sync.Cond + sentInitPacket []byte + sentInitMsg *kexInitMsg + writtenSinceKex uint64 + writeError error + + // The session ID or nil if first kex did not complete yet. + sessionID []byte +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, 16), + config: config, + } + t.cond = sync.NewCond(&t.mu) + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + if config.HostKeyAlgorithms != nil { + t.hostKeyAlgorithms = config.HostKeyAlgorithms + } else { + t.hostKeyAlgorithms = supportedHostKeyAlgos + } + go t.readLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + go t.readLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.sessionID +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + for { + p, err := t.readOnePacket() + if err != nil { + t.readError = err + close(t.incoming) + break + } + if p[0] == msgIgnore || p[0] == msgDebug { + continue + } + t.incoming <- p + } + + // If we can't read, declare the writing part dead too. + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil { + t.writeError = t.readError + } + t.cond.Broadcast() +} + +func (t *handshakeTransport) readOnePacket() ([]byte, error) { + if t.readSinceKex > t.config.RekeyThreshold { + if err := t.requestKeyChange(); err != nil { + return nil, err + } + } + + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + t.readSinceKex += uint64(len(p)) + if debugHandshake { + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s got data (packet %d bytes)", t.id(), len(p)) + } else { + msg, err := decode(p) + log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) + } + } + if p[0] != msgKexInit { + return p, nil + } + + t.mu.Lock() + + firstKex := t.sessionID == nil + + err = t.enterKeyExchangeLocked(p) + if err != nil { + // drop connection + t.conn.Close() + t.writeError = err + } + + if debugHandshake { + log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) + } + + // Unblock writers. + t.sentInitMsg = nil + t.sentInitPacket = nil + t.cond.Broadcast() + t.writtenSinceKex = 0 + t.mu.Unlock() + + if err != nil { + return nil, err + } + + t.readSinceKex = 0 + + // By default, a key exchange is hidden from higher layers by + // translating it into msgIgnore. + successPacket := []byte{msgIgnore} + if firstKex { + // sendKexInit() for the first kex waits for + // msgNewKeys so the authentication process is + // guaranteed to happen over an encrypted transport. + successPacket = []byte{msgNewKeys} + } + + return successPacket, nil +} + +// keyChangeCategory describes whether a key exchange is the first on a +// connection, or a subsequent one. +type keyChangeCategory bool + +const ( + firstKeyExchange keyChangeCategory = true + subsequentKeyExchange keyChangeCategory = false +) + +// sendKexInit sends a key change message, and returns the message +// that was sent. After initiating the key change, all writes will be +// blocked until the change is done, and a failed key change will +// close the underlying transport. This function is safe for +// concurrent use by multiple goroutines. +func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error { + var err error + + t.mu.Lock() + // If this is the initial key change, but we already have a sessionID, + // then do nothing because the key exchange has already completed + // asynchronously. + if !isFirst || t.sessionID == nil { + _, _, err = t.sendKexInitLocked(isFirst) + } + t.mu.Unlock() + if err != nil { + return err + } + if isFirst { + if packet, err := t.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + } + return nil +} + +func (t *handshakeTransport) requestInitialKeyChange() error { + return t.sendKexInit(firstKeyExchange) +} + +func (t *handshakeTransport) requestKeyChange() error { + return t.sendKexInit(subsequentKeyExchange) +} + +// sendKexInitLocked sends a key change message. t.mu must be locked +// while this happens. +func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + if t.sentInitMsg != nil { + return t.sentInitMsg, t.sentInitPacket, nil + } + + msg := &kexInitMsg{ + KexAlgos: t.config.KeyExchanges, + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + if len(t.hostKeys) > 0 { + for _, k := range t.hostKeys { + msg.ServerHostKeyAlgos = append( + msg.ServerHostKeyAlgos, k.PublicKey().Type()) + } + } else { + msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + } + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.conn.writePacket(packetCopy); err != nil { + return nil, nil, err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + return msg, packet, nil +} + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.writtenSinceKex > t.config.RekeyThreshold { + t.sendKexInitLocked(subsequentKeyExchange) + } + for t.sentInitMsg != nil && t.writeError == nil { + t.cond.Wait() + } + if t.writeError != nil { + return t.writeError + } + t.writtenSinceKex += uint64(len(p)) + + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + default: + return t.conn.writePacket(p) + } +} + +func (t *handshakeTransport) Close() error { + return t.conn.Close() +} + +// enterKeyExchange runs the key exchange. t.mu must be held while running this. +func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange) + if err != nil { + return err + } + + if pkgConfig.Verbose { + if t.config.ConnLog != nil { + t.config.ConnLog.ClientKex = myInit + } + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + if t.config.ConnLog != nil { + t.config.ConnLog.ServerKex = otherInit + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: myInitPacket, + } + + clientInit := otherInit + serverInit := myInit + if len(t.hostKeys) == 0 { + clientInit = myInit + serverInit = otherInit + + magics.clientKexInit = myInitPacket + magics.serverKexInit = otherInitPacket + } + + algs, err := findAgreedAlgorithms(clientInit, serverInit) + if err != nil { + return err + } + if t.config.ConnLog != nil { + t.config.ConnLog.AlgorithmSelection = algs + } + + // We don't send FirstKexFollows, but we handle receiving it. + // + // RFC 4253 section 7 defines the kex and the agreement method for + // first_kex_packet_follows. It states that the guessed packet + // should be ignored if the "kex algorithm and/or the host + // key algorithm is guessed wrong (server and client have + // different preferred algorithm), or if any of the other + // algorithms cannot be agreed upon". The other algorithms have + // already been checked above so the kex algorithm and host key + // algorithm are checked here. + if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[algs.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) + } + + kex = kex.GetNew(algs.kex) + + if t.config.ConnLog != nil { + t.config.ConnLog.DHKeyExchange = kex + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, algs, &magics) + } else { + result, err = t.client(kex, algs, &magics) + } + if pkgConfig.Verbose { + if t.config.ConnLog != nil { + t.config.ConnLog.Crypto = result + } + } + if err != nil { + return err + } + + if t.sessionID == nil { + t.sessionID = result.H + } + result.SessionID = t.sessionID + + t.conn.prepareKeyChange(algs, result) + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + var hostKey Signer + for _, k := range t.hostKeys { + if algs.hostKey == k.PublicKey().Type() { + hostKey = k + } + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, result); err != nil { + return nil, err + } + + if t.hostKeyCallback != nil { + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + } + + return result, nil +} diff --git a/lib/ssh/handshake_test.go b/lib/ssh/handshake_test.go new file mode 100644 index 0000000..41603b9 --- /dev/null +++ b/lib/ssh/handshake_test.go @@ -0,0 +1,486 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + trC := newTransport(a, rand.Reader, true) + trS := newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + go func() { + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress + for i := 0; i < 10; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i == 5 { + // halfway through, we request a key change. + err := trC.sendKexInit(subsequentKeyExchange) + if err != nil { + t.Fatalf("sendKexInit: %v", err) + } + } + } + trC.Close() + }() + + // Server checks that client messages come in cleanly + i := 0 + for { + p, err := trS.readPacket() + if err != nil { + break + } + if p[0] == msgNewKeys { + continue + } + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %q, want %q", i, p, want) + } + i++ + } + if i != 10 { + t.Errorf("received %d messages, want 10.", i) + } + + // If all went well, we registered exactly 1 key change. + if len(checker.calls) != 1 { + t.Fatalf("got %d host key checks, want 1", len(checker.calls)) + } + + pub := testSigners["ecdsa"].PublicKey() + want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) + if want != checker.calls[0] { + t.Errorf("got %q want %q for host key check", checker.calls[0], want) + } +} + +func TestHandshakeError(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + // send a packet + packet := []byte{msgRequestSuccess, 42} + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // Now request a key change. + err = trC.sendKexInit(subsequentKeyExchange) + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + // the key change will fail, and afterwards we can't write. + if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { + t.Errorf("writePacket after botched rekey succeeded.") + } + + readback, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + if bytes.Compare(readback, packet) != 0 { + t.Errorf("got %q want %q", readback, packet) + } + readback, err = trS.readPacket() + if err == nil { + t.Errorf("got a message %q after failed key change", readback) + } +} + +func TestForceFirstKex(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) + + // We setup the initial key exchange, but the remote side + // tries to send serviceRequestMsg in cleartext, which is + // disallowed. + + err = trS.sendKexInit(firstKeyExchange) + if err == nil { + t.Errorf("server first kex init should reject unexpected packet") + } +} + +func TestHandshakeTwice(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // Both sides should ask for the first key exchange first. + err = trS.sendKexInit(firstKeyExchange) + if err != nil { + t.Errorf("server sendKexInit: %v", err) + } + + err = trC.sendKexInit(firstKeyExchange) + if err != nil { + t.Errorf("client sendKexInit: %v", err) + } + + sent := 0 + // send a packet + packet := make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + sent++ + + // Send another packet. Use a fresh one, since writePacket destroys. + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + sent++ + + // 2nd key change. + err = trC.sendKexInit(subsequentKeyExchange) + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + sent++ + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + for i := 0; i < sent; i++ { + msg, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + + if bytes.Compare(msg, packet) != 0 { + t.Errorf("packet %d: got %q want %q", i, msg, packet) + } + } + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, want 2", len(checker.calls)) + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + for i := 0; i < 5; i++ { + packet := make([]byte, 251) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + } + + j := 0 + for ; j < 5; j++ { + _, err := trS.readPacket() + if err != nil { + break + } + } + + if j != 5 { + t.Errorf("got %d, want 5 messages", j) + } + + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, wanted 2", len(checker.calls)) + } +} + +type syncChecker struct { + called chan int +} + +func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + t.called <- 1 + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{make(chan int, 2)} + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + // While we read out the packet, a key change will be + // initiated. + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} + +// errorKeyingTransport generates errors after a given number of +// read/write operations. +type errorKeyingTransport struct { + packetConn + readLeft, writeLeft int +} + +func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { + return nil +} +func (n *errorKeyingTransport) getSessionID() []byte { + return nil +} + +func (n *errorKeyingTransport) writePacket(packet []byte) error { + if n.writeLeft == 0 { + n.Close() + return errors.New("barf") + } + + n.writeLeft-- + return n.packetConn.writePacket(packet) +} + +func (n *errorKeyingTransport) readPacket() ([]byte, error) { + if n.readLeft == 0 { + n.Close() + return nil, errors.New("barf") + } + + n.readLeft-- + return n.packetConn.readPacket() +} + +func TestHandshakeErrorHandlingRead(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1) + } +} + +func TestHandshakeErrorHandlingWrite(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i) + } +} + +// testHandshakeErrorHandlingN runs handshakes, injecting errors. If +// handshakeTransport deadlocks, the go runtime will detect it and +// panic. +func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { + msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) + + a, b := memPipe() + defer a.Close() + defer b.Close() + + key := testSigners["ecdsa"] + serverConf := Config{RekeyThreshold: minRekeyThreshold} + serverConf.SetDefaults() + serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) + serverConn.hostKeys = []Signer{key} + go serverConn.readLoop() + + clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} + clientConf.SetDefaults() + clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) + clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} + go clientConn.readLoop() + + var wg sync.WaitGroup + wg.Add(4) + + for _, hs := range []packetConn{serverConn, clientConn} { + go func(c packetConn) { + for { + err := c.writePacket(msg) + if err != nil { + break + } + } + wg.Done() + }(hs) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + } + wg.Done() + }(hs) + } + + wg.Wait() +} + +func TestDisconnect(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + errMsg := &disconnectMsg{ + Reason: 42, + Message: "such is life", + } + trC.writePacket(Marshal(errMsg)) + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgRequestSuccess { + t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 2 succeeded") + } else if !reflect.DeepEqual(err, errMsg) { + t.Errorf("got error %#v, want %#v", err, errMsg) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 3 succeeded") + } +} diff --git a/lib/ssh/kex.go b/lib/ssh/kex.go new file mode 100644 index 0000000..3b36802 --- /dev/null +++ b/lib/ssh/kex.go @@ -0,0 +1,764 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "encoding/json" + "errors" + "io" + "math/big" + + ztoolsKeys "github.com/zmap/zgrab/ztools/keys" + + "golang.org/x/crypto/curve25519" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" + kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte `json:"H,omitempty"` + + // Shared secret. See also RFC 4253, section 8. + K []byte `json:"K,omitempty"` + + // Host key as hashed into H. + HostKey []byte `json:"-"` + + // Signature of H. + Signature []byte `json:"-"` + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash `json:"-"` + + // The session ID, which is the first H computed. This is used + // to derive key material inside the transport. + SessionID []byte `json:"session_id,omitempty"` +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +type PublicKeyJsonLog struct { + RSAHostKey *PublicKey `json:"rsa_public_key,omitempty"` + DSAHostKey *PublicKey `json:"dsa_public_key,omitempty"` + ECDSAHostKey *PublicKey `json:"ecdsa_public_key,omitempty"` + Ed25519HostKey *PublicKey `json:"ed25519_public_key,omitempty"` + CertKeyHostKey *PublicKey `json:"certkey_public_key,omitempty"` +} + +func (pkLog *PublicKeyJsonLog) AddPublicKey(pubKey PublicKey) bool { + switch pubKey.Type() { + case KeyAlgoRSA: + pkLog.RSAHostKey = &pubKey + return true + + case KeyAlgoDSA: + pkLog.DSAHostKey = &pubKey + return true + + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + pkLog.ECDSAHostKey = &pubKey + return true + + case KeyAlgoED25519: + pkLog.Ed25519HostKey = &pubKey + return true + + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01: + pkLog.CertKeyHostKey = &pubKey + return true + + default: + return false + } +} + +type ServerHostKeyJsonLog struct { + PublicKeyJsonLog + Raw []byte `json:"raw"` + Algorithm string `json:"algorithm"` + Fingerprint string `json:"fingerprint_sha256,omitempty"` + TrailingData []byte `json:"trailing_data,omitempty"` + ParseError string `json:"parse_error,omitempty"` +} + +func LogServerHostKey(sshRawKey []byte) *ServerHostKeyJsonLog { + ret := new(ServerHostKeyJsonLog) + ret.Raw = sshRawKey + tempHash := sha256.Sum256(sshRawKey) + ret.Fingerprint = hex.EncodeToString(tempHash[:]) + + keyAlgorithm, keyBytes, ok := parseString(sshRawKey) + if !ok { + ret.Algorithm = "unknown" + return ret + } + ret.Algorithm = string(keyAlgorithm) + + keyObj, rest, err := parsePubKey(keyBytes, ret.Algorithm) + if err != nil { + ret.ParseError = err.Error() + return ret + } + ret.TrailingData = rest + + ok = ret.PublicKeyJsonLog.AddPublicKey(keyObj) + if !ok { + ret.ParseError = "Cannot parse to JSON" + } + + return ret +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) + + // Create a JSON object for the kexAlgorithm group + MarshalJSON() ([]byte, error) + + // Get a new instance of this interface + // Because the base x/crypto package passes the same object to each connection + GetNew(keyType string) kexAlgorithm +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p, pMinus1 *big.Int + JsonLog dhGroupJsonLog +} +type dhGroupJsonLog struct { + Parameters *ztoolsKeys.DHParams `json:"dh_params,omitempty"` + ServerSignature *JsonSignature `json:"server_signature,omitempty"` + ServerHostKey *ServerHostKeyJsonLog `json:"server_host_key,omitempty"` +} + +func (group *dhGroup) MarshalJSON() ([]byte, error) { + group.JsonLog.Parameters.Generator = group.g + group.JsonLog.Parameters.Prime = group.p + return json.Marshal(group.JsonLog) +} + +func (group *dhGroup) GetNew(keyType string) kexAlgorithm { + ret := new(dhGroup) + ret.g = new(big.Int).SetInt64(2) + + switch keyType { + case kexAlgoDH1SHA1: + ret.p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + ret.pMinus1 = new(big.Int).Sub(ret.p, bigOne) + break + + case kexAlgoDH14SHA1: + ret.p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + ret.pMinus1 = new(big.Int).Sub(ret.p, bigOne) + break + + default: + panic("Unimplemented DH KEX selected") + } + + return ret +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + group.JsonLog.Parameters = new(ztoolsKeys.DHParams) + hashFunc := crypto.SHA1 + + var x *big.Int + for { + var err error + if x, err = rand.Int(randSource, group.pMinus1); err != nil { + return nil, err + } + if x.Sign() > 0 { + break + } + } + + if pkgConfig.Verbose { + group.JsonLog.Parameters.ClientPrivate = x + } + + X := new(big.Int).Exp(group.g, x, group.p) + + if pkgConfig.Verbose { + group.JsonLog.Parameters.ClientPublic = X + } + + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + group.JsonLog.Parameters.ServerPublic = kexDHReply.Y + group.JsonLog.ServerSignature = new(JsonSignature) + group.JsonLog.ServerSignature.Raw = kexDHReply.Signature + group.JsonLog.ServerSignature.Parsed, _, _ = parseSignatureBody(kexDHReply.Signature) + group.JsonLog.ServerHostKey = LogServerHostKey(kexDHReply.HostKey) + + kInt, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: crypto.SHA1, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + hashFunc := crypto.SHA1 + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + var y *big.Int + for { + if y, err = rand.Int(randSource, group.pMinus1); err != nil { + return + } + if y.Sign() > 0 { + break + } + } + + Y := new(big.Int).Exp(group.g, y, group.p) + kInt, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA1, + }, nil +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve + JsonLog ecdhJsonLog +} + +type ecdhJsonLog struct { + Parameters *ztoolsKeys.ECDHParams `json:"ecdh_params,omitempty"` + ServerSignature *JsonSignature `json:"server_signature,omitempty"` + ServerHostKey *ServerHostKeyJsonLog `json:"server_host_key,omitempty"` +} + +func (kex *ecdh) MarshalJSON() ([]byte, error) { + return json.Marshal(kex.JsonLog) +} + +func (kex *ecdh) GetNew(keyType string) kexAlgorithm { + ret := new(ecdh) + + switch keyType { + case kexAlgoECDH521: + ret.curve = elliptic.P521() + break + + case kexAlgoECDH384: + ret.curve = elliptic.P384() + break + + case kexAlgoECDH256: + ret.curve = elliptic.P256() + break + + default: + panic("Unimplemented ECDH KEX selected") + } + + return ret +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + kex.JsonLog.Parameters = new(ztoolsKeys.ECDHParams) + kex.JsonLog.Parameters.ServerPublic = new(ztoolsKeys.ECPoint) + if pkgConfig.Verbose { + kex.JsonLog.Parameters.ClientPublic = new(ztoolsKeys.ECPoint) + kex.JsonLog.Parameters.ClientPrivate = new(ztoolsKeys.ECDHPrivateParams) + } + + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + if pkgConfig.Verbose { + kex.JsonLog.Parameters.ClientPublic.X = ephKey.PublicKey.X + kex.JsonLog.Parameters.ClientPublic.Y = ephKey.PublicKey.Y + kex.JsonLog.Parameters.ClientPrivate.Value = ephKey.D.Bytes() + kex.JsonLog.Parameters.ClientPrivate.Length = ephKey.D.BitLen() + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + kex.JsonLog.Parameters.ServerPublic.X = x + kex.JsonLog.Parameters.ServerPublic.Y = y + kex.JsonLog.ServerHostKey = LogServerHostKey(reply.HostKey) + kex.JsonLog.ServerSignature = new(JsonSignature) + kex.JsonLog.ServerSignature.Raw = reply.Signature + kex.JsonLog.ServerSignature.Parsed, _, _ = parseSignatureBody(reply.Signature) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in RFC + // 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{curve: elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{curve: elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{curve: elliptic.P256()} + kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} + kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1} + kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256} +} + +// curve25519sha256 implements the curve25519-sha256@libssh.org key +// agreement protocol, as described in +// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt +type curve25519sha256 struct { + JsonLog curve25519sha256JsonLog +} + +type curve25519sha256JsonLog struct { + Parameters curve25519sha256JsonLogParameters `json:"curve25519_sha256_params"` + ServerSignature *JsonSignature `json:"server_signature,omitempty"` + ServerHostKey *ServerHostKeyJsonLog `json:"server_host_key,omitempty"` +} + +type curve25519sha256JsonLogParameters struct { + ClientPublic []byte `json:"client_public,omitempty"` + ClientPrivate []byte `json:"client_private,omitempty"` + ServerPublic []byte `json:"server_public,omitempty"` +} + +func (kex *curve25519sha256) MarshalJSON() ([]byte, error) { + return json.Marshal(kex.JsonLog) +} + +func (kex *curve25519sha256) GetNew(keyType string) kexAlgorithm { + return new(curve25519sha256) +} + +type curve25519KeyPair struct { + priv [32]byte + pub [32]byte +} + +func (kp *curve25519KeyPair) generate(rand io.Reader) error { + if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { + return err + } + curve25519.ScalarBaseMult(&kp.pub, &kp.priv) + return nil +} + +// curve25519Zeros is just an array of 32 zero bytes so that we have something +// convenient to compare against in order to reject curve25519 points with the +// wrong order. +var curve25519Zeros [32]byte + +func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + if pkgConfig.Verbose { + kex.JsonLog.Parameters.ClientPublic = kp.pub[:] + kex.JsonLog.Parameters.ClientPrivate = kp.priv[:] + } + + if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + kex.JsonLog.Parameters.ServerPublic = reply.EphemeralPubKey + kex.JsonLog.ServerHostKey = LogServerHostKey(reply.HostKey) + kex.JsonLog.ServerSignature = new(JsonSignature) + kex.JsonLog.ServerSignature.Raw = reply.Signature + kex.JsonLog.ServerSignature.Parsed, _, _ = parseSignatureBody(reply.Signature) + if len(reply.EphemeralPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var servPub, secret [32]byte + copy(servPub[:], reply.EphemeralPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &servPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kp.pub[:]) + writeString(h, reply.EphemeralPubKey) + + kInt := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: crypto.SHA256, + }, nil +} + +func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexInit kexECDHInitMsg + if err = Unmarshal(packet, &kexInit); err != nil { + return + } + + if len(kexInit.ClientPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + var clientPub, secret [32]byte + copy(clientPub[:], kexInit.ClientPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &clientPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexInit.ClientPubKey) + writeString(h, kp.pub[:]) + + kInt := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: kp.pub[:], + HostKey: hostKeyBytes, + Signature: sig, + } + if err := c.writePacket(Marshal(&reply)); err != nil { + return nil, err + } + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA256, + }, nil +} diff --git a/lib/ssh/kex_gex.go b/lib/ssh/kex_gex.go new file mode 100644 index 0000000..b2b6c57 --- /dev/null +++ b/lib/ssh/kex_gex.go @@ -0,0 +1,291 @@ +package xssh + +import ( + "crypto" + "crypto/rand" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math/big" + + ztoolsKeys "github.com/zmap/zgrab/ztools/keys" +) + +const ( + kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1" + kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256" +) + +// Messages +type kexDHGexGroupMsg struct { + P *big.Int `sshtype:"31"` + G *big.Int +} + +type kexDHGexInitMsg struct { + X *big.Int `sshtype:"32"` +} + +type kexDHGexReplyMsg struct { + HostKey []byte `sshtype:"33"` + Y *big.Int + Signature []byte +} + +type kexDHGexRequestMsg struct { + MinBits uint32 `sshtype:"34"` + PreferedBits uint32 + MaxBits uint32 +} + +type gexJsonLog struct { + Parameters *ztoolsKeys.DHParams `json:"dh_params,omitempty"` + ServerSignature *JsonSignature `json:"server_signature,omitempty"` + ServerHostKey *ServerHostKeyJsonLog `json:"server_host_key,omitempty"` +} + +type dhGEXSHA struct { + g, p *big.Int + hashFunc crypto.Hash + JsonLog *gexJsonLog +} + +func (gex *dhGEXSHA) GetNew(keyType string) kexAlgorithm { + switch keyType { + case kexAlgoDHGEXSHA1: + ret := new(dhGEXSHA) + ret.hashFunc = crypto.SHA1 + ret.JsonLog = new(gexJsonLog) + ret.JsonLog.Parameters = new(ztoolsKeys.DHParams) + return ret + + case kexAlgoDHGEXSHA256: + ret := new(dhGEXSHA) + ret.hashFunc = crypto.SHA256 + ret.JsonLog = new(gexJsonLog) + ret.JsonLog.Parameters = new(ztoolsKeys.DHParams) + return ret + + default: + panic("Unimplemented GEX selected") + } +} + +func (gex *dhGEXSHA) MarshalJSON() ([]byte, error) { + return json.Marshal(gex.JsonLog) +} + +func (gex *dhGEXSHA) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Sign() <= 0 || theirPublic.Cmp(gex.p) >= 0 { + return nil, fmt.Errorf("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, gex.p), nil +} + +func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + // Send GexRequest + kexDHGexRequest := kexDHGexRequestMsg{ + MinBits: uint32(pkgConfig.GexMinBits), + PreferedBits: uint32(pkgConfig.GexPreferredBits), + MaxBits: uint32(pkgConfig.GexMaxBits), + } + if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil { + return nil, err + } + + // *Receive GexGroup* + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexGroup kexDHGexGroupMsg + if err = Unmarshal(packet, &kexDHGexGroup); err != nil { + return nil, err + } + + if gex.JsonLog != nil { + gex.JsonLog.Parameters.Prime = kexDHGexGroup.P + gex.JsonLog.Parameters.Generator = kexDHGexGroup.G + } + + // reject if p's bit length < pkgConfig.GexMinBits or > pkgConfig.GexMaxBits + if kexDHGexGroup.P.BitLen() < int(pkgConfig.GexMinBits) || kexDHGexGroup.P.BitLen() > int(pkgConfig.GexMaxBits) { + return nil, fmt.Errorf("Server-generated gex p (dont't ask) is out of range (%d bits)", kexDHGexGroup.P.BitLen()) + } + + gex.p = kexDHGexGroup.P + gex.g = kexDHGexGroup.G + + // *Send GexInit + x, err := rand.Int(randSource, gex.p) + if err != nil { + return nil, err + } + X := new(big.Int).Exp(gex.g, x, gex.p) + kexDHGexInit := kexDHGexInitMsg{ + X: X, + } + + if gex.JsonLog != nil && pkgConfig.Verbose { + gex.JsonLog.Parameters.ClientPrivate = x + gex.JsonLog.Parameters.ClientPublic = X + } + + if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil { + return nil, err + } + + // Receive GexReply + packet, err = c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexReply kexDHGexReplyMsg + if err = Unmarshal(packet, &kexDHGexReply); err != nil { + return nil, err + } + + if gex.JsonLog != nil { + gex.JsonLog.Parameters.ServerPublic = kexDHGexReply.Y + gex.JsonLog.ServerSignature = new(JsonSignature) + gex.JsonLog.ServerSignature.Raw = kexDHGexReply.Signature + gex.JsonLog.ServerSignature.Parsed, _, _ = parseSignatureBody(kexDHGexReply.Signature) + gex.JsonLog.ServerHostKey = LogServerHostKey(kexDHGexReply.HostKey) + } + + kInt, err := gex.diffieHellman(kexDHGexReply.Y, x) + if err != nil { + return nil, err + } + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, kexDHGexReply.HostKey) + binary.Write(h, binary.BigEndian, uint32(pkgConfig.GexMinBits)) + binary.Write(h, binary.BigEndian, uint32(pkgConfig.GexPreferredBits)) + binary.Write(h, binary.BigEndian, uint32(pkgConfig.GexMaxBits)) + writeInt(h, gex.p) + writeInt(h, gex.g) + writeInt(h, X) + writeInt(h, kexDHGexReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHGexReply.HostKey, + Signature: kexDHGexReply.Signature, + Hash: gex.hashFunc, + }, nil +} + +func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + // *Receive GexRequest* + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHGexRequest kexDHGexRequestMsg + if err = Unmarshal(packet, &kexDHGexRequest); err != nil { + return + } + + // smoosh the user's preferred size into our own limits + if kexDHGexRequest.PreferedBits > uint32(pkgConfig.GexMaxBits) { + kexDHGexRequest.PreferedBits = uint32(pkgConfig.GexMaxBits) + } + if kexDHGexRequest.PreferedBits < uint32(pkgConfig.GexMinBits) { + kexDHGexRequest.PreferedBits = uint32(pkgConfig.GexMinBits) + } + // fix min/max if they're inconsistent. technically, we could just pout + // and hang up, but there's no harm in giving them the benefit of the + // doubt and just picking a bitsize for them. + if kexDHGexRequest.MinBits > kexDHGexRequest.PreferedBits { + kexDHGexRequest.MinBits = kexDHGexRequest.PreferedBits + } + if kexDHGexRequest.MaxBits < kexDHGexRequest.PreferedBits { + kexDHGexRequest.MaxBits = kexDHGexRequest.PreferedBits + } + + // *Send GexGroup* + // generate prime + // TODO: Not implemented yet, should load primes from /etc/ssh/moduli + gex.p, _ = new(big.Int).SetString("D1391174233D315398FE2830AC6B2B66BCCD01B0A634899F339B7879F1DB85712E9DC4E4B1C6C8355570C1D2DCB53493DF18175A9C53D1128B592B4C72D97136F5542FEB981CBFE8012FDD30361F288A42BD5EBB08BAB0A5640E1AC48763B2ABD1945FEE36B2D55E1D50A1C86CED9DD141C4E7BE2D32D9B562A0F8E2E927020E91F58B57EB9ACDDA106A59302D7E92AD5F6E851A45FA1CFE86029A0F727F65A8F475F33572E2FDAB6073F0C21B8B54C3823DB2EF068927E5D747498F96361507", 16) + gex.g = big.NewInt(5) + kexDHGexGroup := kexDHGexGroupMsg{ + P: gex.p, + G: gex.g, + } + if err := c.writePacket(Marshal(&kexDHGexGroup)); err != nil { + return nil, err + } + + // *Receive GexInit + packet, err = c.readPacket() + if err != nil { + return + } + var kexDHGexInit kexDHGexInitMsg + if err = Unmarshal(packet, &kexDHGexInit); err != nil { + return + } + + y, err := rand.Int(randSource, gex.p) + if err != nil { + return + } + + Y := new(big.Int).Exp(gex.g, y, gex.p) + kInt, err := gex.diffieHellman(kexDHGexInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + binary.Write(h, binary.BigEndian, uint32(pkgConfig.GexMinBits)) + binary.Write(h, binary.BigEndian, uint32(pkgConfig.GexPreferredBits)) + binary.Write(h, binary.BigEndian, uint32(pkgConfig.GexMaxBits)) + writeInt(h, gex.p) + writeInt(h, gex.g) + writeInt(h, kexDHGexInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H) + if err != nil { + return nil, err + } + + kexDHGexReply := kexDHGexReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHGexReply) + + err = c.writePacket(packet) + + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: gex.hashFunc, + }, nil +} diff --git a/lib/ssh/kex_test.go b/lib/ssh/kex_test.go new file mode 100644 index 0000000..61d41f0 --- /dev/null +++ b/lib/ssh/kex_test.go @@ -0,0 +1,50 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +// Key exchange tests. + +import ( + "crypto/rand" + "reflect" + "testing" +) + +func TestKexes(t *testing.T) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + a, b := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + go func() { + r, e := kex.Client(a, rand.Reader, &magics) + a.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) + b.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + if clientRes.err != nil { + t.Errorf("client: %v", clientRes.err) + } + if serverRes.err != nil { + t.Errorf("server: %v", serverRes.err) + } + if !reflect.DeepEqual(clientRes.result, serverRes.result) { + t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) + } + } +} diff --git a/lib/ssh/keys.go b/lib/ssh/keys.go new file mode 100644 index 0000000..bb86255 --- /dev/null +++ b/lib/ssh/keys.go @@ -0,0 +1,937 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "strings" + + "golang.org/x/crypto/ed25519" + + ztoolsX509 "github.com/zmap/zcrypto/x509" + ztoolsKeys "github.com/zmap/zgrab/ztools/keys" +) + +// These constants represent the algorithm names for key types supported by this +// package. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" + KeyAlgoED25519 = "ssh-ed25519" +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case KeyAlgoED25519: + return parseED25519(in) + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01: + cert, err := parseCert(in, certToPrivAlgo(algo)) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseKnownHosts parses an entry in the format of the known_hosts file. +// +// The known_hosts format is documented in the sshd(8) manual page. This +// function will parse a single entry from in. On successful return, marker +// will contain the optional marker value (i.e. "cert-authority" or "revoked") +// or else be empty, hosts will contain the hosts that this entry matches, +// pubKey will contain the public key and comment will contain any trailing +// comment at the end of the line. See the sshd(8) manual page for the various +// forms that a host string can take. +// +// The unparsed remainder of the input will be returned in rest. This function +// can be called repeatedly to parse multiple entries. +// +// If no entries were found in the input then err will be io.EOF. Otherwise a +// non-nil err value indicates a parse error. +func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + // Strip out the beginning of the known_host key. + // This is either an optional marker or a (set of) hostname(s). + keyFields := bytes.Fields(in) + if len(keyFields) < 3 || len(keyFields) > 5 { + return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") + } + + // keyFields[0] is either "@cert-authority", "@revoked" or a comma separated + // list of hosts + marker := "" + if keyFields[0][0] == '@' { + marker = string(keyFields[0][1:]) + keyFields = keyFields[1:] + } + + hosts := string(keyFields[0]) + // keyFields[1] contains the key type (e.g. “ssh-rsa”). + // However, that information is duplicated inside the + // base64-encoded key and so is ignored here. + + key := bytes.Join(keyFields[2:], []byte(" ")) + if pubKey, comment, err = parseAuthorizedKey(key); err != nil { + return "", nil, nil, "", nil, err + } + + return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil + } + + return "", nil, nil, "", nil, io.EOF +} + +// ParseAuthorizedKeys parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + // Type returns the key's type, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, + // with the name prefix. + Marshal() []byte + + // Verify that sig is a signature on the given data using this + // key. This function will hash the data appropriately first. + Verify(data []byte, sig *Signature) error + + // Create the JSON representation for ZGrab + MarshalJSON() ([]byte, error) +} + +// CryptoPublicKey, if implemented by a PublicKey, +// returns the underlying crypto.PublicKey form of the key. +type CryptoPublicKey interface { + CryptoPublicKey() crypto.PublicKey +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + // PublicKey returns an associated PublicKey instance. + PublicKey() PublicKey + + // Sign returns raw signature for the given data. This method + // will apply the hash specified for the keytype to the data. + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +func (r *rsaPublicKey) MarshalJSON() ([]byte, error) { + ret := new(ztoolsKeys.RSAPublicKey) + ret.PublicKey = (*rsa.PublicKey)(r) + return json.Marshal(ret) +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + // RSA publickey struct layout should match the struct used by + // parseRSACert in the github.com/zmap/zgrab/ztools/xssh/agent package. + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != r.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) +} + +func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*rsa.PublicKey)(r) +} + +type dsaPublicKey dsa.PublicKey + +func (r *dsaPublicKey) Type() string { + return "ssh-dss" +} + +func (r *dsaPublicKey) MarshalJSON() ([]byte, error) { + temp := make(map[string]interface{}) + ztoolsX509.AddDSAPublicKeyToKeyMap(temp, (*dsa.PublicKey)(r)) + return json.Marshal(temp) +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + }, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + // DSA publickey struct layout should match the struct used by + // parseDSACert in the x/github.com/zmap/zgrab/ztools/xssh/agent package. + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*dsa.PublicKey)(k) +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (key *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + key.nistID() +} + +func (key *ecdsaPublicKey) MarshalJSON() ([]byte, error) { + temp := make(map[string]interface{}) + ztoolsX509.AddECDSAPublicKeyToKeyMap(temp, (*ecdsa.PublicKey)(key)) + return json.Marshal(temp) +} + +func (key *ecdsaPublicKey) nistID() string { + switch key.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +type ed25519PublicKey ed25519.PublicKey + +func (key ed25519PublicKey) Type() string { + return KeyAlgoED25519 +} + +func (key ed25519PublicKey) MarshalJSON() ([]byte, error) { + temp := make(map[string]interface{}) + baseKey := (ed25519.PublicKey)(key) + temp["public_bytes"] = ([]byte)(baseKey) + return json.Marshal(temp) +} + +func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := ed25519.PublicKey(w.KeyBytes) + + return (ed25519PublicKey)(key), w.Rest, nil +} + +func (key ed25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + }{ + KeyAlgoED25519, + []byte(key), + } + return Marshal(&w) +} + +func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { + if sig.Format != key.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) + } + + edKey := (ed25519.PublicKey)(key) + if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return ed25519.PublicKey(k) +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(ecdsa.PublicKey) + + switch w.Curve { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), w.Rest, nil +} + +func (key *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) + // ECDSA publickey struct layout should match the struct used by + // parseECDSACert in the github.com/zmap/zgrab/ztools/xssh/agent package. + w := struct { + Name string + ID string + Key []byte + }{ + key.Type(), + key.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != key.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) + } + + h := ecHash(key.Curve).New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*ecdsa.PublicKey)(k) +} + +// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, +// *ecdsa.PrivateKey or any other crypto.Signer and returns a corresponding +// Signer instance. ECDSA keys must use P-256, P-384 or P-521. +func NewSignerFromKey(key interface{}) (Signer, error) { + switch key := key.(type) { + case crypto.Signer: + return NewSignerFromSigner(key) + case *dsa.PrivateKey: + return &dsaPrivateKey{key}, nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +type wrappedSigner struct { + signer crypto.Signer + pubKey PublicKey +} + +// NewSignerFromSigner takes any crypto.Signer implementation and +// returns a corresponding Signer interface. This can be used, for +// example, with keys kept in hardware modules. +func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { + pubKey, err := NewPublicKey(signer.Public()) + if err != nil { + return nil, err + } + + return &wrappedSigner{signer, pubKey}, nil +} + +func (s *wrappedSigner) PublicKey() PublicKey { + return s.pubKey +} + +func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + var hashFunc crypto.Hash + + switch key := s.pubKey.(type) { + case *rsaPublicKey, *dsaPublicKey: + hashFunc = crypto.SHA1 + case *ecdsaPublicKey: + hashFunc = ecHash(key.Curve) + case ed25519PublicKey: + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } + + var digest []byte + if hashFunc != 0 { + h := hashFunc.New() + h.Write(data) + digest = h.Sum(nil) + } else { + digest = data + } + + signature, err := s.signer.Sign(rand, digest, hashFunc) + if err != nil { + return nil, err + } + + // crypto.Signer.Sign is expected to return an ASN.1-encoded signature + // for ECDSA and DSA, but that's not the encoding expected by SSH, so + // re-encode. + switch s.pubKey.(type) { + case *ecdsaPublicKey, *dsaPublicKey: + type asn1Signature struct { + R, S *big.Int + } + asn1Sig := new(asn1Signature) + _, err := asn1.Unmarshal(signature, asn1Sig) + if err != nil { + return nil, err + } + + switch s.pubKey.(type) { + case *ecdsaPublicKey: + signature = Marshal(asn1Sig) + + case *dsaPublicKey: + signature = make([]byte, 40) + r := asn1Sig.R.Bytes() + s := asn1Sig.S.Bytes() + copy(signature[20-len(r):20], r) + copy(signature[40-len(s):40], s) + } + } + + return &Signature{ + Format: s.pubKey.Type(), + Blob: signature, + }, nil +} + +// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, +// or ed25519.PublicKey returns a corresponding PublicKey instance. +// ECDSA keys must use P-256, P-384 or P-521. +func NewPublicKey(key interface{}) (PublicKey, error) { + switch key := key.(type) { + case *rsa.PublicKey: + return (*rsaPublicKey)(key), nil + case *ecdsa.PublicKey: + if !supportedEllipticCurve(key.Curve) { + return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") + } + return (*ecdsaPublicKey)(key), nil + case *dsa.PublicKey: + return (*dsaPublicKey)(key), nil + case ed25519.PublicKey: + return (ed25519PublicKey)(key), nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// encryptedBlock tells whether a private key is +// encrypted by examining its Proc-Type header +// for a mention of ENCRYPTED +// according to RFC 1421 Section 4.6.1.1. +func encryptedBlock(block *pem.Block) bool { + return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It +// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if encryptedBlock(block) { + return nil, errors.New("ssh: cannot decode encrypted private keys") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + case "OPENSSH PRIVATE KEY": + return parseOpenSSHPrivateKey(block.Bytes) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Priv *big.Int + Pub *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Priv, + }, + X: k.Pub, + }, nil +} + +// Implemented based on the documentation at +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key +func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { + magic := append([]byte("openssh-key-v1"), 0) + if !bytes.Equal(magic, key[0:len(magic)]) { + return nil, errors.New("ssh: invalid openssh private key format") + } + remaining := key[len(magic):] + + var w struct { + CipherName string + KdfName string + KdfOpts string + NumKeys uint32 + PubKey []byte + PrivKeyBlock []byte + } + + if err := Unmarshal(remaining, &w); err != nil { + return nil, err + } + + pk1 := struct { + Check1 uint32 + Check2 uint32 + Keytype string + Pub []byte + Priv []byte + Comment string + Pad []byte `ssh:"rest"` + }{} + + if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { + return nil, err + } + + if pk1.Check1 != pk1.Check2 { + return nil, errors.New("ssh: checkint mismatch") + } + + // we only handle ed25519 keys currently + if pk1.Keytype != KeyAlgoED25519 { + return nil, errors.New("ssh: unhandled key type") + } + + for i, b := range pk1.Pad { + if int(b) != i+1 { + return nil, errors.New("ssh: padding not as expected") + } + } + + if len(pk1.Priv) != ed25519.PrivateKeySize { + return nil, errors.New("ssh: private key unexpected length") + } + + pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) + copy(pk, pk1.Priv) + return &pk, nil +} + +// FingerprintLegacyMD5 returns the user presentation of the key's +// fingerprint as described by RFC 4716 section 4. +func FingerprintLegacyMD5(pubKey PublicKey) string { + md5sum := md5.Sum(pubKey.Marshal()) + hexarray := make([]string, len(md5sum)) + for i, c := range md5sum { + hexarray[i] = hex.EncodeToString([]byte{c}) + } + return strings.Join(hexarray, ":") +} + +// FingerprintSHA256 returns the user presentation of the key's +// fingerprint as unpadded base64 encoded sha256 hash. +// This format was introduced from OpenSSH 6.8. +// https://www.openssh.com/txt/release-6.8 +// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding) +func FingerprintSHA256(pubKey PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} diff --git a/lib/ssh/keys_test.go b/lib/ssh/keys_test.go new file mode 100644 index 0000000..4006553 --- /dev/null +++ b/lib/ssh/keys_test.go @@ -0,0 +1,480 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/zmap/zgrab/ztools/xssh/testdata" + "golang.org/x/crypto/ed25519" +) + +func rawKey(pub PublicKey) interface{} { + switch k := pub.(type) { + case *rsaPublicKey: + return (*rsa.PublicKey)(k) + case *dsaPublicKey: + return (*dsa.PublicKey)(k) + case *ecdsaPublicKey: + return (*ecdsa.PublicKey)(k) + case ed25519PublicKey: + return (ed25519.PublicKey)(k) + case *Certificate: + return k + } + panic("unknown key type") +} + +func TestKeyMarshalParse(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) + } + + if cert, ok := roundtrip.(*Certificate); ok { + // Certificate.SignatureRaw is an attribute added + // explicitly to smuggle out data for zgrab logging. + cert.SignatureRaw = nil + } + + k1 := rawKey(pub) + k2 := rawKey(roundtrip) + + if !reflect.DeepEqual(k1, k2) { + t.Errorf("got %#v in roundtrip, want %#v", k2, k1) + } + } +} + +func TestUnsupportedCurves(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err) + } + + if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err) + } +} + +func TestNewPublicKey(t *testing.T) { + for _, k := range testSigners { + raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } + pub, err := NewPublicKey(raw) + if err != nil { + t.Errorf("NewPublicKey(%#v): %v", raw, err) + } + if !reflect.DeepEqual(k.PublicKey(), pub) { + t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) + } + } +} + +func TestKeySignVerify(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + + data := []byte("sign me") + sig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } +} + +func TestParseRSAPrivateKey(t *testing.T) { + key := testPrivateKeys["rsa"] + + rsa, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *rsa.PrivateKey", rsa) + } + + if err := rsa.Validate(); err != nil { + t.Errorf("Validate: %v", err) + } +} + +func TestParseECPrivateKey(t *testing.T) { + key := testPrivateKeys["ecdsa"] + + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) + } + + if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { + t.Fatalf("public key does not validate.") + } +} + +// See Issue https://github.com/golang/go/issues/6650. +func TestParseEncryptedPrivateKeysFails(t *testing.T) { + const wantSubstring = "encrypted" + for i, tt := range testdata.PEMEncryptedKeys { + _, err := ParsePrivateKey(tt.PEMBytes) + if err == nil { + t.Errorf("#%d key %s: ParsePrivateKey successfully parsed, expected an error", i, tt.Name) + continue + } + + if !strings.Contains(err.Error(), wantSubstring) { + t.Errorf("#%d key %s: got error %q, want substring %q", i, tt.Name, err, wantSubstring) + } + } +} + +func TestParseDSA(t *testing.T) { + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) + if err != nil { + t.Fatalf("ParsePrivateKey returned error: %s", err) + } + + data := []byte("sign me") + sig, err := s.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("dsa.Sign: %v", err) + } + + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +type authResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { + rest := authKeys + var values []authResult + for len(rest) > 0 { + var r authResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []authResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []authResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) + } +} + +var knownHostsParseTests = []struct { + input string + err string + + marker string + comment string + hosts []string + rest string +}{ + { + "", + "EOF", + + "", "", nil, "", + }, + { + "# Just a comment", + "EOF", + + "", "", nil, "", + }, + { + " \t ", + "EOF", + + "", "", nil, "", + }, + { + "localhost ssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line", + "", + + "", "comment comment", []string{"localhost"}, "next line", + }, + { + "localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}", + "", + + "marker", "", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd", + "short read", + + "", "", nil, "", + }, +} + +func TestKnownHostsParsing(t *testing.T) { + rsaPub, rsaPubSerialized := getTestKey() + + for i, test := range knownHostsParseTests { + var expectedKey PublicKey + const rsaKeyToken = "{RSAPUB}" + + input := test.input + if strings.Contains(input, rsaKeyToken) { + expectedKey = rsaPub + input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1) + } + + marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input)) + if err != nil { + if len(test.err) == 0 { + t.Errorf("#%d: unexpectedly failed with %q", i, err) + } else if !strings.Contains(err.Error(), test.err) { + t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err) + } + continue + } else if len(test.err) != 0 { + t.Errorf("#%d: succeeded but expected error including %q", i, test.err) + continue + } + + if !reflect.DeepEqual(expectedKey, pubKey) { + t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey) + } + + if marker != test.marker { + t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker) + } + + if comment != test.comment { + t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment) + } + + if !reflect.DeepEqual(test.hosts, hosts) { + t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts) + } + + if rest := string(rest); rest != test.rest { + t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest) + } + } +} + +func TestFingerprintLegacyMD5(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintLegacyMD5(pub) + want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestFingerprintSHA256(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintSHA256(pub) + want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} diff --git a/lib/ssh/log.go b/lib/ssh/log.go new file mode 100644 index 0000000..163f607 --- /dev/null +++ b/lib/ssh/log.go @@ -0,0 +1,35 @@ +/* + * ZGrab Copyright 2015 Regents of the University of Michigan + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package xssh + +// HandshakeLog contains detailed information about each step of the +// SSH handshake, and can be encoded to JSON. +type HandshakeLog struct { + ServerID *EndpointId `json:"server_id,omitempty"` + ClientID *EndpointId `json:"client_id,omitempty"` + ServerKex *kexInitMsg `json:"server_key_exchange,omitempty"` + ClientKex *kexInitMsg `json:"client_key_exchange,omitempty"` + AlgorithmSelection *algorithms `json:"algorithm_selection,omitempty"` + DHKeyExchange kexAlgorithm `json:"key_exchange,omitempty"` + UserAuth []string `json:"userauth,omitempty"` + Crypto *kexResult `json:"crypto,omitempty"` +} + +type EndpointId struct { + Raw string `json:"raw,omitempty"` + ProtoVersion string `json:"version,omitempty"` + SoftwareVersion string `json:"software,omitempty"` + Comment string `json:"comment,omitempty"` +} diff --git a/lib/ssh/mac.go b/lib/ssh/mac.go new file mode 100644 index 0000000..d037bf2 --- /dev/null +++ b/lib/ssh/mac.go @@ -0,0 +1,57 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "hash" +) + +type macMode struct { + keySize int + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha2-256": {32, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha1": {20, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/lib/ssh/mempipe_test.go b/lib/ssh/mempipe_test.go new file mode 100644 index 0000000..3c44a89 --- /dev/null +++ b/lib/ssh/mempipe_test.go @@ -0,0 +1,110 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "io" + "sync" + "testing" +) + +// An in-memory packetConn. It is safe to call Close and writePacket +// from different goroutines. +type memTransport struct { + eof bool + pending [][]byte + write *memTransport + sync.Mutex + *sync.Cond +} + +func (t *memTransport) readPacket() ([]byte, error) { + t.Lock() + defer t.Unlock() + for { + if len(t.pending) > 0 { + r := t.pending[0] + t.pending = t.pending[1:] + return r, nil + } + if t.eof { + return nil, io.EOF + } + t.Cond.Wait() + } +} + +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { + return io.EOF + } + t.eof = true + t.Cond.Broadcast() + return nil +} + +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + +func (t *memTransport) writePacket(p []byte) error { + t.write.Lock() + defer t.write.Unlock() + if t.write.eof { + return io.EOF + } + c := make([]byte, len(p)) + copy(c, p) + t.write.pending = append(t.write.pending, c) + t.write.Cond.Signal() + return nil +} + +func memPipe() (a, b packetConn) { + t1 := memTransport{} + t2 := memTransport{} + t1.write = &t2 + t2.write = &t1 + t1.Cond = sync.NewCond(&t1.Mutex) + t2.Cond = sync.NewCond(&t2.Mutex) + return &t1, &t2 +} + +func TestMemPipe(t *testing.T) { + a, b := memPipe() + if err := a.writePacket([]byte{42}); err != nil { + t.Fatalf("writePacket: %v", err) + } + if err := a.Close(); err != nil { + t.Fatal("Close: ", err) + } + p, err := b.readPacket() + if err != nil { + t.Fatal("readPacket: ", err) + } + if len(p) != 1 || p[0] != 42 { + t.Fatalf("got %v, want {42}", p) + } + p, err = b.readPacket() + if err != io.EOF { + t.Fatalf("got %v, %v, want EOF", p, err) + } +} + +func TestDoubleClose(t *testing.T) { + a, _ := memPipe() + err := a.Close() + if err != nil { + t.Errorf("Close: %v", err) + } + err = a.Close() + if err != io.EOF { + t.Errorf("expect EOF on double close.") + } +} diff --git a/lib/ssh/messages.go b/lib/ssh/messages.go new file mode 100644 index 0000000..98ddd31 --- /dev/null +++ b/lib/ssh/messages.go @@ -0,0 +1,794 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 + + // Standard authentication messages + msgUserAuthSuccess = 52 + msgUserAuthBanner = 53 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +type JsonKexInitMsg struct { + Cookie []byte `json:"cookie,omitempty"` + KexAlgos []string `json:"kex_algorithms,omitempty"` + ServerHostKeyAlgos []string `json:"host_key_algorithms,omitempty"` + CiphersClientServer []string `json:"client_to_server_ciphers,omitempty"` + CiphersServerClient []string `json:"server_to_client_ciphers,omitempty"` + MACsClientServer []string `json:"client_to_server_macs,omitempty"` + MACsServerClient []string `json:"server_to_client_macs,omitempty"` + CompressionClientServer []string `json:"client_to_server_compression,omitempty"` + CompressionServerClient []string `json:"server_to_client_compression,omitempty"` + LanguagesClientServer []string `json:"client_to_server_languages,omitempty"` + LanguagesServerClient []string `json:"server_to_client_languages,omitempty"` + FirstKexFollows bool `json:"first_kex_follows"` + Reserved uint32 `json:"reserved"` +} + +func (kex *kexInitMsg) MarshalJSON() ([]byte, error) { + temp := JsonKexInitMsg{ + Cookie: kex.Cookie[:], + KexAlgos: kex.KexAlgos, + ServerHostKeyAlgos: kex.ServerHostKeyAlgos, + CiphersClientServer: kex.CiphersClientServer, + CiphersServerClient: kex.CiphersServerClient, + MACsClientServer: kex.MACsClientServer, + MACsServerClient: kex.MACsServerClient, + CompressionClientServer: kex.CompressionClientServer, + CompressionServerClient: kex.CompressionServerClient, + LanguagesClientServer: kex.LanguagesClientServer, + LanguagesServerClient: kex.LanguagesServerClient, + FirstKexFollows: kex.FirstKexFollows, + Reserved: kex.Reserved, + } + return json.Marshal(temp) +} + +// See RFC 4253, section 8. + +// Diffie-Helman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// Used for debug printouts of packets. +type userAuthSuccessMsg struct { +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + User string `sshtype:"60"` + Instruction string + DeprecatedLanguage string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersId uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersId uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersId uint32 `sshtype:"91"` + MyId uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersId uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersId uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersId uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersId uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersId uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersId uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersId uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// typeTags returns the possible type bytes for the given reflect.Type, which +// should be a struct. The possible values are separated by a '|' character. +func typeTags(structType reflect.Type) (tags []byte) { + tagStr := structType.Field(0).Tag.Get("sshtype") + + for _, tag := range strings.Split(tagStr, "|") { + i, err := strconv.Atoi(tag) + if err == nil { + tags = append(tags, byte(i)) + } + } + + return tags +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a '|'-separated set of numbers +// in decimal, the packet must start with one of those numbers. In +// case of error, Unmarshal returns a ParseError or +// UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedTypes := typeTags(structType) + + var expectedType byte + if len(expectedTypes) > 0 { + expectedType = expectedTypes[0] + } + + if len(data) == 0 { + return parseError(expectedType) + } + + if len(expectedTypes) > 0 { + goodType := false + for _, e := range expectedTypes { + if e > 0 && data[0] == e { + goodType = true + break + } + } + if !goodType { + return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgTypes := typeTags(v.Type()) + if len(msgTypes) > 0 { + out = append(out, msgTypes[0]) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthSuccess: + return new(userAuthSuccessMsg), nil + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelData: + msg = new(channelDataMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} diff --git a/lib/ssh/messages_test.go b/lib/ssh/messages_test.go new file mode 100644 index 0000000..c6ac051 --- /dev/null +++ b/lib/ssh/messages_test.go @@ -0,0 +1,288 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break + } + } +} + +func TestUnmarshalEmptyPacket(t *testing.T) { + var b []byte + var m channelRequestSuccessMsg + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") + } +} + +func TestUnmarshalUnexpectedPacket(t *testing.T) { + type S struct { + I uint32 `sshtype:"43"` + S string + B bool + } + + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 + roundtrip := S{} + err := Unmarshal(packet, &roundtrip) + if err == nil { + t.Fatal("expected error, not nil") + } +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) + } +} + +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := Marshal(s) + roundtrip := S{} + Unmarshal(packet, &roundtrip) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := Marshal(s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + +func TestUnmarshalShortKexInitPacket(t *testing.T) { + // This used to panic. + // Issue 11348 + packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} + kim := &kexInitMsg{} + if err := Unmarshal(packet, kim); err == nil { + t.Error("truncated packet unmarshaled without error") + } +} + +func TestMarshalMultiTag(t *testing.T) { + var res struct { + A uint32 `sshtype:"1|2"` + } + + good1 := struct { + A uint32 `sshtype:"1"` + }{ + 1, + } + good2 := struct { + A uint32 `sshtype:"2"` + }{ + 1, + } + + if e := Unmarshal(Marshal(good1), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + if e := Unmarshal(Marshal(good2), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + bad1 := struct { + A uint32 `sshtype:"3"` + }{ + 1, + } + if e := Unmarshal(Marshal(bad1), &res); e == nil { + t.Errorf("bad struct unmarshaled without error") + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} + +var ( + _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) +) + +func BenchmarkMarshalKexInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexInitMsg) + } +} + +func BenchmarkUnmarshalKexInitMsg(b *testing.B) { + m := new(kexInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexInit, m) + } +} + +func BenchmarkMarshalKexDHInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexDHInitMsg) + } +} + +func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { + m := new(kexDHInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexDHInit, m) + } +} diff --git a/lib/ssh/mux.go b/lib/ssh/mux.go new file mode 100644 index 0000000..b00e1d9 --- /dev/null +++ b/lib/ssh/mux.go @@ -0,0 +1,330 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, 16), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, 16), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + if debugMux { + log.Printf("send global(%d): %#v", m.chanList.offset, msg) + } + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return fmt.Errorf("ssh: invalid channel %d", id) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersId: msg.PeersId, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersId: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} diff --git a/lib/ssh/mux_test.go b/lib/ssh/mux_test.go new file mode 100644 index 0000000..c8164ac --- /dev/null +++ b/lib/ssh/mux_test.go @@ -0,0 +1,502 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "io" + "io/ioutil" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Fatalf("No incoming channel") + } + if newCh.ChannelType() != "chan" { + t.Fatalf("got type %q want chan", newCh.ChannelType()) + } + ch, _, err := newCh.Accept() + if err != nil { + t.Fatalf("Accept %v", err) + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + + return <-res, ch, c +} + +// Test that stderr and stdout can be addressed from different +// goroutines. This is intended for use with the race detector. +func TestMuxChannelExtendedThreadSafety(t *testing.T) { + writer, reader, mux := channelPair(t) + defer writer.Close() + defer reader.Close() + defer mux.Close() + + var wr, rd sync.WaitGroup + magic := "hello world" + + wr.Add(2) + go func() { + io.WriteString(writer, magic) + wr.Done() + }() + go func() { + io.WriteString(writer.Stderr(), magic) + wr.Done() + }() + + rd.Add(2) + go func() { + c, err := ioutil.ReadAll(reader) + if string(c) != magic { + t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + go func() { + c, err := ioutil.ReadAll(reader.Stderr()) + if string(c) != magic { + t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + + wr.Wait() + writer.CloseWrite() + rd.Wait() +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + go func() { + _, err := s.Write([]byte(magic)) + if err != nil { + t.Fatalf("Write: %v", err) + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Fatalf("Write: %v", err) + } + err = s.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + wDone <- 1 + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } + <-wDone +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() + <-wDone +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() + <-wDone +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + go func() { + ch, ok := <-server.incomingChannels + if !ok { + t.Fatalf("Accept") + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxGlobalRequest(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + var seen bool + go func() { + for r := range serverMux.incomingRequests { + seen = seen || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if !seen { + t.Errorf("never saw 'peek' request") + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := ioutil.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + go a.SendRequest("hello", false, nil) + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } +} diff --git a/lib/ssh/server.go b/lib/ssh/server.go new file mode 100644 index 0000000..4f98e7f --- /dev/null +++ b/lib/ssh/server.go @@ -0,0 +1,567 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strings" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a +// user. Permissions, except for "source-address", must be enforced in +// the server application layer, after successful authentication. The +// Permissions are passed on in ServerConn so a server implementation +// can honor them. +type Permissions struct { + // Critical options restrict default permissions. Common + // restrictions are "source-address" and "force-command". If + // the server cannot enforce the restriction, or does not + // recognize it, the user should not authenticate. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Common extensions are + // "permit-agent-forwarding", "permit-X11-forwarding". Lack of + // support for an extension does not preclude authenticating a + // user. + Extensions map[string]string +} + +// The Critical Options and Extensions as stated in PROTOCOL.certkeys as of 3May2016 +// http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD +var allKnownCriticalOptions = []string{"force-command", "source-address"} +var allKnownExtensions = []string{"permit-X11-forwarding", "permit-agent-forwarding", "permit-port-forwarding", "permit-pty", "permit-user-rc"} + +type JsonCriticalOptions struct { + options map[string]string +} + +func (jOptions *JsonCriticalOptions) MarshalJSON() ([]byte, error) { + knownOpt := make(map[string]string) + var unknownOpt []string + + for key, value := range jOptions.options { + isKnown := false + for _, known := range allKnownCriticalOptions { + if key == known { + isKnown = true + break + } + } + + jsonSafeKey := strings.Replace(key, "-", "_", -1) + if isKnown { + knownOpt[jsonSafeKey] = value + } else { + unknownOpt = append(unknownOpt, fmt.Sprintf("%s : %s", jsonSafeKey, value)) + } + } + + temp := make(map[string]interface{}) + if len(knownOpt) > 0 { + temp["known"] = knownOpt + } + if len(unknownOpt) > 0 { + temp["unknown"] = unknownOpt + } + + return json.Marshal(temp) +} + +type JsonExtensions struct { + extensions map[string]string +} + +func (ext *JsonExtensions) MarshalJSON() ([]byte, error) { + knownExt := make(map[string]string) + var unknownExt []string + + for key, value := range ext.extensions { + isKnown := false + for _, known := range allKnownExtensions { + if key == known { + isKnown = true + break + } + } + + jsonSafeKey := strings.Replace(key, "-", "_", -1) + if isKnown { + knownExt[jsonSafeKey] = value + } else { + unknownExt = append(unknownExt, fmt.Sprintf("%s : %s", jsonSafeKey, value)) + } + } + + temp := make(map[string]interface{}) + if len(knownExt) > 0 { + temp["known"] = knownExt + } + if len(unknownExt) > 0 { + temp["unknown"] = unknownExt + } + + return json.Marshal(temp) +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + hostKeys []Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + NoClientAuth bool + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client attempts public + // key authentication. It must return true if the given public key is + // valid for the given user. For example, see CertChecker.Authenticate. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // ServerVersion is the version identification string to announce in + // the public handshake. + // If empty, a reasonable default is used. + // Note that RFC 4253 section 4.2 requires that this string start with + // "SSH-2.0-". + ServerVersion string +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same algorithm, it is overwritten. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + for i, k := range s.hostKeys { + if k.PublicKey().Type() == key.PublicKey().Type() { + s.hostKeys[i] = key + return + } + } + + s.hostKeys = append(s.hostKeys, key) +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +const maxCachedPubKeys = 16 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) < maxCachedPubKeys { + c.keys = append(c.keys, candidate) + } +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. +func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { + sig, err := k.Sign(rand, data) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.requestInitialKeyChange(); err != nil { + return nil, err + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func isAcceptableAlgo(algo string) bool { + switch algo { + case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: + return true + } + return false +} + +func checkSourceAddress(addr net.Addr, sourceAddr string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if bytes.Equal(allowedIP, tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + var err error + var cache pubKeyCache + var perms *Permissions + +userAuthLoop: + for { + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + s.user = userAuthReq.User + perms = nil + authErr := errors.New("no auth passed yet") + + switch userAuthReq.Method { + case "none": + if config.NoClientAuth { + authErr = nil + } + case "password": + if config.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = config.PasswordCallback(s, password) + case "keyboard-interactive": + if config.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configubred") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if config.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !isAcceptableAlgo(algo) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) + if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + candidate.result = checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]) + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + if candidate.result == nil { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !isAcceptableAlgo(sig.Format) { + break + } + signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + if authErr == nil { + break userAuthLoop + } + + var failureMsg userAuthFailureMsg + if config.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if config.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if config.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/lib/ssh/session.go b/lib/ssh/session.go new file mode 100644 index 0000000..5a9ec18 --- /dev/null +++ b/lib/ssh/session.go @@ -0,0 +1,627 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of ioutil.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return errors.New("ssh: could not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + wm.status = int(binary.BigEndian.Uint32(msg.Payload)) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + // signal was not sent either. RFC 4254 + // section 6.10 recommends against this + // behavior, but it is allowed, so we let + // clients handle it. + return &ExitMissingError{} + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + + return &ExitError{wm} +} + +// ExitMissingError is returned if a session is torn down cleanly, but +// the server sends no confirmation of the exit status. +type ExitMissingError struct{} + +func (e *ExitMissingError) Error() string { + return "wait: remote command exited without exit status or exit signal" +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + str := fmt.Sprintf("Process exited with status %v", w.status) + if w.signal != "" { + str += fmt.Sprintf(" from signal %v", w.signal) + } + if w.msg != "" { + str += fmt.Sprintf(". Reason was: %v", w.msg) + } + return str +} diff --git a/lib/ssh/session_test.go b/lib/ssh/session_test.go new file mode 100644 index 0000000..62cb64a --- /dev/null +++ b/lib/ssh/session_test.go @@ -0,0 +1,770 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +// Session tests. + +import ( + "bytes" + crypto_rand "crypto/rand" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "testing" + + "github.com/zmap/zgrab/ztools/xssh/terminal" +) + +type serverType func(Channel, <-chan *Request, *testing.T) + +// dial constructs a new test server and returns a *ClientConn. +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + go func() { + defer c1.Close() + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(testSigners["rsa"]) + + _, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("Unable to handshake: %v", err) + } + go DiscardRequests(reqs) + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue + } + + ch, inReqs, err := newCh.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + continue + } + go func() { + handler(ch, inReqs, t) + }() + } + }() + + config := &ClientConfig{ + User: "testuser", + } + + conn, chans, reqs, err := NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("unable to dial remote side: %v", err) + } + + return NewClient(conn, chans, reqs) +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// Test that a simple string is returned via the Output helper, +// and that stderr is discarded. +func TestSessionOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.Output("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + w := "this-is-stdout." + g := string(buf) + if g != w { + t.Error("Remote command did not return expected string:") + t.Logf("want %q", w) + t.Logf("got %q", g) + } +} + +// Test that both stdout and stderr are returned +// via the CombinedOutput helper. +func TestSessionCombinedOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + const stdout = "this-is-stdout." + const stderr = "this-is-stderr." + g := string(buf) + if g != stdout+stderr && g != stderr+stdout { + t.Error("Remote command did not return expected string:") + t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) + t.Logf("got %q", g) + } +} + +// Test non-0 exit status is returned correctly. +func TestExitStatusNonZero(t *testing.T) { + conn := dial(exitStatusNonZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) + } +} + +// Test 0 exit status is returned correctly. +func TestExitStatusZero(t *testing.T) { + conn := dial(exitStatusZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +// Test exit signal and status are both returned correctly. +func TestExitSignalAndStatus(t *testing.T) { + conn := dial(exitSignalAndStatusHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestKnownExitSignalOnly(t *testing.T) { + conn := dial(exitSignalHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 143 { + t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestUnknownExitSignal(t *testing.T) { + conn := dial(exitSignalUnknownHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "SYS" || e.ExitStatus() != 128 { + t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +func TestExitWithoutStatusOrSignal(t *testing.T) { + conn := dial(exitWithoutSignalOrStatus, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + if _, ok := err.(*ExitMissingError); !ok { + t.Fatalf("got %T want *ExitMissingError", err) + } +} + +// windowTestBytes is the number of bytes that we'll send to the SSH server. +const windowTestBytes = 16000 * 200 + +// TestServerWindow writes random data to the server. The server is expected to echo +// the same data back, which is compared against the original. +func TestServerWindow(t *testing.T) { + origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) + origBytes := origBuf.Bytes() + + conn := dial(echoHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + result := make(chan []byte) + + go func() { + defer close(result) + echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + serverStdout, err := session.StdoutPipe() + if err != nil { + t.Errorf("StdoutPipe failed: %v", err) + return + } + n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) + if err != nil && err != io.EOF { + t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) + } + result <- echoedBuf.Bytes() + }() + + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) + if err != nil { + t.Fatalf("failed to copy origBuf to serverStdin: %v", err) + } + if written != windowTestBytes { + t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes) + } + + echoedBytes := <-result + + if !bytes.Equal(origBytes, echoedBytes) { + t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) + } +} + +// Verify the client can handle a keepalive packet from the server. +func TestClientHandlesKeepalives(t *testing.T) { + conn := dial(channelKeepaliveSender, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got: %v", err) + } +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Errmsg string + Lang string +} + +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) + } +} + +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(0, ch, t) +} + +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) +} + +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) + sendSignal("TERM", ch, t) +} + +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("TERM", ch, t) +} + +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("SYS", ch, t) +} + +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) +} + +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "golang") + readLine(shell, t) + sendStatus(0, ch, t) +} + +// Ignores the command, writes fixed strings to stderr and stdout. +// Strings are "this-is-stdout." and "this-is-stderr.". +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + _, err := ch.Read(nil) + + req, ok := <-in + if !ok { + t.Fatalf("error: expected channel request, got: %#v", err) + return + } + + // ignore request, always send some text + req.Reply(true, nil) + + _, err = io.WriteString(ch, "this-is-stdout.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + sendStatus(0, ch, t) +} + +func readLine(shell *terminal.Terminal, t *testing.T) { + if _, err := shell.ReadLine(); err != nil && err != io.EOF { + t.Errorf("unable to read line: %v", err) + } +} + +func sendStatus(status uint32, ch Channel, t *testing.T) { + msg := exitStatusMsg{ + Status: status, + } + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { + t.Errorf("unable to send status: %v", err) + } +} + +func sendSignal(signal string, ch Channel, t *testing.T) { + sig := exitSignalMsg{ + Signal: signal, + CoreDumped: false, + Errmsg: "Process terminated", + Lang: "en-GB-oed", + } + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { + t.Errorf("unable to send signal: %v", err) + } +} + +func discardHandler(ch Channel, t *testing.T) { + defer ch.Close() + io.Copy(ioutil.Discard, ch) +} + +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { + t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) + } +} + +// copyNRandomly copies n bytes from src to dst. It uses a variable, and random, +// buffer size to exercise more code paths. +func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { + var ( + buf = make([]byte, 32*1024) + written int + remaining = n + ) + for remaining > 0 { + l := rand.Intn(1 << 15) + if remaining < l { + l = remaining + } + nr, er := src.Read(buf[:l]) + nw, ew := dst.Write(buf[:nr]) + remaining -= nw + written += nw + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + if er != nil && er != io.EOF { + return written, er + } + } + return written, nil +} + +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { + t.Errorf("unable to send channel keepalive request: %v", err) + } + sendStatus(0, ch, t) +} + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := ioutil.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := ioutil.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} + +func TestSessionID(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverID := make(chan []byte, 1) + clientID := make(chan []byte, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + User: "user", + } + + go func() { + conn, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Fatalf("server handshake: %v", err) + } + serverID <- conn.SessionID() + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + go func() { + conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + clientID <- conn.SessionID() + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + s := <-serverID + c := <-clientID + if bytes.Compare(s, c) != 0 { + t.Errorf("server session ID (%x) != client session ID (%x)", s, c) + } else if len(s) == 0 { + t.Errorf("client and server SessionID were empty.") + } +} + +type noReadConn struct { + readSeen bool + net.Conn +} + +func (c *noReadConn) Close() error { + return nil +} + +func (c *noReadConn) Read(b []byte) (int, error) { + c.readSeen = true + return 0, errors.New("noReadConn error") +} + +func TestInvalidServerConfiguration(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serveConn := noReadConn{Conn: c1} + serverConf := &ServerConfig{} + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") + } + + serverConf.AddHostKey(testSigners["ecdsa"]) + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") + } +} + +func TestHostKeyAlgorithms(t *testing.T) { + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.AddHostKey(testSigners["ecdsa"]) + + connect := func(clientConf *ClientConfig, want string) { + var alg string + clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { + alg = key.Type() + return nil + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, serverConf) + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + if alg != want { + t.Errorf("selected key algorithm %s, want %s", alg, want) + } + } + + // By default, we get the preferred algorithm, which is ECDSA 256. + + clientConf := &ClientConfig{} + connect(clientConf, KeyAlgoECDSA256) + + // Client asks for RSA explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} + connect(clientConf, KeyAlgoRSA) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, serverConf) + clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with unknown hostkey algorithm") + } +} diff --git a/lib/ssh/tcpip.go b/lib/ssh/tcpip.go new file mode 100644 index 0000000..514c735 --- /dev/null +++ b/lib/ssh/tcpip.go @@ -0,0 +1,407 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +func (c *Client) Listen(n, addr string) (net.Listener, error) { + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(*laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.TCPAddr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr *net.TCPAddr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.TCPAddr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + addr, + make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var payload forwardedTCPPayload + if err := Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err := parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + if ok := l.forward(*laddr, *raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.TCPAddr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { + f.c <- forward{ch, &raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &tcpChanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(*l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + return &tcpChanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &tcpChanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// tcpChanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type tcpChanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *tcpChanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *tcpChanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *tcpChanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/lib/ssh/tcpip_test.go b/lib/ssh/tcpip_test.go new file mode 100644 index 0000000..5879f1d --- /dev/null +++ b/lib/ssh/tcpip_test.go @@ -0,0 +1,20 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "testing" +) + +func TestAutoPortListenBroken(t *testing.T) { + broken := "SSH-2.0-OpenSSH_5.9hh11" + works := "SSH-2.0-OpenSSH_6.1" + if !isBrokenOpenSSHVersion(broken) { + t.Errorf("version %q not marked as broken", broken) + } + if isBrokenOpenSSHVersion(works) { + t.Errorf("version %q marked as broken", works) + } +} diff --git a/lib/ssh/terminal/terminal.go b/lib/ssh/terminal/terminal.go new file mode 100644 index 0000000..f816773 --- /dev/null +++ b/lib/ssh/terminal/terminal.go @@ -0,0 +1,924 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "bytes" + "io" + "sync" + "unicode/utf8" +) + +// EscapeCodes contains escape sequences that can be written to the terminal in +// order to achieve different styles of text. +type EscapeCodes struct { + // Foreground colors + Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte + + // Reset all attributes + Reset []byte +} + +var vt100EscapeCodes = EscapeCodes{ + Black: []byte{keyEscape, '[', '3', '0', 'm'}, + Red: []byte{keyEscape, '[', '3', '1', 'm'}, + Green: []byte{keyEscape, '[', '3', '2', 'm'}, + Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, + Blue: []byte{keyEscape, '[', '3', '4', 'm'}, + Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, + Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, + White: []byte{keyEscape, '[', '3', '7', 'm'}, + + Reset: []byte{keyEscape, '[', '0', 'm'}, +} + +// Terminal contains the state for running a VT100 terminal that is capable of +// reading lines of input. +type Terminal struct { + // AutoCompleteCallback, if non-null, is called for each keypress with + // the full input line and the current position of the cursor (in + // bytes, as an index into |line|). If it returns ok=false, the key + // press is processed normally. Otherwise it returns a replacement line + // and the new cursor position. + AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) + + // Escape contains a pointer to the escape codes for this terminal. + // It's always a valid pointer, although the escape codes themselves + // may be empty if the terminal doesn't support them. + Escape *EscapeCodes + + // lock protects the terminal and the state in this object from + // concurrent processing of a key press and a Write() call. + lock sync.Mutex + + c io.ReadWriter + prompt []rune + + // line is the current line being entered. + line []rune + // pos is the logical position of the cursor in line + pos int + // echo is true if local echo is enabled + echo bool + // pasteActive is true iff there is a bracketed paste operation in + // progress. + pasteActive bool + + // cursorX contains the current X value of the cursor where the left + // edge is 0. cursorY contains the row number where the first row of + // the current line is 0. + cursorX, cursorY int + // maxLine is the greatest value of cursorY so far. + maxLine int + + termWidth, termHeight int + + // outBuf contains the terminal data to be sent. + outBuf []byte + // remainder contains the remainder of any partial key sequences after + // a read. It aliases into inBuf. + remainder []byte + inBuf [256]byte + + // history contains previously entered commands so that they can be + // accessed with the up and down keys. + history stRingBuffer + // historyIndex stores the currently accessed history entry, where zero + // means the immediately previous entry. + historyIndex int + // When navigating up and down the history it's possible to return to + // the incomplete, initial line. That value is stored in + // historyPending. + historyPending string +} + +// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is +// a local terminal, that terminal must first have been put into raw mode. +// prompt is a string that is written at the start of each input line (i.e. +// "> "). +func NewTerminal(c io.ReadWriter, prompt string) *Terminal { + return &Terminal{ + Escape: &vt100EscapeCodes, + c: c, + prompt: []rune(prompt), + termWidth: 80, + termHeight: 24, + echo: true, + historyIndex: -1, + } +} + +const ( + keyCtrlD = 4 + keyCtrlU = 21 + keyEnter = '\r' + keyEscape = 27 + keyBackspace = 127 + keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota + keyUp + keyDown + keyLeft + keyRight + keyAltLeft + keyAltRight + keyHome + keyEnd + keyDeleteWord + keyDeleteLine + keyClearScreen + keyPasteStart + keyPasteEnd +) + +var ( + crlf = []byte{'\r', '\n'} + pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} + pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} +) + +// bytesToKey tries to parse a key sequence from b. If successful, it returns +// the key and the remainder of the input. Otherwise it returns utf8.RuneError. +func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { + if len(b) == 0 { + return utf8.RuneError, nil + } + + if !pasteActive { + switch b[0] { + case 1: // ^A + return keyHome, b[1:] + case 5: // ^E + return keyEnd, b[1:] + case 8: // ^H + return keyBackspace, b[1:] + case 11: // ^K + return keyDeleteLine, b[1:] + case 12: // ^L + return keyClearScreen, b[1:] + case 23: // ^W + return keyDeleteWord, b[1:] + } + } + + if b[0] != keyEscape { + if !utf8.FullRune(b) { + return utf8.RuneError, b + } + r, l := utf8.DecodeRune(b) + return r, b[l:] + } + + if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { + switch b[2] { + case 'A': + return keyUp, b[3:] + case 'B': + return keyDown, b[3:] + case 'C': + return keyRight, b[3:] + case 'D': + return keyLeft, b[3:] + case 'H': + return keyHome, b[3:] + case 'F': + return keyEnd, b[3:] + } + } + + if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { + switch b[5] { + case 'C': + return keyAltRight, b[6:] + case 'D': + return keyAltLeft, b[6:] + } + } + + if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { + return keyPasteStart, b[6:] + } + + if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { + return keyPasteEnd, b[6:] + } + + // If we get here then we have a key that we don't recognise, or a + // partial sequence. It's not clear how one should find the end of a + // sequence without knowing them all, but it seems that [a-zA-Z~] only + // appears at the end of a sequence. + for i, c := range b[0:] { + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { + return keyUnknown, b[i+1:] + } + } + + return utf8.RuneError, b +} + +// queue appends data to the end of t.outBuf +func (t *Terminal) queue(data []rune) { + t.outBuf = append(t.outBuf, []byte(string(data))...) +} + +var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} +var space = []rune{' '} + +func isPrintable(key rune) bool { + isInSurrogateArea := key >= 0xd800 && key <= 0xdbff + return key >= 32 && !isInSurrogateArea +} + +// moveCursorToPos appends data to t.outBuf which will move the cursor to the +// given, logical position in the text. +func (t *Terminal) moveCursorToPos(pos int) { + if !t.echo { + return + } + + x := visualLength(t.prompt) + pos + y := x / t.termWidth + x = x % t.termWidth + + up := 0 + if y < t.cursorY { + up = t.cursorY - y + } + + down := 0 + if y > t.cursorY { + down = y - t.cursorY + } + + left := 0 + if x < t.cursorX { + left = t.cursorX - x + } + + right := 0 + if x > t.cursorX { + right = x - t.cursorX + } + + t.cursorX = x + t.cursorY = y + t.move(up, down, left, right) +} + +func (t *Terminal) move(up, down, left, right int) { + movement := make([]rune, 3*(up+down+left+right)) + m := movement + for i := 0; i < up; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'A' + m = m[3:] + } + for i := 0; i < down; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'B' + m = m[3:] + } + for i := 0; i < left; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'D' + m = m[3:] + } + for i := 0; i < right; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'C' + m = m[3:] + } + + t.queue(movement) +} + +func (t *Terminal) clearLineToRight() { + op := []rune{keyEscape, '[', 'K'} + t.queue(op) +} + +const maxLineLength = 4096 + +func (t *Terminal) setLine(newLine []rune, newPos int) { + if t.echo { + t.moveCursorToPos(0) + t.writeLine(newLine) + for i := len(newLine); i < len(t.line); i++ { + t.writeLine(space) + } + t.moveCursorToPos(newPos) + } + t.line = newLine + t.pos = newPos +} + +func (t *Terminal) advanceCursor(places int) { + t.cursorX += places + t.cursorY += t.cursorX / t.termWidth + if t.cursorY > t.maxLine { + t.maxLine = t.cursorY + } + t.cursorX = t.cursorX % t.termWidth + + if places > 0 && t.cursorX == 0 { + // Normally terminals will advance the current position + // when writing a character. But that doesn't happen + // for the last character in a line. However, when + // writing a character (except a new line) that causes + // a line wrap, the position will be advanced two + // places. + // + // So, if we are stopping at the end of a line, we + // need to write a newline so that our cursor can be + // advanced to the next line. + t.outBuf = append(t.outBuf, '\r', '\n') + } +} + +func (t *Terminal) eraseNPreviousChars(n int) { + if n == 0 { + return + } + + if t.pos < n { + n = t.pos + } + t.pos -= n + t.moveCursorToPos(t.pos) + + copy(t.line[t.pos:], t.line[n+t.pos:]) + t.line = t.line[:len(t.line)-n] + if t.echo { + t.writeLine(t.line[t.pos:]) + for i := 0; i < n; i++ { + t.queue(space) + } + t.advanceCursor(n) + t.moveCursorToPos(t.pos) + } +} + +// countToLeftWord returns then number of characters from the cursor to the +// start of the previous word. +func (t *Terminal) countToLeftWord() int { + if t.pos == 0 { + return 0 + } + + pos := t.pos - 1 + for pos > 0 { + if t.line[pos] != ' ' { + break + } + pos-- + } + for pos > 0 { + if t.line[pos] == ' ' { + pos++ + break + } + pos-- + } + + return t.pos - pos +} + +// countToRightWord returns then number of characters from the cursor to the +// start of the next word. +func (t *Terminal) countToRightWord() int { + pos := t.pos + for pos < len(t.line) { + if t.line[pos] == ' ' { + break + } + pos++ + } + for pos < len(t.line) { + if t.line[pos] != ' ' { + break + } + pos++ + } + return pos - t.pos +} + +// visualLength returns the number of visible glyphs in s. +func visualLength(runes []rune) int { + inEscapeSeq := false + length := 0 + + for _, r := range runes { + switch { + case inEscapeSeq: + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + inEscapeSeq = false + } + case r == '\x1b': + inEscapeSeq = true + default: + length++ + } + } + + return length +} + +// handleKey processes the given key and, optionally, returns a line of text +// that the user has entered. +func (t *Terminal) handleKey(key rune) (line string, ok bool) { + if t.pasteActive && key != keyEnter { + t.addKeyToLine(key) + return + } + + switch key { + case keyBackspace: + if t.pos == 0 { + return + } + t.eraseNPreviousChars(1) + case keyAltLeft: + // move left by a word. + t.pos -= t.countToLeftWord() + t.moveCursorToPos(t.pos) + case keyAltRight: + // move right by a word. + t.pos += t.countToRightWord() + t.moveCursorToPos(t.pos) + case keyLeft: + if t.pos == 0 { + return + } + t.pos-- + t.moveCursorToPos(t.pos) + case keyRight: + if t.pos == len(t.line) { + return + } + t.pos++ + t.moveCursorToPos(t.pos) + case keyHome: + if t.pos == 0 { + return + } + t.pos = 0 + t.moveCursorToPos(t.pos) + case keyEnd: + if t.pos == len(t.line) { + return + } + t.pos = len(t.line) + t.moveCursorToPos(t.pos) + case keyUp: + entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) + if !ok { + return "", false + } + if t.historyIndex == -1 { + t.historyPending = string(t.line) + } + t.historyIndex++ + runes := []rune(entry) + t.setLine(runes, len(runes)) + case keyDown: + switch t.historyIndex { + case -1: + return + case 0: + runes := []rune(t.historyPending) + t.setLine(runes, len(runes)) + t.historyIndex-- + default: + entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) + if ok { + t.historyIndex-- + runes := []rune(entry) + t.setLine(runes, len(runes)) + } + } + case keyEnter: + t.moveCursorToPos(len(t.line)) + t.queue([]rune("\r\n")) + line = string(t.line) + ok = true + t.line = t.line[:0] + t.pos = 0 + t.cursorX = 0 + t.cursorY = 0 + t.maxLine = 0 + case keyDeleteWord: + // Delete zero or more spaces and then one or more characters. + t.eraseNPreviousChars(t.countToLeftWord()) + case keyDeleteLine: + // Delete everything from the current cursor position to the + // end of line. + for i := t.pos; i < len(t.line); i++ { + t.queue(space) + t.advanceCursor(1) + } + t.line = t.line[:t.pos] + t.moveCursorToPos(t.pos) + case keyCtrlD: + // Erase the character under the current position. + // The EOF case when the line is empty is handled in + // readLine(). + if t.pos < len(t.line) { + t.pos++ + t.eraseNPreviousChars(1) + } + case keyCtrlU: + t.eraseNPreviousChars(t.pos) + case keyClearScreen: + // Erases the screen and moves the cursor to the home position. + t.queue([]rune("\x1b[2J\x1b[H")) + t.queue(t.prompt) + t.cursorX, t.cursorY = 0, 0 + t.advanceCursor(visualLength(t.prompt)) + t.setLine(t.line, t.pos) + default: + if t.AutoCompleteCallback != nil { + prefix := string(t.line[:t.pos]) + suffix := string(t.line[t.pos:]) + + t.lock.Unlock() + newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) + t.lock.Lock() + + if completeOk { + t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) + return + } + } + if !isPrintable(key) { + return + } + if len(t.line) == maxLineLength { + return + } + t.addKeyToLine(key) + } + return +} + +// addKeyToLine inserts the given key at the current position in the current +// line. +func (t *Terminal) addKeyToLine(key rune) { + if len(t.line) == cap(t.line) { + newLine := make([]rune, len(t.line), 2*(1+len(t.line))) + copy(newLine, t.line) + t.line = newLine + } + t.line = t.line[:len(t.line)+1] + copy(t.line[t.pos+1:], t.line[t.pos:]) + t.line[t.pos] = key + if t.echo { + t.writeLine(t.line[t.pos:]) + } + t.pos++ + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) writeLine(line []rune) { + for len(line) != 0 { + remainingOnLine := t.termWidth - t.cursorX + todo := len(line) + if todo > remainingOnLine { + todo = remainingOnLine + } + t.queue(line[:todo]) + t.advanceCursor(visualLength(line[:todo])) + line = line[todo:] + } +} + +// writeWithCRLF writes buf to w but replaces all occurrences of \n with \r\n. +func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) { + for len(buf) > 0 { + i := bytes.IndexByte(buf, '\n') + todo := len(buf) + if i >= 0 { + todo = i + } + + var nn int + nn, err = w.Write(buf[:todo]) + n += nn + if err != nil { + return n, err + } + buf = buf[todo:] + + if i >= 0 { + if _, err = w.Write(crlf); err != nil { + return n, err + } + n += 1 + buf = buf[1:] + } + } + + return n, nil +} + +func (t *Terminal) Write(buf []byte) (n int, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.cursorX == 0 && t.cursorY == 0 { + // This is the easy case: there's nothing on the screen that we + // have to move out of the way. + return writeWithCRLF(t.c, buf) + } + + // We have a prompt and possibly user input on the screen. We + // have to clear it first. + t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) + t.cursorX = 0 + t.clearLineToRight() + + for t.cursorY > 0 { + t.move(1 /* up */, 0, 0, 0) + t.cursorY-- + t.clearLineToRight() + } + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + + if n, err = writeWithCRLF(t.c, buf); err != nil { + return + } + + t.writeLine(t.prompt) + if t.echo { + t.writeLine(t.line) + } + + t.moveCursorToPos(t.pos) + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + return +} + +// ReadPassword temporarily changes the prompt and reads a password, without +// echo, from the terminal. +func (t *Terminal) ReadPassword(prompt string) (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + oldPrompt := t.prompt + t.prompt = []rune(prompt) + t.echo = false + + line, err = t.readLine() + + t.prompt = oldPrompt + t.echo = true + + return +} + +// ReadLine returns a line of input from the terminal. +func (t *Terminal) ReadLine() (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + return t.readLine() +} + +func (t *Terminal) readLine() (line string, err error) { + // t.lock must be held at this point + + if t.cursorX == 0 && t.cursorY == 0 { + t.writeLine(t.prompt) + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + } + + lineIsPasted := t.pasteActive + + for { + rest := t.remainder + lineOk := false + for !lineOk { + var key rune + key, rest = bytesToKey(rest, t.pasteActive) + if key == utf8.RuneError { + break + } + if !t.pasteActive { + if key == keyCtrlD { + if len(t.line) == 0 { + return "", io.EOF + } + } + if key == keyPasteStart { + t.pasteActive = true + if len(t.line) == 0 { + lineIsPasted = true + } + continue + } + } else if key == keyPasteEnd { + t.pasteActive = false + continue + } + if !t.pasteActive { + lineIsPasted = false + } + line, lineOk = t.handleKey(key) + } + if len(rest) > 0 { + n := copy(t.inBuf[:], rest) + t.remainder = t.inBuf[:n] + } else { + t.remainder = nil + } + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + if lineOk { + if t.echo { + t.historyIndex = -1 + t.history.Add(line) + } + if lineIsPasted { + err = ErrPasteIndicator + } + return + } + + // t.remainder is a slice at the beginning of t.inBuf + // containing a partial key sequence + readBuf := t.inBuf[len(t.remainder):] + var n int + + t.lock.Unlock() + n, err = t.c.Read(readBuf) + t.lock.Lock() + + if err != nil { + return + } + + t.remainder = t.inBuf[:n+len(t.remainder)] + } + + panic("unreachable") // for Go 1.0. +} + +// SetPrompt sets the prompt to be used when reading subsequent lines. +func (t *Terminal) SetPrompt(prompt string) { + t.lock.Lock() + defer t.lock.Unlock() + + t.prompt = []rune(prompt) +} + +func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { + // Move cursor to column zero at the start of the line. + t.move(t.cursorY, 0, t.cursorX, 0) + t.cursorX, t.cursorY = 0, 0 + t.clearLineToRight() + for t.cursorY < numPrevLines { + // Move down a line + t.move(0, 1, 0, 0) + t.cursorY++ + t.clearLineToRight() + } + // Move back to beginning. + t.move(t.cursorY, 0, 0, 0) + t.cursorX, t.cursorY = 0, 0 + + t.queue(t.prompt) + t.advanceCursor(visualLength(t.prompt)) + t.writeLine(t.line) + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) SetSize(width, height int) error { + t.lock.Lock() + defer t.lock.Unlock() + + if width == 0 { + width = 1 + } + + oldWidth := t.termWidth + t.termWidth, t.termHeight = width, height + + switch { + case width == oldWidth: + // If the width didn't change then nothing else needs to be + // done. + return nil + case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: + // If there is nothing on current line and no prompt printed, + // just do nothing + return nil + case width < oldWidth: + // Some terminals (e.g. xterm) will truncate lines that were + // too long when shinking. Others, (e.g. gnome-terminal) will + // attempt to wrap them. For the former, repainting t.maxLine + // works great, but that behaviour goes badly wrong in the case + // of the latter because they have doubled every full line. + + // We assume that we are working on a terminal that wraps lines + // and adjust the cursor position based on every previous line + // wrapping and turning into two. This causes the prompt on + // xterms to move upwards, which isn't great, but it avoids a + // huge mess with gnome-terminal. + if t.cursorX >= t.termWidth { + t.cursorX = t.termWidth - 1 + } + t.cursorY *= 2 + t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) + case width > oldWidth: + // If the terminal expands then our position calculations will + // be wrong in the future because we think the cursor is + // |t.pos| chars into the string, but there will be a gap at + // the end of any wrapped line. + // + // But the position will actually be correct until we move, so + // we can move back to the beginning and repaint everything. + t.clearAndRepaintLinePlusNPrevious(t.maxLine) + } + + _, err := t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + return err +} + +type pasteIndicatorError struct{} + +func (pasteIndicatorError) Error() string { + return "terminal: ErrPasteIndicator not correctly handled" +} + +// ErrPasteIndicator may be returned from ReadLine as the error, in addition +// to valid line data. It indicates that bracketed paste mode is enabled and +// that the returned line consists only of pasted data. Programs may wish to +// interpret pasted data more literally than typed data. +var ErrPasteIndicator = pasteIndicatorError{} + +// SetBracketedPasteMode requests that the terminal bracket paste operations +// with markers. Not all terminals support this but, if it is supported, then +// enabling this mode will stop any autocomplete callback from running due to +// pastes. Additionally, any lines that are completely pasted will be returned +// from ReadLine with the error set to ErrPasteIndicator. +func (t *Terminal) SetBracketedPasteMode(on bool) { + if on { + io.WriteString(t.c, "\x1b[?2004h") + } else { + io.WriteString(t.c, "\x1b[?2004l") + } +} + +// stRingBuffer is a ring buffer of strings. +type stRingBuffer struct { + // entries contains max elements. + entries []string + max int + // head contains the index of the element most recently added to the ring. + head int + // size contains the number of elements in the ring. + size int +} + +func (s *stRingBuffer) Add(a string) { + if s.entries == nil { + const defaultNumEntries = 100 + s.entries = make([]string, defaultNumEntries) + s.max = defaultNumEntries + } + + s.head = (s.head + 1) % s.max + s.entries[s.head] = a + if s.size < s.max { + s.size++ + } +} + +// NthPreviousEntry returns the value passed to the nth previous call to Add. +// If n is zero then the immediately prior value is returned, if one, then the +// next most recent, and so on. If such an element doesn't exist then ok is +// false. +func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { + if n >= s.size { + return "", false + } + index := s.head - n + if index < 0 { + index += s.max + } + return s.entries[index], true +} diff --git a/lib/ssh/terminal/terminal_test.go b/lib/ssh/terminal/terminal_test.go new file mode 100644 index 0000000..1d54c4f --- /dev/null +++ b/lib/ssh/terminal/terminal_test.go @@ -0,0 +1,306 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "bytes" + "io" + "os" + "testing" +) + +type MockTerminal struct { + toSend []byte + bytesPerRead int + received []byte +} + +func (c *MockTerminal) Read(data []byte) (n int, err error) { + n = len(data) + if n == 0 { + return + } + if n > len(c.toSend) { + n = len(c.toSend) + } + if n == 0 { + return 0, io.EOF + } + if c.bytesPerRead > 0 && n > c.bytesPerRead { + n = c.bytesPerRead + } + copy(data, c.toSend[:n]) + c.toSend = c.toSend[n:] + return +} + +func (c *MockTerminal) Write(data []byte) (n int, err error) { + c.received = append(c.received, data...) + return len(data), nil +} + +func TestClose(t *testing.T) { + c := &MockTerminal{} + ss := NewTerminal(c, "> ") + line, err := ss.ReadLine() + if line != "" { + t.Errorf("Expected empty line but got: %s", line) + } + if err != io.EOF { + t.Errorf("Error should have been EOF but got: %s", err) + } +} + +var keyPressTests = []struct { + in string + line string + err error + throwAwayLines int +}{ + { + err: io.EOF, + }, + { + in: "\r", + line: "", + }, + { + in: "foo\r", + line: "foo", + }, + { + in: "a\x1b[Cb\r", // right + line: "ab", + }, + { + in: "a\x1b[Db\r", // left + line: "ba", + }, + { + in: "a\177b\r", // backspace + line: "b", + }, + { + in: "\x1b[A\r", // up + }, + { + in: "\x1b[B\r", // down + }, + { + in: "line\x1b[A\x1b[B\r", // up then down + line: "line", + }, + { + in: "line1\rline2\x1b[A\r", // recall previous line. + line: "line1", + throwAwayLines: 1, + }, + { + // recall two previous lines and append. + in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r", + line: "line1xxx", + throwAwayLines: 2, + }, + { + // Ctrl-A to move to beginning of line followed by ^K to kill + // line. + in: "a b \001\013\r", + line: "", + }, + { + // Ctrl-A to move to beginning of line, Ctrl-E to move to end, + // finally ^K to kill nothing. + in: "a b \001\005\013\r", + line: "a b ", + }, + { + in: "\027\r", + line: "", + }, + { + in: "a\027\r", + line: "", + }, + { + in: "a \027\r", + line: "", + }, + { + in: "a b\027\r", + line: "a ", + }, + { + in: "a b \027\r", + line: "a ", + }, + { + in: "one two thr\x1b[D\027\r", + line: "one two r", + }, + { + in: "\013\r", + line: "", + }, + { + in: "a\013\r", + line: "a", + }, + { + in: "ab\x1b[D\013\r", + line: "a", + }, + { + in: "Ξεσκεπάζω\r", + line: "Ξεσκεπάζω", + }, + { + in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace. + line: "", + throwAwayLines: 1, + }, + { + in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter. + line: "£", + throwAwayLines: 1, + }, + { + // Ctrl-D at the end of the line should be ignored. + in: "a\004\r", + line: "a", + }, + { + // a, b, left, Ctrl-D should erase the b. + in: "ab\x1b[D\004\r", + line: "a", + }, + { + // a, b, c, d, left, left, ^U should erase to the beginning of + // the line. + in: "abcd\x1b[D\x1b[D\025\r", + line: "cd", + }, + { + // Bracketed paste mode: control sequences should be returned + // verbatim in paste mode. + in: "abc\x1b[200~de\177f\x1b[201~\177\r", + line: "abcde\177", + }, + { + // Enter in bracketed paste mode should still work. + in: "abc\x1b[200~d\refg\x1b[201~h\r", + line: "efgh", + throwAwayLines: 1, + }, + { + // Lines consisting entirely of pasted data should be indicated as such. + in: "\x1b[200~a\r", + line: "a", + err: ErrPasteIndicator, + }, +} + +func TestKeyPresses(t *testing.T) { + for i, test := range keyPressTests { + for j := 1; j < len(test.in); j++ { + c := &MockTerminal{ + toSend: []byte(test.in), + bytesPerRead: j, + } + ss := NewTerminal(c, "> ") + for k := 0; k < test.throwAwayLines; k++ { + _, err := ss.ReadLine() + if err != nil { + t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err) + } + } + line, err := ss.ReadLine() + if line != test.line { + t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) + break + } + if err != test.err { + t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err) + break + } + } + } +} + +func TestPasswordNotSaved(t *testing.T) { + c := &MockTerminal{ + toSend: []byte("password\r\x1b[A\r"), + bytesPerRead: 1, + } + ss := NewTerminal(c, "> ") + pw, _ := ss.ReadPassword("> ") + if pw != "password" { + t.Fatalf("failed to read password, got %s", pw) + } + line, _ := ss.ReadLine() + if len(line) > 0 { + t.Fatalf("password was saved in history") + } +} + +var setSizeTests = []struct { + width, height int +}{ + {40, 13}, + {80, 24}, + {132, 43}, +} + +func TestTerminalSetSize(t *testing.T) { + for _, setSize := range setSizeTests { + c := &MockTerminal{ + toSend: []byte("password\r\x1b[A\r"), + bytesPerRead: 1, + } + ss := NewTerminal(c, "> ") + ss.SetSize(setSize.width, setSize.height) + pw, _ := ss.ReadPassword("Password: ") + if pw != "password" { + t.Fatalf("failed to read password, got %s", pw) + } + if string(c.received) != "Password: \r\n" { + t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received) + } + } +} + +func TestMakeRawState(t *testing.T) { + fd := int(os.Stdout.Fd()) + if !IsTerminal(fd) { + t.Skip("stdout is not a terminal; skipping test") + } + + st, err := GetState(fd) + if err != nil { + t.Fatalf("failed to get terminal state from GetState: %s", err) + } + defer Restore(fd, st) + raw, err := MakeRaw(fd) + if err != nil { + t.Fatalf("failed to get terminal state from MakeRaw: %s", err) + } + + if *st != *raw { + t.Errorf("states do not match; was %v, expected %v", raw, st) + } +} + +func TestOutputNewlines(t *testing.T) { + // \n should be changed to \r\n in terminal output. + buf := new(bytes.Buffer) + term := NewTerminal(buf, ">") + + term.Write([]byte("1\n2\n")) + output := string(buf.Bytes()) + const expected = "1\r\n2\r\n" + + if output != expected { + t.Errorf("incorrect output: was %q, expected %q", output, expected) + } +} diff --git a/lib/ssh/terminal/util.go b/lib/ssh/terminal/util.go new file mode 100644 index 0000000..c162b8b --- /dev/null +++ b/lib/ssh/terminal/util.go @@ -0,0 +1,133 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal // import "github.com/zmap/zgrab/ztools/xssh/terminal" + +import ( + "io" + "syscall" + "unsafe" +) + +// State contains the state of a terminal. +type State struct { + termios syscall.Termios +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var oldState State + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { + return nil, err + } + + newState := oldState.termios + // This attempts to replicate the behaviour documented for cfmakeraw in + // the termios(3) manpage. + newState.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON + newState.Oflag &^= syscall.OPOST + newState.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN + newState.Cflag &^= syscall.CSIZE | syscall.PARENB + newState.Cflag |= syscall.CS8 + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { + return nil, err + } + + return &oldState, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var oldState State + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { + return nil, err + } + + return &oldState, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var dimensions [4]uint16 + + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 { + return -1, -1, err + } + return int(dimensions[1]), int(dimensions[0]), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var oldState syscall.Termios + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { + return nil, err + } + + newState := oldState + newState.Lflag &^= syscall.ECHO + newState.Lflag |= syscall.ICANON | syscall.ISIG + newState.Iflag |= syscall.ICRNL + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { + return nil, err + } + + defer func() { + syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(fd, buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/lib/ssh/terminal/util_bsd.go b/lib/ssh/terminal/util_bsd.go new file mode 100644 index 0000000..9c1ffd1 --- /dev/null +++ b/lib/ssh/terminal/util_bsd.go @@ -0,0 +1,12 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd netbsd openbsd + +package terminal + +import "syscall" + +const ioctlReadTermios = syscall.TIOCGETA +const ioctlWriteTermios = syscall.TIOCSETA diff --git a/lib/ssh/terminal/util_linux.go b/lib/ssh/terminal/util_linux.go new file mode 100644 index 0000000..5883b22 --- /dev/null +++ b/lib/ssh/terminal/util_linux.go @@ -0,0 +1,11 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +// These constants are declared here, rather than importing +// them from the syscall package as some syscall packages, even +// on linux, for example gccgo, do not declare them. +const ioctlReadTermios = 0x5401 // syscall.TCGETS +const ioctlWriteTermios = 0x5402 // syscall.TCSETS diff --git a/lib/ssh/terminal/util_plan9.go b/lib/ssh/terminal/util_plan9.go new file mode 100644 index 0000000..799f049 --- /dev/null +++ b/lib/ssh/terminal/util_plan9.go @@ -0,0 +1,58 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "fmt" + "runtime" +) + +type State struct{} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + return false +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/lib/ssh/terminal/util_solaris.go b/lib/ssh/terminal/util_solaris.go new file mode 100644 index 0000000..431cafe --- /dev/null +++ b/lib/ssh/terminal/util_solaris.go @@ -0,0 +1,73 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build solaris + +package terminal // import "github.com/zmap/zgrab/ztools/xssh/terminal" + +import ( + "golang.org/x/sys/unix" + "io" + "syscall" +) + +// State contains the state of a terminal. +type State struct { + termios syscall.Termios +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + // see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c + var termio unix.Termio + err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio) + return err == nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + // see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c + val, err := unix.IoctlGetTermios(fd, unix.TCGETS) + if err != nil { + return nil, err + } + oldState := *val + + newState := oldState + newState.Lflag &^= syscall.ECHO + newState.Lflag |= syscall.ICANON | syscall.ISIG + newState.Iflag |= syscall.ICRNL + err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState) + if err != nil { + return nil, err + } + + defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState) + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(fd, buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/lib/ssh/terminal/util_windows.go b/lib/ssh/terminal/util_windows.go new file mode 100644 index 0000000..ae9fa9e --- /dev/null +++ b/lib/ssh/terminal/util_windows.go @@ -0,0 +1,174 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "io" + "syscall" + "unsafe" +) + +const ( + enableLineInput = 2 + enableEchoInput = 4 + enableProcessedInput = 1 + enableWindowInput = 8 + enableMouseInput = 16 + enableInsertMode = 32 + enableQuickEditMode = 64 + enableExtendedFlags = 128 + enableAutoPosition = 256 + enableProcessedOutput = 1 + enableWrapAtEolOutput = 2 +) + +var kernel32 = syscall.NewLazyDLL("kernel32.dll") + +var ( + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +type ( + short int16 + word uint16 + + coord struct { + x short + y short + } + smallRect struct { + left short + top short + right short + bottom short + } + consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes word + window smallRect + maximumWindowSize coord + } +) + +type State struct { + mode uint32 +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + raw := st &^ (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(raw), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var info consoleScreenBufferInfo + _, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) + if e != 0 { + return 0, 0, error(e) + } + return int(info.size.x), int(info.size.y), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + old := st + + st &^= (enableEchoInput) + st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) + if e != 0 { + return nil, error(e) + } + + defer func() { + syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(syscall.Handle(fd), buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + if n > 0 && buf[n-1] == '\r' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/lib/ssh/test/agent_unix_test.go b/lib/ssh/test/agent_unix_test.go new file mode 100644 index 0000000..716bfa3 --- /dev/null +++ b/lib/ssh/test/agent_unix_test.go @@ -0,0 +1,59 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "testing" + + "github.com/zmap/zgrab/ztools/xssh" + "github.com/zmap/zgrab/ztools/xssh/agent" +) + +func DISABLED_TestAgentForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + keyring := agent.NewKeyring() + if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil { + t.Fatalf("Error adding key: %s", err) + } + if err := keyring.Add(agent.AddedKey{ + PrivateKey: testPrivateKeys["dsa"], + ConfirmBeforeUse: true, + LifetimeSecs: 3600, + }); err != nil { + t.Fatalf("Error adding key with constraints: %s", err) + } + pub := testPublicKeys["dsa"] + + sess, err := conn.NewSession() + if err != nil { + t.Fatalf("NewSession: %v", err) + } + if err := agent.RequestAgentForwarding(sess); err != nil { + t.Fatalf("RequestAgentForwarding: %v", err) + } + + if err := agent.ForwardToAgent(conn, keyring); err != nil { + t.Fatalf("SetupForwardKeyring: %v", err) + } + out, err := sess.CombinedOutput("ssh-add -L") + if err != nil { + t.Fatalf("running ssh-add: %v, out %s", err, out) + } + key, _, _, _, err := xssh.ParseAuthorizedKey(out) + if err != nil { + t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) + } + + if !bytes.Equal(key.Marshal(), pub.Marshal()) { + t.Fatalf("got key %s, want %s", xssh.MarshalAuthorizedKey(key), xssh.MarshalAuthorizedKey(pub)) + } +} diff --git a/lib/ssh/test/cert_test.go b/lib/ssh/test/cert_test.go new file mode 100644 index 0000000..d1f5aaf --- /dev/null +++ b/lib/ssh/test/cert_test.go @@ -0,0 +1,47 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "crypto/rand" + "testing" + + "github.com/zmap/zgrab/ztools/xssh" +) + +func DISABLED_TestCertLogin(t *testing.T) { + s := newServer(t) + defer s.Shutdown() + + // Use a key different from the default. + clientKey := testSigners["dsa"] + caAuthKey := testSigners["ecdsa"] + cert := &xssh.Certificate{ + Key: clientKey.PublicKey(), + ValidPrincipals: []string{username()}, + CertType: xssh.UserCert, + ValidBefore: xssh.CertTimeInfinity, + } + if err := cert.SignCert(rand.Reader, caAuthKey); err != nil { + t.Fatalf("SetSignature: %v", err) + } + + certSigner, err := xssh.NewCertSigner(cert, clientKey) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + conf := &xssh.ClientConfig{ + User: username(), + } + conf.Auth = append(conf.Auth, xssh.PublicKeys(certSigner)) + client, err := s.TryDial(conf) + if err != nil { + t.Fatalf("TryDial: %v", err) + } + client.Close() +} diff --git a/lib/ssh/test/doc.go b/lib/ssh/test/doc.go new file mode 100644 index 0000000..5693273 --- /dev/null +++ b/lib/ssh/test/doc.go @@ -0,0 +1,7 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains integration tests for the +// github.com/zmap/zgrab/ztools/xssh. +package test // import "github.com/zmap/zgrab/ztools/xssh/test" diff --git a/lib/ssh/test/forward_unix_test.go b/lib/ssh/test/forward_unix_test.go new file mode 100644 index 0000000..69c3f24 --- /dev/null +++ b/lib/ssh/test/forward_unix_test.go @@ -0,0 +1,160 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "io" + "io/ioutil" + "math/rand" + "net" + "testing" + "time" +) + +func DISABLED_TestPortForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + go func() { + sshConn, err := sshListener.Accept() + if err != nil { + t.Fatalf("listen.Accept failed: %v", err) + } + + _, err = io.Copy(sshConn, sshConn) + if err != nil && err != io.EOF { + t.Fatalf("ssh client copy: %v", err) + } + sshConn.Close() + }() + + forwardedAddr := sshListener.Addr().String() + tcpConn, err := net.Dial("tcp", forwardedAddr) + if err != nil { + t.Fatalf("TCP dial failed: %v", err) + } + + readChan := make(chan []byte) + go func() { + data, _ := ioutil.ReadAll(tcpConn) + readChan <- data + }() + + // Invent some data. + data := make([]byte, 100*1000) + for i := range data { + data[i] = byte(i % 255) + } + + var sent []byte + for len(sent) < 1000*1000 { + // Send random sized chunks + m := rand.Intn(len(data)) + n, err := tcpConn.Write(data[:m]) + if err != nil { + break + } + sent = append(sent, data[:n]...) + } + if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { + t.Errorf("tcpConn.CloseWrite: %v", err) + } + + read := <-readChan + + if len(sent) != len(read) { + t.Fatalf("got %d bytes, want %d", len(read), len(sent)) + } + if bytes.Compare(sent, read) != 0 { + t.Fatalf("read back data does not match") + } + + if err := sshListener.Close(); err != nil { + t.Fatalf("sshListener.Close: %v", err) + } + + // Check that the forward disappeared. + tcpConn, err = net.Dial("tcp", forwardedAddr) + if err == nil { + tcpConn.Close() + t.Errorf("still listening to %s after closing", forwardedAddr) + } +} + +func DISABLED_TestAcceptClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + sshListener.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} + +// Check that listeners exit if the underlying client transport dies. +func DISABLED_TestPortForwardConnectionClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + + // It would be even nicer if we closed the server side, but it + // is more involved as the fd for that side is dup()ed. + server.clientConn.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} diff --git a/lib/ssh/test/session_test.go b/lib/ssh/test/session_test.go new file mode 100644 index 0000000..d80ec2a --- /dev/null +++ b/lib/ssh/test/session_test.go @@ -0,0 +1,365 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package test + +// Session functional tests. + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" + + "github.com/zmap/zgrab/ztools/xssh" +) + +func DISABLED_TestRunCommandSuccess(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func DISABLED_TestHostKeyCheck(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + + // change the keys. + hostDB.keys[xssh.KeyAlgoRSA][25]++ + hostDB.keys[xssh.KeyAlgoDSA][25]++ + hostDB.keys[xssh.KeyAlgoECDSA256][25]++ + + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + t.Fatalf("dial should have failed.") + } else if !strings.Contains(err.Error(), "host key mismatch") { + t.Fatalf("'host key mismatch' not found in %v", err) + } +} + +func DISABLED_TestRunCommandStdin(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + defer w.Close() + session.Stdin = r + + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func DISABLED_TestRunCommandStdinError(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + session.Stdin = r + pipeErr := errors.New("closing write end of pipe") + w.CloseWithError(pipeErr) + + err = session.Run("true") + if err != pipeErr { + t.Fatalf("expected %v, found %v", pipeErr, err) + } +} + +func DISABLED_TestRunCommandFailed(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + err = session.Run(`bash -c "kill -9 $$"`) + if err == nil { + t.Fatalf("session succeeded: %v", err) + } +} + +func DISABLED_TestRunCommandWeClosed(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + err = session.Shell() + if err != nil { + t.Fatalf("shell failed: %v", err) + } + err = session.Close() + if err != nil { + t.Fatalf("shell failed: %v", err) + } +} + +func DISABLED_TestFuncLargeRead(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=2048 count=1024") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + if n != 2048*1024 { + t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n) + } +} + +func DISABLED_TestKeyChange(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + conf.RekeyThreshold = 1024 + conn := server.Dial(conf) + defer conn.Close() + + for i := 0; i < 4; i++ { + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=1024 count=1") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + want := int64(1024) + if n != want { + t.Fatalf("Expected %d bytes but read only %d from remote command", want, n) + } + } + + if changes := hostDB.checkCount; changes < 4 { + t.Errorf("got %d key changes, want 4", changes) + } +} + +func DISABLED_TestInvalidTerminalMode(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + if err = session.RequestPty("vt100", 80, 40, xssh.TerminalModes{255: 1984}); err == nil { + t.Fatalf("req-pty failed: successful request with invalid mode") + } +} + +func DISABLED_TestValidTerminalMode(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("unable to acquire stdin pipe: %s", err) + } + + tm := xssh.TerminalModes{xssh.ECHO: 0} + if err = session.RequestPty("xterm", 80, 40, tm); err != nil { + t.Fatalf("req-pty failed: %s", err) + } + + err = session.Shell() + if err != nil { + t.Fatalf("session failed: %s", err) + } + + stdin.Write([]byte("stty -a && exit\n")) + + var buf bytes.Buffer + if _, err := io.Copy(&buf, stdout); err != nil { + t.Fatalf("reading failed: %s", err) + } + + if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") { + t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput) + } +} + +func DISABLED_TestCiphers(t *testing.T) { + var config xssh.Config + config.SetDefaults() + cipherOrder := config.Ciphers + // These ciphers will not be tested when commented out in cipher.go it will + // fallback to the next available as per line 292. + cipherOrder = append(cipherOrder, "aes128-cbc", "3des-cbc") + + for _, ciph := range cipherOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.Ciphers = []string{ciph} + // Don't fail if sshd doesn't have the cipher. + conf.Ciphers = append(conf.Ciphers, cipherOrder...) + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Fatalf("failed for cipher %q", ciph) + } + } +} + +func DISABLED_TestMACs(t *testing.T) { + var config xssh.Config + config.SetDefaults() + macOrder := config.MACs + + for _, mac := range macOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.MACs = []string{mac} + // Don't fail if sshd doesn't have the MAC. + conf.MACs = append(conf.MACs, macOrder...) + if conn, err := server.TryDial(conf); err == nil { + conn.Close() + } else { + t.Fatalf("failed for MAC %q", mac) + } + } +} + +func DISABLED_TestKeyExchanges(t *testing.T) { + var config xssh.Config + config.SetDefaults() + kexOrder := config.KeyExchanges + for _, kex := range kexOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + // Don't fail if sshd doesn't have the kex. + conf.KeyExchanges = append([]string{kex}, kexOrder...) + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Errorf("failed for kex %q", kex) + } + } +} + +func DISABLED_TestClientAuthAlgorithms(t *testing.T) { + for _, key := range []string{ + "rsa", + "dsa", + "ecdsa", + "ed25519", + } { + server := newServer(t) + conf := clientConfig() + conf.SetDefaults() + conf.Auth = []xssh.AuthMethod{ + xssh.PublicKeys(testSigners[key]), + } + + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Errorf("failed for key %q", key) + } + + server.Shutdown() + } +} diff --git a/lib/ssh/test/tcpip_test.go b/lib/ssh/test/tcpip_test.go new file mode 100644 index 0000000..2fd6fd6 --- /dev/null +++ b/lib/ssh/test/tcpip_test.go @@ -0,0 +1,46 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package test + +// direct-tcpip functional tests + +import ( + "io" + "net" + "testing" +) + +func DISABLED_TestDial(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + sshConn := server.Dial(clientConfig()) + defer sshConn.Close() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer l.Close() + + go func() { + for { + c, err := l.Accept() + if err != nil { + break + } + + io.WriteString(c, c.RemoteAddr().String()) + c.Close() + } + }() + + conn, err := sshConn.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() +} diff --git a/lib/ssh/test/test_unix_test.go b/lib/ssh/test/test_unix_test.go new file mode 100644 index 0000000..eae841a --- /dev/null +++ b/lib/ssh/test/test_unix_test.go @@ -0,0 +1,268 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd plan9 + +package test + +// functional test harness for unix. + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "testing" + "text/template" + + "github.com/zmap/zgrab/ztools/xssh" + "github.com/zmap/zgrab/ztools/xssh/testdata" +) + +const sshd_config = ` +Protocol 2 +HostKey {{.Dir}}/id_rsa +HostKey {{.Dir}}/id_dsa +HostKey {{.Dir}}/id_ecdsa +Pidfile {{.Dir}}/sshd.pid +#UsePrivilegeSeparation no +KeyRegenerationInterval 3600 +ServerKeyBits 768 +SyslogFacility AUTH +LogLevel DEBUG2 +LoginGraceTime 120 +PermitRootLogin no +StrictModes no +RSAAuthentication yes +PubkeyAuthentication yes +AuthorizedKeysFile {{.Dir}}/authorized_keys +TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub +IgnoreRhosts yes +RhostsRSAAuthentication no +HostbasedAuthentication no +PubkeyAcceptedKeyTypes=* +` + +var configTmpl = template.Must(template.New("").Parse(sshd_config)) + +type server struct { + t *testing.T + cleanup func() // executed during Shutdown + configfile string + cmd *exec.Cmd + output bytes.Buffer // holds stderr from sshd process + + // Client half of the network connection. + clientConn net.Conn +} + +func username() string { + var username string + if user, err := user.Current(); err == nil { + username = user.Username + } else { + // user.Current() currently requires cgo. If an error is + // returned attempt to get the username from the environment. + log.Printf("user.Current: %v; falling back on $USER", err) + username = os.Getenv("USER") + } + if username == "" { + panic("Unable to get username") + } + return username +} + +type storedHostKey struct { + // keys map from an algorithm string to binary key data. + keys map[string][]byte + + // checkCount counts the Check calls. Used for testing + // rekeying. + checkCount int +} + +func (k *storedHostKey) Add(key xssh.PublicKey) { + if k.keys == nil { + k.keys = map[string][]byte{} + } + k.keys[key.Type()] = key.Marshal() +} + +func (k *storedHostKey) Check(addr string, remote net.Addr, key xssh.PublicKey) error { + k.checkCount++ + algo := key.Type() + + if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 { + return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) + } + return nil +} + +func hostKeyDB() *storedHostKey { + keyChecker := &storedHostKey{} + keyChecker.Add(testPublicKeys["ecdsa"]) + keyChecker.Add(testPublicKeys["rsa"]) + keyChecker.Add(testPublicKeys["dsa"]) + return keyChecker +} + +func clientConfig() *xssh.ClientConfig { + config := &xssh.ClientConfig{ + User: username(), + Auth: []xssh.AuthMethod{ + xssh.PublicKeys(testSigners["user"]), + }, + HostKeyCallback: hostKeyDB().Check, + } + return config +} + +// unixConnection creates two halves of a connected net.UnixConn. It +// is used for connecting the Go SSH client with sshd without opening +// ports. +func unixConnection() (*net.UnixConn, *net.UnixConn, error) { + dir, err := ioutil.TempDir("", "unixConnection") + if err != nil { + return nil, nil, err + } + defer os.Remove(dir) + + addr := filepath.Join(dir, "ssh") + listener, err := net.Listen("unix", addr) + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("unix", addr) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1.(*net.UnixConn), c2.(*net.UnixConn), nil +} + +func (s *server) TryDial(config *xssh.ClientConfig) (*xssh.Client, error) { + sshd, err := exec.LookPath("sshd") + if err != nil { + s.t.Skipf("skipping test: %v", err) + } + + c1, c2, err := unixConnection() + if err != nil { + s.t.Fatalf("unixConnection: %v", err) + } + + s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") + f, err := c2.File() + if err != nil { + s.t.Fatalf("UnixConn.File: %v", err) + } + defer f.Close() + s.cmd.Stdin = f + s.cmd.Stdout = f + s.cmd.Stderr = &s.output + if err := s.cmd.Start(); err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("s.cmd.Start: %v", err) + } + s.clientConn = c1 + conn, chans, reqs, err := xssh.NewClientConn(c1, "", config) + if err != nil { + return nil, err + } + return xssh.NewClient(conn, chans, reqs), nil +} + +func (s *server) Dial(config *xssh.ClientConfig) *xssh.Client { + conn, err := s.TryDial(config) + if err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("xssh.Client: %v", err) + } + return conn +} + +func (s *server) Shutdown() { + if s.cmd != nil && s.cmd.Process != nil { + // Don't check for errors; if it fails it's most + // likely "os: process already finished", and we don't + // care about that. Use os.Interrupt, so child + // processes are killed too. + s.cmd.Process.Signal(os.Interrupt) + s.cmd.Wait() + } + if s.t.Failed() { + // log any output from sshd process + s.t.Logf("sshd: %s", s.output.String()) + } + s.cleanup() +} + +func writeFile(path string, contents []byte) { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) + if err != nil { + panic(err) + } + defer f.Close() + if _, err := f.Write(contents); err != nil { + panic(err) + } +} + +// newServer returns a new mock ssh server. +func newServer(t *testing.T) *server { + if testing.Short() { + t.Skip("skipping test due to -short") + } + dir, err := ioutil.TempDir("", "sshtest") + if err != nil { + t.Fatal(err) + } + f, err := os.Create(filepath.Join(dir, "sshd_config")) + if err != nil { + t.Fatal(err) + } + err = configTmpl.Execute(f, map[string]string{ + "Dir": dir, + }) + if err != nil { + t.Fatal(err) + } + f.Close() + + for k, v := range testdata.PEMBytes { + filename := "id_" + k + writeFile(filepath.Join(dir, filename), v) + writeFile(filepath.Join(dir, filename+".pub"), xssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + + var authkeys bytes.Buffer + for k := range testdata.PEMBytes { + authkeys.Write(xssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) + + return &server{ + t: t, + configfile: f.Name(), + cleanup: func() { + if err := os.RemoveAll(dir); err != nil { + t.Error(err) + } + }, + } +} diff --git a/lib/ssh/test/testdata_test.go b/lib/ssh/test/testdata_test.go new file mode 100644 index 0000000..cdf93b7 --- /dev/null +++ b/lib/ssh/test/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package test + +import ( + "crypto/rand" + "fmt" + + "github.com/zmap/zgrab/ztools/xssh" + "github.com/zmap/zgrab/ztools/xssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]xssh.Signer + testPublicKeys map[string]xssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]xssh.Signer, n) + testPublicKeys = make(map[string]xssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = xssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = xssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &xssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: xssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: xssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = xssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/lib/ssh/testdata/doc.go b/lib/ssh/testdata/doc.go new file mode 100644 index 0000000..163c448 --- /dev/null +++ b/lib/ssh/testdata/doc.go @@ -0,0 +1,8 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains test data shared between the various subpackages of +// the github.com/zmap/zgrab/ztools/xssh package. Under no circumstance should +// this data be used for production code. +package testdata // import "github.com/zmap/zgrab/ztools/xssh/testdata" diff --git a/lib/ssh/testdata/keys.go b/lib/ssh/testdata/keys.go new file mode 100644 index 0000000..736dad9 --- /dev/null +++ b/lib/ssh/testdata/keys.go @@ -0,0 +1,120 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testdata + +var PEMBytes = map[string][]byte{ + "dsa": []byte(`-----BEGIN DSA PRIVATE KEY----- +MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB +lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3 +EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD +nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV +2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r +juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr +FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz +DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj +nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY +Fmsr0W6fHB9nhS4/UXM8 +-----END DSA PRIVATE KEY----- +`), + "ecdsa": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 +AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ +6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== +-----END EC PRIVATE KEY----- +`), + "rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2 +a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8 +Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQIDAQAB +AoGAJMCk5vqfSRzyXOTXLGIYCuR4Kj6pdsbNSeuuRGfYBeR1F2c/XdFAg7D/8s5R +38p/Ih52/Ty5S8BfJtwtvgVY9ecf/JlU/rl/QzhG8/8KC0NG7KsyXklbQ7gJT8UT +Ojmw5QpMk+rKv17ipDVkQQmPaj+gJXYNAHqImke5mm/K/h0CQQDciPmviQ+DOhOq +2ZBqUfH8oXHgFmp7/6pXw80DpMIxgV3CwkxxIVx6a8lVH9bT/AFySJ6vXq4zTuV9 +6QmZcZzDAkEA2j/UXJPIs1fQ8z/6sONOkU/BjtoePFIWJlRxdN35cZjXnBraX5UR +fFHkePv4YwqmXNqrBOvSu+w2WdSDci+IKwJAcsPRc/jWmsrJW1q3Ha0hSf/WG/Bu +X7MPuXaKpP/DkzGoUmb8ks7yqj6XWnYkPNLjCc8izU5vRwIiyWBRf4mxMwJBAILa +NDvRS0rjwt6lJGv7zPZoqDc65VfrK2aNyHx2PgFyzwrEOtuF57bu7pnvEIxpLTeM +z26i6XVMeYXAWZMTloMCQBbpGgEERQpeUknLBqUHhg/wXF6+lFA+vEGnkY+Dwab2 +KCXFGd+SQ5GdUcEMe9isUH6DYj/6/yCDoFrXXmpQb+M= +-----END RSA PRIVATE KEY----- +`), + "ed25519": []byte(`-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACA+3f7hS7g5UWwXOGVTrMfhmxyrjqz7Sxxbx7I1j8DvvwAAAJhAFfkOQBX5 +DgAAAAtzc2gtZWQyNTUxOQAAACA+3f7hS7g5UWwXOGVTrMfhmxyrjqz7Sxxbx7I1j8Dvvw +AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb +HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg== +-----END OPENSSH PRIVATE KEY----- +`), + "user": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 +AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD +PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== +-----END EC PRIVATE KEY----- +`), +} + +var PEMEncryptedKeys = []struct { + Name string + EncryptionKey string + PEMBytes []byte +}{ + 0: { + Name: "rsa-encrypted", + EncryptionKey: "r54-G0pher_t3st$", + PEMBytes: []byte(`-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-128-CBC,3E1714DE130BC5E81327F36564B05462 + +MqW88sud4fnWk/Jk3fkjh7ydu51ZkHLN5qlQgA4SkAXORPPMj2XvqZOv1v2LOgUV +dUevUn8PZK7a9zbZg4QShUSzwE5k6wdB7XKPyBgI39mJ79GBd2U4W3h6KT6jIdWA +goQpluxkrzr2/X602IaxLEre97FT9mpKC6zxKCLvyFWVIP9n3OSFS47cTTXyFr+l +7PdRhe60nn6jSBgUNk/Q1lAvEQ9fufdPwDYY93F1wyJ6lOr0F1+mzRrMbH67NyKs +rG8J1Fa7cIIre7ueKIAXTIne7OAWqpU9UDgQatDtZTbvA7ciqGsSFgiwwW13N+Rr +hN8MkODKs9cjtONxSKi05s206A3NDU6STtZ3KuPDjFE1gMJODotOuqSM+cxKfyFq +wxpk/CHYCDdMAVBSwxb/vraOHamylL4uCHpJdBHypzf2HABt+lS8Su23uAmL87DR +yvyCS/lmpuNTndef6qHPRkoW2EV3xqD3ovosGf7kgwGJUk2ZpCLVteqmYehKlZDK +r/Jy+J26ooI2jIg9bjvD1PZq+Mv+2dQ1RlDrPG3PB+rEixw6vBaL9x3jatCd4ej7 +XG7lb3qO9xFpLsx89tkEcvpGR+broSpUJ6Mu5LBCVmrvqHjvnDhrZVz1brMiQtU9 +iMZbgXqDLXHd6ERWygk7OTU03u+l1gs+KGMfmS0h0ZYw6KGVLgMnsoxqd6cFSKNB +8Ohk9ZTZGCiovlXBUepyu8wKat1k8YlHSfIHoRUJRhhcd7DrmojC+bcbMIZBU22T +Pl2ftVRGtcQY23lYd0NNKfebF7ncjuLWQGy+vZW+7cgfI6wPIbfYfP6g7QAutk6W +KQx0AoX5woZ6cNxtpIrymaVjSMRRBkKQrJKmRp3pC/lul5E5P2cueMs1fj4OHTbJ +lAUv88ywr+R+mRgYQlFW/XQ653f6DT4t6+njfO9oBcPrQDASZel3LjXLpjjYG/N5 ++BWnVexuJX9ika8HJiFl55oqaKb+WknfNhk5cPY+x7SDV9ywQeMiDZpr0ffeYAEP +LlwwiWRDYpO+uwXHSFF3+JjWwjhs8m8g99iFb7U93yKgBB12dCEPPa2ZeH9wUHMJ +sreYhNuq6f4iWWSXpzN45inQqtTi8jrJhuNLTT543ErW7DtntBO2rWMhff3aiXbn +Uy3qzZM1nPbuCGuBmP9L2dJ3Z5ifDWB4JmOyWY4swTZGt9AVmUxMIKdZpRONx8vz +I9u9nbVPGZBcou50Pa0qTLbkWsSL94MNXrARBxzhHC9Zs6XNEtwN7mOuii7uMkVc +adrxgknBH1J1N+NX/eTKzUwJuPvDtA+Z5ILWNN9wpZT/7ed8zEnKHPNUexyeT5g3 +uw9z9jH7ffGxFYlx87oiVPHGOrCXYZYW5uoZE31SCBkbtNuffNRJRKIFeipmpJ3P +7bpAG+kGHMelQH6b+5K1Qgsv4tpuSyKeTKpPFH9Av5nN4P1ZBm9N80tzbNWqjSJm +S7rYdHnuNEVnUGnRmEUMmVuYZnNBEVN/fP2m2SEwXcP3Uh7TiYlcWw10ygaGmOr7 +MvMLGkYgQ4Utwnd98mtqa0jr0hK2TcOSFir3AqVvXN3XJj4cVULkrXe4Im1laWgp +-----END RSA PRIVATE KEY----- +`), + }, + + 1: { + Name: "dsa-encrypted", + EncryptionKey: "qG0pher-dsa_t3st$", + PEMBytes: []byte(`-----BEGIN DSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-128-CBC,7CE7A6E4A647DC01AF860210B15ADE3E + +hvnBpI99Hceq/55pYRdOzBLntIEis02JFNXuLEydWL+RJBFDn7tA+vXec0ERJd6J +G8JXlSOAhmC2H4uK3q2xR8/Y3yL95n6OIcjvCBiLsV+o3jj1MYJmErxP6zRtq4w3 +JjIjGHWmaYFSxPKQ6e8fs74HEqaeMV9ONUoTtB+aISmgaBL15Fcoayg245dkBvVl +h5Kqspe7yvOBmzA3zjRuxmSCqKJmasXM7mqs3vIrMxZE3XPo1/fWKcPuExgpVQoT +HkJZEoIEIIPnPMwT2uYbFJSGgPJVMDT84xz7yvjCdhLmqrsXgs5Qw7Pw0i0c0BUJ +b7fDJ2UhdiwSckWGmIhTLlJZzr8K+JpjCDlP+REYBI5meB7kosBnlvCEHdw2EJkH +0QDc/2F4xlVrHOLbPRFyu1Oi2Gvbeoo9EsM/DThpd1hKAlb0sF5Y0y0d+owv0PnE +R/4X3HWfIdOHsDUvJ8xVWZ4BZk9Zk9qol045DcFCehpr/3hslCrKSZHakLt9GI58 +vVQJ4L0aYp5nloLfzhViZtKJXRLkySMKdzYkIlNmW1oVGl7tce5UCNI8Nok4j6yn +IiHM7GBn+0nJoKTXsOGMIBe3ulKlKVxLjEuk9yivh/8= +-----END DSA PRIVATE KEY----- +`), + }, +} diff --git a/lib/ssh/testdata_test.go b/lib/ssh/testdata_test.go new file mode 100644 index 0000000..e0fded6 --- /dev/null +++ b/lib/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package xssh + +import ( + "crypto/rand" + "fmt" + + "github.com/zmap/zgrab/ztools/xssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/lib/ssh/transport.go b/lib/ssh/transport.go new file mode 100644 index 0000000..a1c0851 --- /dev/null +++ b/lib/ssh/transport.go @@ -0,0 +1,333 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bufio" + "errors" + "io" +) + +const ( + gcmCipherID = "aes128-gcm@openssh.com" + aes128cbcID = "aes128-cbc" + tripledescbcID = "3des-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + + io.Closer +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writePacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { + return err + } else { + t.reader.pendingKeyChange <- ciph + } + + if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { + return err + } else { + t.writer.pendingKeyChange <- ciph + } + + return nil +} + +// Read and decrypt next packet. +func (t *transport) readPacket() ([]byte, error) { + return t.reader.readPacket(t.bufReader) +} + +func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { + packet, err := s.packetCipher.readPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + return nil, errors.New("ssh: got bogus newkeys message.") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + return t.writer.writePacket(t.bufWriter, t.rand, packet) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// generateKeys generates key material for IV, MAC and encryption. +func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { + cipherMode := cipherModes[algs.Cipher] + macMode := macModes[algs.MAC] + + iv = make([]byte, cipherMode.ivSize) + key = make([]byte, cipherMode.keySize) + macKey = make([]byte, macMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + generateKeyMaterial(macKey, d.macKeyTag, kex) + return +} + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + iv, key, macKey := generateKeys(d, algs, kex) + + if algs.Cipher == gcmCipherID { + return newGCMCipher(iv, key, macKey) + } + + if algs.Cipher == aes128cbcID { + return newAESCBCCipher(iv, key, macKey, algs) + } + + if algs.Cipher == tripledescbcID { + return newTripleDESCBCCipher(iv, key, macKey, algs) + } + + c := &streamPacketCipher{ + mac: macModes[algs.MAC].new(macKey), + } + c.macResult = make([]byte, c.mac.Size()) + + var err error + c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) + if err != nil { + return nil, err + } + + return c, nil +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for len(versionString) < maxVersionStringBytes { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} diff --git a/lib/ssh/transport_test.go b/lib/ssh/transport_test.go new file mode 100644 index 0000000..d77703f --- /dev/null +++ b/lib/ssh/transport_test.go @@ -0,0 +1,109 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xssh + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "strings" + "testing" +) + +func TestReadVersion(t *testing.T) { + longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] + cases := map[string]string{ + "SSH-2.0-bla\r\n": "SSH-2.0-bla", + "SSH-2.0-bla\n": "SSH-2.0-bla", + longversion + "\r\n": longversion, + } + + for in, want := range cases { + result, err := readVersion(bytes.NewBufferString(in)) + if err != nil { + t.Errorf("readVersion(%q): %s", in, err) + } + got := string(result) + if got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func TestReadVersionError(t *testing.T) { + longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] + cases := []string{ + longversion + "too-long\r\n", + } + for _, in := range cases { + if _, err := readVersion(bytes.NewBufferString(in)); err == nil { + t.Errorf("readVersion(%q) should have failed", in) + } + } +} + +func TestExchangeVersionsBasic(t *testing.T) { + v := "SSH-2.0-bla" + buf := bytes.NewBufferString(v + "\r\n") + them, err := exchangeVersions(buf, []byte("xyz")) + if err != nil { + t.Errorf("exchangeVersions: %v", err) + } + + if want := "SSH-2.0-bla"; string(them) != want { + t.Errorf("got %q want %q for our version", them, want) + } +} + +func TestExchangeVersions(t *testing.T) { + cases := []string{ + "not\x000allowed", + "not allowed\n", + } + for _, c := range cases { + buf := bytes.NewBufferString("SSH-2.0-bla\r\n") + if _, err := exchangeVersions(buf, []byte(c)); err == nil { + t.Errorf("exchangeVersions(%q): should have failed", c) + } + } +} + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +}