2018-01-15 19:24:57 +00:00
package postgres
import (
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"net"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/zmap/zgrab2"
)
2018-05-14 15:24:25 +00:00
// Don't allow unbounded reads
2018-05-15 19:47:14 +00:00
const maxPacketSize = 512 * 1024
2018-05-25 21:01:31 +00:00
const maxOutputSize = 1024
2018-05-15 19:47:14 +00:00
// Don't read an unlimited number of tag/value pairs from the server
const maxReadAllPackets = 64
2018-05-14 15:24:25 +00:00
2018-02-09 16:09:44 +00:00
// Connection wraps the state of a given connection to a server.
2018-01-15 19:24:57 +00:00
type Connection struct {
2018-05-14 15:24:25 +00:00
// Target is the requested scan target.
Target * zgrab2 . ScanTarget
2018-02-09 16:09:44 +00:00
// Connection is the underlying TCP (or TLS) stream.
2018-01-15 19:24:57 +00:00
Connection net . Conn
2018-02-09 16:09:44 +00:00
// Config contains the flags from the command line.
2018-02-09 16:49:15 +00:00
Config * Flags
2018-02-09 16:09:44 +00:00
// IsSSL is true if Connection is a TLS connection.
2018-02-09 16:49:15 +00:00
IsSSL bool
2018-01-15 19:24:57 +00:00
}
2018-02-09 16:09:44 +00:00
// ServerPacket is a direct representation of the response packet
// returned by the server.
// See e.g. https://www.postgresql.org/docs/9.6/static/protocol-message-formats.html
2018-01-15 19:24:57 +00:00
// The first byte is a message type, an alphanumeric character.
// The following four bytes are the length of the message body.
// The following <length> bytes are the message itself.
2018-02-09 16:09:44 +00:00
// In certain special cases, the Length can be 0; for instance, a
// response to an SSLRequest is only a S/N Type with no length / body,
// while pre-startup errors can be a E Type followed by a \n\0-
// terminated string.
2018-01-15 19:24:57 +00:00
type ServerPacket struct {
Type byte
Length uint32
Body [ ] byte
}
2018-02-09 16:09:44 +00:00
// ToString is used in logging, to get a human-readable representation
// of the packet.
2018-01-15 19:24:57 +00:00
func ( p * ServerPacket ) ToString ( ) string {
// TODO: Don't hex-encode human-readable bodies?
2018-05-17 18:36:51 +00:00
return fmt . Sprintf ( "{ ServerPacket(%p): { Type: '%c', Length: %d, Body: [[%d bytes]] } }" , & p , p . Type , p . Length , len ( p . Body ) )
2018-01-15 19:24:57 +00:00
}
2018-05-25 21:01:31 +00:00
// OutputValue is the value that is stored for unexpected / unrecognized data.
func ( p * ServerPacket ) OutputValue ( ) string {
l := len ( p . Body )
if len ( p . Body ) > maxOutputSize {
l = maxOutputSize
}
body := hex . EncodeToString ( p . Body [ : l ] )
if p . Length - 4 > uint32 ( l ) {
body = body + "..."
}
return fmt . Sprintf ( "%c: 0x%08x: %s" , p . Type , p . Length , body )
}
// ToError gets a PostgresError version of OutputValue.
func ( p * ServerPacket ) ToError ( ) * PostgresError {
return & PostgresError {
"severity" : "unexpected" ,
"code" : "unexpected error format" ,
"detail" : p . OutputValue ( ) ,
}
}
2018-02-09 16:09:44 +00:00
// Send a client packet: a big-endian uint32 length followed by a body.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) Send ( body [ ] byte ) error {
toSend := make ( [ ] byte , len ( body ) + 4 )
copy ( toSend [ 4 : ] , body )
// The length contains the length of the length, hence the +4.
binary . BigEndian . PutUint32 ( toSend [ 0 : ] , uint32 ( len ( body ) + 4 ) )
// @TODO: Buffered send?
_ , err := c . Connection . Write ( toSend )
return err
}
2018-02-09 16:09:44 +00:00
// SendU32 sends an uint32 packet to the server.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) SendU32 ( val uint32 ) error {
toSend := make ( [ ] byte , 8 )
binary . BigEndian . PutUint32 ( toSend [ 0 : ] , uint32 ( 8 ) )
binary . BigEndian . PutUint32 ( toSend [ 4 : ] , val )
// @TODO: Buffered send?
_ , err := c . Connection . Write ( toSend )
return err
}
2018-02-09 16:09:44 +00:00
// Close out the underlying TCP connection to the server.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) Close ( ) error {
return c . Connection . Close ( )
}
2018-02-09 16:09:44 +00:00
// tryReadPacket tries to read a length + body from the connection.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) tryReadPacket ( header byte ) ( * ServerPacket , * zgrab2 . ScanError ) {
var length [ 4 ] byte
_ , err := io . ReadFull ( c . Connection , length [ : ] )
if err != nil && err != io . EOF {
return nil , zgrab2 . DetectScanError ( err )
}
2018-05-17 18:36:51 +00:00
bodyLen := binary . BigEndian . Uint32 ( length [ : ] )
2018-01-15 19:24:57 +00:00
if length [ 0 ] > 0x00 {
// For scanning purposes, there is no reason we want to read more than 2^24 bytes
// But in practice, it probably means we have a null-terminated error string
var buf [ 1024 ] byte
n , err := c . Connection . Read ( buf [ : ] )
if err != nil && err != io . EOF {
return nil , zgrab2 . DetectScanError ( err )
}
2018-05-11 16:21:11 +00:00
if n < 2 {
return nil , zgrab2 . NewScanError ( zgrab2 . SCAN_PROTOCOL_ERROR , fmt . Errorf ( "Server returned too little data (%d bytes: %s)" , n , hex . EncodeToString ( buf [ : n ] ) ) )
}
2018-01-15 19:24:57 +00:00
if string ( buf [ n - 2 : n ] ) == "\x0a\x00" {
2018-05-17 18:36:51 +00:00
return & ServerPacket {
Type : header ,
Length : 0 ,
Body : append ( length [ : ] , buf [ : n ] ... ) ,
} , nil
2018-01-15 19:24:57 +00:00
}
2018-05-17 18:36:51 +00:00
return nil , zgrab2 . NewScanError ( zgrab2 . SCAN_PROTOCOL_ERROR , fmt . Errorf ( "Server returned too much data: length = 0x%x; first %d bytes = %s" , bodyLen , n , hex . EncodeToString ( buf [ : n ] ) ) )
2018-01-15 19:24:57 +00:00
}
2018-05-17 18:36:51 +00:00
sizeToRead := bodyLen
2018-05-14 15:24:25 +00:00
if sizeToRead > maxPacketSize {
2018-05-17 18:36:51 +00:00
log . Debugf ( "postgres server %s reported packet size of %d bytes; only reading %d bytes." , c . Target . String ( ) , bodyLen , maxPacketSize )
2018-05-14 15:24:25 +00:00
sizeToRead = maxPacketSize
}
2018-05-17 18:36:51 +00:00
body := make ( [ ] byte , sizeToRead - 4 ) // Length includes the length of the Length uint32
_ , err = io . ReadFull ( c . Connection , body )
2018-01-15 19:24:57 +00:00
if err != nil && err != io . EOF {
return nil , zgrab2 . DetectScanError ( err )
}
2018-05-17 18:36:51 +00:00
if sizeToRead < bodyLen && len ( body ) + 4 >= maxPacketSize {
2018-05-15 18:16:51 +00:00
// Warn if we actually truncate (as opposed getting a huge length but only a few bytes are actually available)
2018-05-17 18:36:51 +00:00
log . Warnf ( "Truncated postgres packet from %s: advertised size = %d bytes, read size = %d bytes" , c . Target . String ( ) , bodyLen , len ( body ) )
2018-05-14 15:24:25 +00:00
}
2018-05-17 18:36:51 +00:00
return & ServerPacket {
Type : header ,
Length : bodyLen ,
Body : body ,
} , nil
2018-01-15 19:24:57 +00:00
}
2018-02-09 16:09:44 +00:00
// RequestSSL sends an SSLRequest packet to the server, and returns true
// if and only if the server reports that it is SSL-capable. Otherwise
// it returns false and possibly an error.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) RequestSSL ( ) ( bool , * zgrab2 . ScanError ) {
// NOTE: The SSLRequest request type was introduced in version 7.2, released in 2002 (though the oldest supported version is 9.3, released 2013-09-09)
if err := c . SendU32 ( postgresSSLRequest ) ; err != nil {
return false , zgrab2 . DetectScanError ( err )
}
var header [ 1 ] byte
_ , err := io . ReadFull ( c . Connection , header [ 0 : 1 ] )
if err != nil {
return false , zgrab2 . DetectScanError ( err )
}
if header [ 0 ] < '0' || header [ 0 ] > 'z' {
// Back-end messages always start with the alphanumeric Byte1 value
// We could further constrain this to currently-valid message types, but then we may incorrectly reject future versions
return false , zgrab2 . NewScanError ( zgrab2 . SCAN_PROTOCOL_ERROR , fmt . Errorf ( "Response message type 0x%02x was not an alphanumeric character" , header [ 0 ] ) )
}
switch header [ 0 ] {
case 'N' :
return false , nil
case 'S' :
return true , nil
}
// It was neither a single 'N' / 'S', so it's a failure -- at this point it's just a question of determining if it's an application error (valid packet) or a protocol error
packet , scanError := c . tryReadPacket ( header [ 0 ] )
if scanError != nil {
return false , scanError
}
switch packet . Type {
case 'E' :
return false , zgrab2 . NewScanError ( zgrab2 . SCAN_APPLICATION_ERROR , fmt . Errorf ( "Application rejected SSLRequest packet -- response = %s" , packet . ToString ( ) ) )
default :
// Returning PROTOCOL_ERROR here since any garbage data that starts with a small-ish u32 could be a valid packet, and no known server versions return anything beyond S/N/E.
return false , zgrab2 . NewScanError ( zgrab2 . SCAN_PROTOCOL_ERROR , fmt . Errorf ( "Unexpected response type '%c' from server (full response = %s)" , packet . Type , packet . ToString ( ) ) )
}
}
2018-02-09 16:09:44 +00:00
// ReadPacket reads a ServerPacket from the server.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) ReadPacket ( ) ( * ServerPacket , * zgrab2 . ScanError ) {
var header [ 1 ] byte
_ , err := io . ReadFull ( c . Connection , header [ 0 : 1 ] )
if err != nil {
return nil , zgrab2 . DetectScanError ( err )
}
if header [ 0 ] < '0' || header [ 0 ] > 'z' {
// Back-end messages always start with the alphanumeric Byte1 value
// We could further constrain this to currently-valid message types, but then we may incorrectly reject future versions
return nil , zgrab2 . NewScanError ( zgrab2 . SCAN_PROTOCOL_ERROR , fmt . Errorf ( "Response message type 0x%02x was not an alphanumeric character" , header [ 0 ] ) )
}
return c . tryReadPacket ( header [ 0 ] )
}
2018-02-09 16:09:44 +00:00
// GetTLSLog gets the connection's TLSLog, or nil if the connection has
// not yet been set up as TLS.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) GetTLSLog ( ) * zgrab2 . TLSLog {
if ! c . IsSSL {
return nil
}
return c . Connection . ( * zgrab2 . TLSConnection ) . GetLog ( )
}
2018-02-09 16:49:15 +00:00
// encodeMap encodes a map into a byte array of the form
2018-02-09 16:09:44 +00:00
// "key0\0value\0key1\0value1\0...keyN\0valueN\0\0"
2018-01-15 19:24:57 +00:00
func encodeMap ( dict map [ string ] string ) [ ] byte {
var strs [ ] string
for k , v := range dict {
strs = append ( strs , k )
strs = append ( strs , v )
}
return append ( [ ] byte ( strings . Join ( strs , "\x00" ) ) , 0x00 , 0x00 )
}
2018-02-09 16:09:44 +00:00
// SendStartupMessage creates and sends a StartupMessage.
// The format is uint16 Major + uint16 Minor + (key/value pairs).
2018-01-15 19:24:57 +00:00
func ( c * Connection ) SendStartupMessage ( version string , kvps map [ string ] string ) error {
dict := encodeMap ( kvps )
ret := make ( [ ] byte , len ( dict ) + 4 )
parts := strings . Split ( version , "." )
if len ( parts ) == 1 {
parts = [ ] string { parts [ 0 ] , "0" }
}
major , err := strconv . ParseUint ( parts [ 0 ] , 0 , 16 )
if err != nil {
2018-02-09 16:09:44 +00:00
log . Fatalf ( "Error parsing major version %s as a uint16: %v" , parts [ 0 ] , err )
2018-01-15 19:24:57 +00:00
}
minor , err := strconv . ParseUint ( parts [ 1 ] , 0 , 16 )
if err != nil {
2018-02-09 16:09:44 +00:00
log . Fatalf ( "Error parsing minor version %s as a uint16: %v" , parts [ 1 ] , err )
2018-01-15 19:24:57 +00:00
}
binary . BigEndian . PutUint16 ( ret [ 0 : 2 ] , uint16 ( major ) )
binary . BigEndian . PutUint16 ( ret [ 2 : 4 ] , uint16 ( minor ) )
copy ( ret [ 4 : ] , dict )
return c . Send ( ret )
}
2018-02-09 16:49:15 +00:00
// ReadAll reads packets from the given connection until it hits a
2018-02-09 16:09:44 +00:00
// timeout, EOF, or a 'Z' packet.
2018-01-15 19:24:57 +00:00
func ( c * Connection ) ReadAll ( ) ( [ ] * ServerPacket , * zgrab2 . ScanError ) {
2018-02-09 16:09:44 +00:00
var ret [ ] * ServerPacket
2018-01-15 19:24:57 +00:00
for {
response , readError := c . ReadPacket ( )
if readError != nil {
if readError . Status == zgrab2 . SCAN_IO_TIMEOUT || readError . Err == io . EOF {
return ret , nil
}
2018-02-09 16:09:44 +00:00
return ret , readError
2018-01-15 19:24:57 +00:00
}
ret = append ( ret , response )
if response . Type == 'Z' {
return ret , nil
}
2018-05-15 19:47:14 +00:00
if len ( ret ) > maxReadAllPackets {
log . Warnf ( "Server %s returned more than %d packets -- truncating." , c . Target . String ( ) , maxReadAllPackets )
return ret , nil
}
2018-01-15 19:24:57 +00:00
}
}
2018-02-09 16:09:44 +00:00
// connectionManager is a utility for getting connections and ensuring
// that they all get closed.
// TODO: Is there something like this in the standard libraries?
2018-01-15 19:24:57 +00:00
type connectionManager struct {
2018-05-11 18:01:10 +00:00
connections map [ io . Closer ] bool
2018-01-15 19:24:57 +00:00
}
2018-02-09 16:09:44 +00:00
// addConnection adds a managed connection.
2018-01-15 19:24:57 +00:00
func ( m * connectionManager ) addConnection ( c io . Closer ) {
2018-05-11 18:01:10 +00:00
m . connections [ c ] = true
}
func ( m * connectionManager ) closeConnection ( c io . Closer ) {
if m . connections [ c ] {
m . connections [ c ] = false
err := c . Close ( )
if err != nil {
log . Debugf ( "Got error closing connection: %v" , err )
}
}
2018-01-15 19:24:57 +00:00
}
2018-02-09 16:09:44 +00:00
// cleanUp closes all managed connections.
2018-01-15 19:24:57 +00:00
func ( m * connectionManager ) cleanUp ( ) {
2018-05-14 18:48:48 +00:00
// first in, last out: empty out the map
2018-05-11 18:01:10 +00:00
defer func ( ) {
for conn , _ := range m . connections {
delete ( m . connections , conn )
}
} ( )
for connection , _ := range m . connections {
2018-01-15 19:24:57 +00:00
// Close them all even if there is a panic with one
defer func ( c io . Closer ) {
2018-05-11 18:01:10 +00:00
m . closeConnection ( c )
} ( connection )
2018-01-15 19:24:57 +00:00
}
}
2018-02-09 16:09:44 +00:00
// Get a new connectionmanager instance.
2018-01-15 19:24:57 +00:00
func newConnectionManager ( ) * connectionManager {
2018-05-11 18:01:10 +00:00
return & connectionManager {
connections : make ( map [ io . Closer ] bool ) ,
}
2018-01-15 19:24:57 +00:00
}