package postgres import ( "encoding/binary" "encoding/hex" "fmt" "io" "net" "strconv" "strings" log "github.com/sirupsen/logrus" "github.com/zmap/zgrab2" ) // Don't allow unbounded reads const maxPacketSize = 512 * 1024 // Don't read an unlimited number of tag/value pairs from the server const maxReadAllPackets = 64 // Connection wraps the state of a given connection to a server. type Connection struct { // Target is the requested scan target. Target *zgrab2.ScanTarget // Connection is the underlying TCP (or TLS) stream. Connection net.Conn // Config contains the flags from the command line. Config *Flags // IsSSL is true if Connection is a TLS connection. IsSSL bool } // ServerPacket is a direct representation of the response packet // returned by the server. // See e.g. https://www.postgresql.org/docs/9.6/static/protocol-message-formats.html // The first byte is a message type, an alphanumeric character. // The following four bytes are the length of the message body. // The following bytes are the message itself. // In certain special cases, the Length can be 0; for instance, a // response to an SSLRequest is only a S/N Type with no length / body, // while pre-startup errors can be a E Type followed by a \n\0- // terminated string. type ServerPacket struct { Type byte Length uint32 Body []byte } // ToString is used in logging, to get a human-readable representation // of the packet. func (p *ServerPacket) ToString() string { // TODO: Don't hex-encode human-readable bodies? return fmt.Sprintf("{ ServerPacket(%p): { Type: '%c', Length: %d, Body: [[%d bytes]] } }", &p, p.Type, p.Length, len(p.Body)) } // Send a client packet: a big-endian uint32 length followed by a body. func (c *Connection) Send(body []byte) error { toSend := make([]byte, len(body)+4) copy(toSend[4:], body) // The length contains the length of the length, hence the +4. binary.BigEndian.PutUint32(toSend[0:], uint32(len(body)+4)) // @TODO: Buffered send? _, err := c.Connection.Write(toSend) return err } // SendU32 sends an uint32 packet to the server. func (c *Connection) SendU32(val uint32) error { toSend := make([]byte, 8) binary.BigEndian.PutUint32(toSend[0:], uint32(8)) binary.BigEndian.PutUint32(toSend[4:], val) // @TODO: Buffered send? _, err := c.Connection.Write(toSend) return err } // Close out the underlying TCP connection to the server. func (c *Connection) Close() error { return c.Connection.Close() } // tryReadPacket tries to read a length + body from the connection. func (c *Connection) tryReadPacket(header byte) (*ServerPacket, *zgrab2.ScanError) { var length [4]byte _, err := io.ReadFull(c.Connection, length[:]) if err != nil && err != io.EOF { return nil, zgrab2.DetectScanError(err) } bodyLen := binary.BigEndian.Uint32(length[:]) if length[0] > 0x00 { // For scanning purposes, there is no reason we want to read more than 2^24 bytes // But in practice, it probably means we have a null-terminated error string var buf [1024]byte n, err := c.Connection.Read(buf[:]) if err != nil && err != io.EOF { return nil, zgrab2.DetectScanError(err) } if n < 2 { return nil, zgrab2.NewScanError(zgrab2.SCAN_PROTOCOL_ERROR, fmt.Errorf("Server returned too little data (%d bytes: %s)", n, hex.EncodeToString(buf[:n]))) } if string(buf[n-2:n]) == "\x0a\x00" { return &ServerPacket{ Type: header, Length: 0, Body: append(length[:], buf[:n]...), }, nil } return nil, zgrab2.NewScanError(zgrab2.SCAN_PROTOCOL_ERROR, fmt.Errorf("Server returned too much data: length = 0x%x; first %d bytes = %s", bodyLen, n, hex.EncodeToString(buf[:n]))) } sizeToRead := bodyLen if sizeToRead > maxPacketSize { log.Debugf("postgres server %s reported packet size of %d bytes; only reading %d bytes.", c.Target.String(), bodyLen, maxPacketSize) sizeToRead = maxPacketSize } body := make([]byte, sizeToRead - 4) // Length includes the length of the Length uint32 _, err = io.ReadFull(c.Connection, body) if err != nil && err != io.EOF { return nil, zgrab2.DetectScanError(err) } if sizeToRead < bodyLen && len(body) + 4 >= maxPacketSize { // Warn if we actually truncate (as opposed getting a huge length but only a few bytes are actually available) log.Warnf("Truncated postgres packet from %s: advertised size = %d bytes, read size = %d bytes", c.Target.String(), bodyLen, len(body)) } return &ServerPacket{ Type: header, Length: bodyLen, Body: body, }, nil } // RequestSSL sends an SSLRequest packet to the server, and returns true // if and only if the server reports that it is SSL-capable. Otherwise // it returns false and possibly an error. func (c *Connection) RequestSSL() (bool, *zgrab2.ScanError) { // NOTE: The SSLRequest request type was introduced in version 7.2, released in 2002 (though the oldest supported version is 9.3, released 2013-09-09) if err := c.SendU32(postgresSSLRequest); err != nil { return false, zgrab2.DetectScanError(err) } var header [1]byte _, err := io.ReadFull(c.Connection, header[0:1]) if err != nil { return false, zgrab2.DetectScanError(err) } if header[0] < '0' || header[0] > 'z' { // Back-end messages always start with the alphanumeric Byte1 value // We could further constrain this to currently-valid message types, but then we may incorrectly reject future versions return false, zgrab2.NewScanError(zgrab2.SCAN_PROTOCOL_ERROR, fmt.Errorf("Response message type 0x%02x was not an alphanumeric character", header[0])) } switch header[0] { case 'N': return false, nil case 'S': return true, nil } // It was neither a single 'N' / 'S', so it's a failure -- at this point it's just a question of determining if it's an application error (valid packet) or a protocol error packet, scanError := c.tryReadPacket(header[0]) if scanError != nil { return false, scanError } switch packet.Type { case 'E': return false, zgrab2.NewScanError(zgrab2.SCAN_APPLICATION_ERROR, fmt.Errorf("Application rejected SSLRequest packet -- response = %s", packet.ToString())) default: // Returning PROTOCOL_ERROR here since any garbage data that starts with a small-ish u32 could be a valid packet, and no known server versions return anything beyond S/N/E. return false, zgrab2.NewScanError(zgrab2.SCAN_PROTOCOL_ERROR, fmt.Errorf("Unexpected response type '%c' from server (full response = %s)", packet.Type, packet.ToString())) } } // ReadPacket reads a ServerPacket from the server. func (c *Connection) ReadPacket() (*ServerPacket, *zgrab2.ScanError) { var header [1]byte _, err := io.ReadFull(c.Connection, header[0:1]) if err != nil { return nil, zgrab2.DetectScanError(err) } if header[0] < '0' || header[0] > 'z' { // Back-end messages always start with the alphanumeric Byte1 value // We could further constrain this to currently-valid message types, but then we may incorrectly reject future versions return nil, zgrab2.NewScanError(zgrab2.SCAN_PROTOCOL_ERROR, fmt.Errorf("Response message type 0x%02x was not an alphanumeric character", header[0])) } return c.tryReadPacket(header[0]) } // GetTLSLog gets the connection's TLSLog, or nil if the connection has // not yet been set up as TLS. func (c *Connection) GetTLSLog() *zgrab2.TLSLog { if !c.IsSSL { return nil } return c.Connection.(*zgrab2.TLSConnection).GetLog() } // encodeMap encodes a map into a byte array of the form // "key0\0value\0key1\0value1\0...keyN\0valueN\0\0" func encodeMap(dict map[string]string) []byte { var strs []string for k, v := range dict { strs = append(strs, k) strs = append(strs, v) } return append([]byte(strings.Join(strs, "\x00")), 0x00, 0x00) } // SendStartupMessage creates and sends a StartupMessage. // The format is uint16 Major + uint16 Minor + (key/value pairs). func (c *Connection) SendStartupMessage(version string, kvps map[string]string) error { dict := encodeMap(kvps) ret := make([]byte, len(dict)+4) parts := strings.Split(version, ".") if len(parts) == 1 { parts = []string{parts[0], "0"} } major, err := strconv.ParseUint(parts[0], 0, 16) if err != nil { log.Fatalf("Error parsing major version %s as a uint16: %v", parts[0], err) } minor, err := strconv.ParseUint(parts[1], 0, 16) if err != nil { log.Fatalf("Error parsing minor version %s as a uint16: %v", parts[1], err) } binary.BigEndian.PutUint16(ret[0:2], uint16(major)) binary.BigEndian.PutUint16(ret[2:4], uint16(minor)) copy(ret[4:], dict) return c.Send(ret) } // ReadAll reads packets from the given connection until it hits a // timeout, EOF, or a 'Z' packet. func (c *Connection) ReadAll() ([]*ServerPacket, *zgrab2.ScanError) { var ret []*ServerPacket for { response, readError := c.ReadPacket() if readError != nil { if readError.Status == zgrab2.SCAN_IO_TIMEOUT || readError.Err == io.EOF { return ret, nil } return ret, readError } ret = append(ret, response) if response.Type == 'Z' { return ret, nil } if len(ret) > maxReadAllPackets { log.Warnf("Server %s returned more than %d packets -- truncating.", c.Target.String(), maxReadAllPackets) return ret, nil } } } // connectionManager is a utility for getting connections and ensuring // that they all get closed. // TODO: Is there something like this in the standard libraries? type connectionManager struct { connections map[io.Closer]bool } // addConnection adds a managed connection. func (m *connectionManager) addConnection(c io.Closer) { m.connections[c] = true } func (m *connectionManager) closeConnection(c io.Closer) { if m.connections[c] { m.connections[c] = false err := c.Close() if err != nil { log.Debugf("Got error closing connection: %v", err) } } } // cleanUp closes all managed connections. func (m *connectionManager) cleanUp() { // first in, last out: empty out the map defer func() { for conn, _ := range m.connections { delete(m.connections, conn) } }() for connection, _ := range m.connections { // Close them all even if there is a panic with one defer func(c io.Closer) { m.closeConnection(c) }(connection) } } // Get a new connectionmanager instance. func newConnectionManager() *connectionManager { return &connectionManager{ connections: make(map[io.Closer]bool), } }