diff --git a/config.go b/config.go index 1903e89..12402be 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,7 @@ type Config struct { Debug bool `long:"debug" description:"Include debug fields in the output."` GOMAXPROCS int `long:"gomaxprocs" default:"0" description:"Set GOMAXPROCS"` ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"` + ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"` Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."` Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"` inputFile *os.File @@ -118,6 +119,11 @@ func validateFrameworkConfiguration() { if config.ConnectionsPerHost > 50 { log.Fatalf("connectionsPerHost must be in the range [0,50]") } + + // Stop even third-party libraries from performing unbounded reads on untrusted hosts + if config.ReadLimitPerHost > 0 { + DefaultBytesReadLimit = config.ReadLimitPerHost * 1024 + } } // GetMetaFile returns the file to which metadata should be output diff --git a/conn.go b/conn.go index 0092818..2208c9d 100644 --- a/conn.go +++ b/conn.go @@ -1,49 +1,127 @@ package zgrab2 import ( + "context" + "errors" + "io" "net" "time" - "context" + + "github.com/sirupsen/logrus" ) +// ReadLimitExceededAction describes how the connection reacts to an attempt to read more data than permitted. +type ReadLimitExceededAction string + +const ( + // ReadLimitExceededActionNotSet is a placeholder for the zero value, so that explicitly set values can be + // distinguished from the empty default. + ReadLimitExceededActionNotSet = ReadLimitExceededAction("") + + // ReadLimitExceededActionTruncate causes the connection to truncate at BytesReadLimit bytes and return a bogus + // io.EOF error. The fact that a truncation took place is logged at debug level. + ReadLimitExceededActionTruncate = ReadLimitExceededAction("truncate") + + // ReadLimitExceededActionError causes the Read call to return n, ErrReadLimitExceeded (in addition to truncating). + ReadLimitExceededActionError = ReadLimitExceededAction("error") + + // ReadLimitExceededActionPanic causes the Read call to panic(ErrReadLimitExceeded). + ReadLimitExceededActionPanic = ReadLimitExceededAction("panic") +) + +var ( + // DefaultBytesReadLimit is the maximum number of bytes to read per connection when no explicit value is provided. + DefaultBytesReadLimit = 256 * 1024 * 1024 + + // DefaultReadLimitExceededAction is the action used when no explicit action is set. + DefaultReadLimitExceededAction = ReadLimitExceededActionTruncate + + // DefaultSessionTimeout is the default maximum time a connection may be used when no explicit value is provided. + DefaultSessionTimeout = 1 * time.Minute +) + +// ErrReadLimitExceeded is returned / panic'd from Read if the read limit is exceeded when the +// ReadLimitExceededAction is error / panic. +var ErrReadLimitExceeded = errors.New("read limit exceeded") + // TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts +// TODO: Refactor this into TimeoutConnection, BoundedReader, LoggedReader, etc type TimeoutConnection struct { net.Conn - Timeout time.Duration - explicitReadDeadline bool - explicitWriteDeadline bool - explicitDeadline bool + ctx context.Context + Timeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + BytesRead int + BytesWritten int + BytesReadLimit int + ReadLimitExceededAction ReadLimitExceededAction + Cancel context.CancelFunc + explicitReadDeadline bool + explicitWriteDeadline bool + explicitDeadline bool } // TimeoutConnection.Read calls Read() on the underlying connection, using any configured deadlines func (c *TimeoutConnection) Read(b []byte) (n int, err error) { + if err := c.checkContext(); err != nil { + return 0, err + } + origSize := len(b) + if c.BytesRead+len(b) >= c.BytesReadLimit { + b = b[0 : c.BytesReadLimit-c.BytesRead] + } if c.explicitReadDeadline || c.explicitDeadline { c.explicitReadDeadline = false c.explicitDeadline = false - } else if c.Timeout > 0 { - if err = c.Conn.SetReadDeadline(time.Now().Add(c.Timeout)); err != nil { + } else if readTimeout := c.getTimeout(c.ReadTimeout); readTimeout > 0 { + if err = c.Conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { return 0, err } } - return c.Conn.Read(b) + n, err = c.Conn.Read(b) + c.BytesRead += n + if err == nil && origSize != len(b) && n == len(b) { + // we had to shrink the output buffer AND we used up the whole shrunk size, AND we're not at EOF + switch c.ReadLimitExceededAction { + case ReadLimitExceededActionTruncate: + logrus.Debug("Truncated read from %d bytes to %d bytes (hit limit of %d bytes)", origSize, n, c.BytesReadLimit) + err = io.EOF + case ReadLimitExceededActionError: + return n, ErrReadLimitExceeded + case ReadLimitExceededActionPanic: + panic(ErrReadLimitExceeded) + default: + logrus.Fatalf("Unrecognized ReadLimitExceededAction: %s", c.ReadLimitExceededAction) + } + } + return n, err } -// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines +// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines. func (c *TimeoutConnection) Write(b []byte) (n int, err error) { + if err := c.checkContext(); err != nil { + return 0, err + } if c.explicitWriteDeadline || c.explicitDeadline { c.explicitWriteDeadline = false c.explicitDeadline = false - } else if c.Timeout > 0 { - if err = c.Conn.SetWriteDeadline(time.Now().Add(c.Timeout)); err != nil { + } else if writeTimeout := c.getTimeout(c.WriteTimeout); writeTimeout > 0 { + if err = c.Conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { return 0, err } } - return c.Conn.Write(b) + n, err = c.Conn.Write(b) + c.BytesWritten += n + return n, err } // SetReadDeadline sets an explicit ReadDeadline that will override the timeout // for one read. Use deadline = 0 to clear the deadline. func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error { + if err := c.checkContext(); err != nil { + return err + } if !deadline.IsZero() { err := c.Conn.SetReadDeadline(deadline) if err != nil { @@ -57,6 +135,9 @@ func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error { // SetWriteDeadline sets an explicit WriteDeadline that will override the // WriteDeadline for one write. Use deadline = 0 to clear the deadline. func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error { + if err := c.checkContext(); err != nil { + return err + } if !deadline.IsZero() { err := c.Conn.SetWriteDeadline(deadline) if err != nil { @@ -70,6 +151,9 @@ func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error { // SetDeadline sets a read / write deadline that will override the deadline for // a single read/write. Use deadline = 0 to clear the deadline. func (c *TimeoutConnection) SetDeadline(deadline time.Time) error { + if err := c.checkContext(); err != nil { + return err + } if !deadline.IsZero() { err := c.Conn.SetDeadline(deadline) if err != nil { @@ -80,8 +164,8 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error { return nil } -// GetTimeoutDialer returns a Dialer function that dials with the given timeout -func GetTimeoutDialer(timeout time.Duration) func(string, string) (net.Conn, error) { +// GetTimeoutDialFunc returns a DialFunc that dials with the given timeout +func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) { return func(proto, target string) (net.Conn, error) { return DialTimeoutConnection(proto, target, timeout) } @@ -92,14 +176,69 @@ func (c *TimeoutConnection) Close() error { return c.Conn.Close() } -// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations. -func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) { +// Get the timeout for the given field, falling back to the global timeout. +func (c *TimeoutConnection) getTimeout(field time.Duration) time.Duration { + if field == 0 { + return c.Timeout + } + return field +} + +// Check if the context has been cancelled, and if so, return an error (either the context error, or +// if the context error is nil, ErrTotalTimeout). +func (c *TimeoutConnection) checkContext() error { + if c.ctx == nil { + return nil + } + select { + case <-c.ctx.Done(): + if err := c.ctx.Err(); err != nil { + return err + } else { + return ErrTotalTimeout + } + default: + return nil + } +} + +// SetDefaults on the connection. +func (c *TimeoutConnection) SetDefaults() *TimeoutConnection { + if c.BytesReadLimit == 0 { + c.BytesReadLimit = DefaultBytesReadLimit + } + if c.ReadLimitExceededAction == ReadLimitExceededActionNotSet { + c.ReadLimitExceededAction = DefaultReadLimitExceededAction + } + if c.Timeout == 0 { + c.Timeout = DefaultSessionTimeout + } + return c +} + +// NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults. +func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection { + ret := (&TimeoutConnection{ + Conn: conn, + Timeout: timeout, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + }).SetDefaults() + if ctx == nil { + ctx = context.Background() + } + ret.ctx, ret.Cancel = context.WithTimeout(ctx, timeout) + return ret +} + +// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations. +func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) { var conn net.Conn var err error - if timeout > 0 { - conn, err = net.DialTimeout(proto, target, timeout) + if dialTimeout > 0 { + conn, err = net.DialTimeout(proto, target, dialTimeout) } else { - conn, err = net.Dial(proto, target) + conn, err = net.DialTimeout(proto, target, sessionTimeout) } if err != nil { if conn != nil { @@ -107,49 +246,104 @@ func DialTimeoutConnection(proto string, target string, timeout time.Duration) ( } return nil, err } - return &TimeoutConnection{ - Conn: conn, - Timeout: timeout, - }, nil + return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil +} + +// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations. +func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) { + return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout) } // Dialer provides Dial and DialContext methods to get connections with the given timeout. type Dialer struct { - // Timeout is the maximum time to wait for a connection or I/O. + // Timeout is the maximum time to wait for the entire session, after which any operations on the + // connection will fail. Timeout time.Duration - // dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a TimeoutConnection). + // ConnectTimeout is the maximum time to wait for a connection. + ConnectTimeout time.Duration + + // ReadTimeout is the maximum time to wait for a Read + ReadTimeout time.Duration + + // WriteTimeout is the maximum time to wait for a Write + WriteTimeout time.Duration + + // Dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a + // TimeoutConnection). Dialer *net.Dialer + + // BytesReadLimit is the maximum number of bytes that connections dialed with this dialer will + // read before erroring. + BytesReadLimit int + + // ReadLimitExceededAction describes how connections dialed with this dialer deal with exceeding + // the BytesReadLimit. + ReadLimitExceededAction ReadLimitExceededAction +} + +func (d *Dialer) getTimeout(field time.Duration) time.Duration { + if field == 0 { + return d.Timeout + } + return field } // DialContext wraps the connection returned by net.Dialer.DialContext() with a TimeoutConnection. func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - // ensure that our aux dialer is up-to-date - d.Dialer.Timeout = d.Timeout + if d.Timeout != 0 { + ctx, _ = context.WithTimeout(ctx, d.Timeout) + } + // ensure that our aux dialer is up-to-date; copied from http/transport.go + d.Dialer.Timeout = d.getTimeout(d.ConnectTimeout) d.Dialer.KeepAlive = d.Timeout - ret, err := d.Dialer.DialContext(ctx, network, address) + dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout) + defer cancelDial() + conn, err := d.Dialer.DialContext(dialContext, network, address) if err != nil { return nil, err } - return &TimeoutConnection{ - Conn: ret, - Timeout: d.Timeout, - }, nil + ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout) + ret.BytesReadLimit = d.BytesReadLimit + ret.ReadLimitExceededAction = d.ReadLimitExceededAction + return ret, nil } // Dial returns a connection with the configured timeout. func (d *Dialer) Dial(proto string, target string) (net.Conn, error) { - return DialTimeoutConnection(proto, target, d.Timeout) + return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout) } // GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout. func GetTimeoutConnectionDialer(timeout time.Duration) *Dialer { - return &Dialer{ - Timeout: timeout, - Dialer: &net.Dialer{ - Timeout: timeout, - KeepAlive: timeout, - DualStack: true, - }, + return NewDialer(&Dialer{Timeout: timeout}) +} + +// SetDefaults for the Dialer. +func (d *Dialer) SetDefaults() *Dialer { + if d.Timeout == 0 { + d.Timeout = DefaultSessionTimeout } -} \ No newline at end of file + if d.ReadLimitExceededAction == ReadLimitExceededActionNotSet { + d.ReadLimitExceededAction = DefaultReadLimitExceededAction + } + if d.BytesReadLimit == 0 { + d.BytesReadLimit = DefaultBytesReadLimit + } + if d.Dialer == nil { + d.Dialer = &net.Dialer{ + Timeout: d.Timeout, + KeepAlive: d.Timeout, + DualStack: true, + } + } + return d +} + +// NewDialer creates a new Dialer with default settings. +func NewDialer(value *Dialer) *Dialer { + if value == nil { + value = &Dialer{} + } + return value.SetDefaults() +} diff --git a/conn_bytelimit_test.go b/conn_bytelimit_test.go new file mode 100644 index 0000000..1e93379 --- /dev/null +++ b/conn_bytelimit_test.go @@ -0,0 +1,396 @@ +package zgrab2 + +import ( + "context" + "fmt" + "io" + "net" + "strings" + "testing" + "time" +) + +// Start a local echo server on port. +func runEchoServer(t *testing.T, port int) { + endpoint := fmt.Sprintf("127.0.0.1:%d", port) + listener, err := net.Listen("tcp", endpoint) + if err != nil { + t.Fatal(err) + } + go func() { + defer listener.Close() + sock, err := listener.Accept() + if err != nil { + t.Fatal(err) + } + defer sock.Close() + + buf := make([]byte, 1024) + for { + n, err := sock.Read(buf) + if err != nil { + if err != io.EOF && !strings.Contains(err.Error(), "connection reset") { + t.Fatal(err) + } + return + } + sock.SetWriteDeadline(time.Now().Add(time.Millisecond * 250)) + n, err = sock.Write(buf[0:n]) + if err != nil { + if err != io.EOF && !strings.Contains(err.Error(), "connection reset") && !strings.Contains(err.Error(), "broken pipe") { + t.Logf("Unexpected error writing to client: %v", err) + } + return + } + } + }() +} + +// Interface for getting a TimeoutConnection; we want to test both the dialer and the direct Dial functions. +type timeoutConnector interface { + connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) + getConfig() readLimitTestConfig +} + +// Config for a test case +type readLimitTestConfig struct { + // The maximum bytes the connection should read + limit int + // The number of bytes that should be sent (so iff sendSize > limit, the action should be triggered) + sendSize int + // The action to run when too much data is sent + action ReadLimitExceededAction +} + +// Call sendReceive(), and check that the input/output match, and that any expected errors / truncation occurs. +func checkedSendReceive(t *testing.T, conn *TimeoutConnection, size int) (result error) { + // helper to report + return an error + tErrorf := func(format string, args ...interface{}) error { + result = fmt.Errorf(format, args) + t.Error(result) + return result + } + + // We will check that this increases by the correct size + before := conn.BytesRead + + // This is true if we expect an overflow to occur (and so the ReadLimitExceededAction should fire) + overflowed := (before + size) > conn.BytesReadLimit + + // Don't want to keep re-typing this + action := conn.ReadLimitExceededAction + + defer func() { + if result != nil { + // log any previous error -- more may still follow + t.Error(result) + } + err := recover() + if err != nil { + if action != ReadLimitExceededActionPanic { + // no reason to panic unless that is the action + panic(err) + } + if !overflowed { + tErrorf("panicked early: only sent %d bytes so far, but limit=%d", before+size, conn.BytesReadLimit) + return + } + if err == ErrReadLimitExceeded { + // We read too much data and this is the right error: silently succeed + return + } + tErrorf("wrong panic error: got %v, expected ErrReadlimitExceeded", err) + return + } + + if action != ReadLimitExceededActionPanic { + // other action -- fine that we didn't panic + return + } + if !overflowed { + // not enough bytes read to overflow -- fine that we didn't panic + return + } + // ReadLimitExceededActionPanic, read too many bytes: should have panicked but didn't + tErrorf("should have panicked: action=ReadLimitExceededActionPanic, but sent without issue") + }() + + ret, err := sendReceive(t, conn, size) + + if err != nil { + if !overflowed { + // If there is no overflow, there should be no error + return tErrorf("read: unexpected error: %v", err) + } + if err != io.EOF && err != ErrReadLimitExceeded { + // EOF and ErrReadLimitExceeded are the only errors that should be returned + return tErrorf("read: wrong error: %v", err) + } + if err == io.EOF && action != ReadLimitExceededActionTruncate { + // EOF should only occur with truncation + return tErrorf("read: unexpected EOF") + } + if err == ErrReadLimitExceeded && action != ReadLimitExceededActionError { + // ErrReadLimitExceeded should only occur with ReadLimitExceededActionError + return tErrorf("read: unexpected ErrReadLimitExceeded") + } + // Otherwise, fall through -- we still need to check that the data matches + } else { + if overflowed && action == ReadLimitExceededActionError { + return tErrorf("read: should have gotten an error, but did not") + } + } + expectedSize := size + if overflowed { + expectedSize = conn.BytesReadLimit - before + } + + if conn.BytesRead != before+expectedSize { + return tErrorf("check: BytesRead value inconsistent; expected %d, got %d", before+expectedSize, conn.BytesRead) + } + if len(ret) != expectedSize { + return tErrorf("check: expected %d bytes, got %d", expectedSize, len(ret)) + } + if expectedSize > 0 && !checkTestBuffer(ret) { + return tErrorf("Got back invalid data (%x)", ret) + } + return nil +} + +// Send size testBuffer bytes to conn, then perform a read, and return the result/error. +func sendReceive(t *testing.T, conn *TimeoutConnection, size int) ([]byte, error) { + toSend := getTestBuffer(size) + n, err := conn.Write(toSend) + if err != nil { + t.Fatalf("Send failed: %v", err) + return nil, err + } + if n != len(toSend) { + t.Fatalf("Short write: expected to send %d bytes, returned %d", len(toSend), n) + return nil, io.ErrShortWrite + } + readBuf := make([]byte, size) + n, err = conn.Read(readBuf) + return readBuf[0:n], err +} + +// Get a size-byte slice of sequential bytes (mod 256), starting from 0 +func getTestBuffer(size int) []byte { + ret := make([]byte, size) + for i := 0; i < size; i++ { + ret[i] = byte(i & 0xff) + } + return ret +} + +// Check that buf is of the type returned by getTestBuffer. +func checkTestBuffer(buf []byte) bool { + if buf == nil || len(buf) == 0 { + return false + } + for i, v := range buf { + if v != byte(i&0xff) { + return false + } + } + return true +} + +// Send / receive cfg.sendSize bytes in a single shot and check that it behaves appropriately. +func (cfg readLimitTestConfig) runSingleSend(t *testing.T, conn *TimeoutConnection, idx int) error { + if err := checkedSendReceive(t, conn, cfg.sendSize); err != nil { + return err + } + return nil +} + +// Send / receive cfg.sendSize bytes, split over five sends, and check that it behaves appropriately. +func (cfg readLimitTestConfig) runMultiSend(t *testing.T, conn *TimeoutConnection, idx int) error { + for i := 0; i < 5; i++ { + if err := checkedSendReceive(t, conn, cfg.sendSize/5); err != nil { + return err + } + } + return nil +} + +// A timeoutConnector that uses a dialer to dial the connections +type dialerConnector struct { + readLimitTestConfig + + // This is lazily inited + dialer *Dialer +} + +// Function that returns a connector +type timeoutConnectorFactory func(readLimitTestConfig) timeoutConnector + +// Dial the connection using the dialer (creating the dialer if necessary) +func (d *dialerConnector) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) { + if d.dialer == nil { + d.dialer = NewDialer(&Dialer{ + BytesReadLimit: d.limit, + ReadLimitExceededAction: d.action, + }) + } + var ret *TimeoutConnection + conn, err := d.dialer.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if conn != nil { + ret = conn.(*TimeoutConnection) + } + return ret, err +} + +func (d *dialerConnector) getConfig() readLimitTestConfig { + return d.readLimitTestConfig +} + +func dialerTimeoutConnectorFactory(cfg readLimitTestConfig) timeoutConnector { + return &dialerConnector{ + readLimitTestConfig: cfg, + } +} + +// Dial using a direct call to DialTimeoutConnectionEx +type directDial struct { + readLimitTestConfig +} + +func (d *directDial) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) { + conn, err := DialTimeoutConnectionEx("tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second, time.Second, time.Second, time.Second) + var ret *TimeoutConnection + if conn != nil { + ret = conn.(*TimeoutConnection) + ret.BytesReadLimit = d.limit + ret.ReadLimitExceededAction = d.action + } + return ret, err +} + +func (d *directDial) getConfig() readLimitTestConfig { + return d.readLimitTestConfig +} + +func directDialFactory(cfg readLimitTestConfig) timeoutConnector { + return &directDial{cfg} +} + +var readLimitTestConfigs = map[string]readLimitTestConfig{ + // Check that a 2000-byte read gets truncated at 1000 bytes + "truncate": { + limit: 1000, + sendSize: 2000, + action: ReadLimitExceededActionTruncate, + }, + + // Check that a 1005-byte read gets truncated at 1000 bytes + "truncate_close": { + limit: 1000, + sendSize: 1005, + action: ReadLimitExceededActionTruncate, + }, + + // Check that a 2000-byte read errors after reading the first 1000 bytes + "error": { + limit: 1000, + sendSize: 2000, + action: ReadLimitExceededActionError, + }, + + // Check that a 2000-byte read panics after reading the first 1000 bytes + "panic": { + limit: 1000, + sendSize: 2000, + action: ReadLimitExceededActionPanic, + }, + + // Check that the default settings pass (backwards compatibility) + "default": {}, + + // Check that a 100-byte read succeeds / is not truncated + "happy": { + limit: 1000, + sendSize: 100, + action: ReadLimitExceededActionPanic, + }, + + // Check that a 1000-byte read succeeds / is not truncated + "closeCall": { + limit: 1000, + sendSize: 1000, + action: ReadLimitExceededActionPanic, + }, +} + +// Each of these gets run with each readLimitTestConfig +var connTestConnectors = map[string]timeoutConnectorFactory{ + "directDial": directDialFactory, + "dialerConnector": dialerTimeoutConnectorFactory, +} + +// Run a single full trial with the given connector: connect, send/receive the configured bytes, and +// check that the response was properly truncated (or not), and that the bytes read total is +// correctly tabulated. +func runBytesReadLimitTrial(t *testing.T, connector timeoutConnector, idx int, method func(readLimitTestConfig, *testing.T, *TimeoutConnection, int) error) (result error) { + cfg := connector.getConfig() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + port := 0x1234 + idx + runEchoServer(t, port) + conn, err := connector.connect(ctx, t, port, idx) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + expectedSize := cfg.sendSize + if expectedSize > cfg.limit { + expectedSize = cfg.limit + } + defer func() { + if conn.BytesRead != expectedSize { + result = fmt.Errorf("BytesRead(%d) != expected(%d)", conn.BytesRead, expectedSize) + t.Error(result) + } + }() + defer conn.Close() + return method(cfg, t, conn, idx) +} + +// Run a full set of trials on the connector -- ten with a single send, and ten with multiple sends. +func testBytesReadLimitOn(t *testing.T, connector timeoutConnector) error { + for i := 0; i < 10; i++ { + if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runSingleSend); err != nil { + return err + } + } + for i := 0; i < 10; i++ { + if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runMultiSend); err != nil { + return err + } + } + return nil +} + +// Check that the BytesReadLimit is enforced (or not) as expected: +// 1. Create an echo server +// 2. Dial a fresh TimeoutConnection to the echo server with the given BytesReadLimit / ReadLimitExceededAction +// 3. Send the configured number of bytes in a single packet +// 4. Check that it (succeeds / truncates / errors / panics) according to the config +// 5. Repeat 10 times +// 6. Repeat the above 10 more times, except in #3, split the send across five packets +func TestBytesReadLimit(t *testing.T) { + connectors := make(map[string]timeoutConnector) + // Create a fresh connector for each configuration + for cfgName, cfg := range readLimitTestConfigs { + for connectorName, factory := range connTestConnectors { + connectors[connectorName+"_"+cfgName] = factory(cfg) + } + } + + // Run each connector + for name, connector := range connectors { + t.Logf("Running %s", name) + if err := testBytesReadLimitOn(t, connector); err != nil { + t.Logf("Failed running %s: %v", name, err) + } + } +} diff --git a/conn_timeout_test.go b/conn_timeout_test.go new file mode 100644 index 0000000..f9f1040 --- /dev/null +++ b/conn_timeout_test.go @@ -0,0 +1,422 @@ +package zgrab2 + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +// Config for a single timeout test +type connTimeoutTestConfig struct { + // Name for the test for logging purposes + name string + + // Optional explicit endpoint to connect to (if absent, use 127.0.0.1) + endpoint string + + // TCP port number to communicate on + port int + + // Function used to dial new connections + dialer func() (*TimeoutConnection, error) + + // Client timeout values + timeout time.Duration + connectTimeout time.Duration + readTimeout time.Duration + writeTimeout time.Duration + + // Time for server to wait after listening before accepting a connection + acceptDelay time.Duration + + // Time for server to wait after accepting before writing payload + writeDelay time.Duration + + // Time for server to wait before reading payload + readDelay time.Duration + + // Payload for server to send client after connecting + serverToClientPayload []byte + + // Payload for client to send server after reading the previous payload + clientToServerPayload []byte + + // Step when the client is expected to fail + failStep testStep + + // If non-empty, the error string returned by the client should contain this + failError string +} + +// Standardized time units, separated by factors of 10. +const ( + short = 100 * time.Millisecond + medium = 1000 * time.Millisecond + long = 10000 * time.Millisecond +) + +// enum type for the various locations where the test can fail +type testStep string + +const ( + testStepConnect = testStep("connect") + testStepRead = testStep("read") + testStepWrite = testStep("write") + testStepDone = testStep("done") +) + +// Encapsulates a source for an error (client/server/???), the step where it occurred, and an +// optional cause. +type timeoutTestError struct { + source string + step testStep + cause error +} + +func (err *timeoutTestError) Error() string { + return fmt.Sprintf("%s error at %s: %v", err.source, err.step, err.cause) +} + +func serverError(step testStep, err error) *timeoutTestError { + return &timeoutTestError{ + source: "server", + step: step, + cause: err, + } +} + +func clientError(step testStep, err error) *timeoutTestError { + return &timeoutTestError{ + source: "client", + step: step, + cause: err, + } +} + +// Helper to ensure all data is written to a socket +func _write(writer io.Writer, data []byte) error { + n, err := writer.Write(data) + if err == nil && n != len(data) { + err = io.ErrShortWrite + } + return err +} + +// Run the configured server. As soon as it returns, it is listening. +// Returns a channel that receives a timeoutTestError on error, or is closed on successful completion. +func (cfg *connTimeoutTestConfig) runServer(t *testing.T) (chan *timeoutTestError) { + errorChan := make(chan *timeoutTestError) + if cfg.endpoint != "" { + // Only listen on localhost + return errorChan + } + listener, err := net.Listen("tcp", cfg.getEndpoint()) + if err != nil { + logrus.Fatalf("Error listening: %v", err) + } + go func() { + defer listener.Close() + defer close(errorChan) + time.Sleep(cfg.acceptDelay) + sock, err := listener.Accept() + if err != nil { + errorChan <- serverError(testStepConnect, err) + return + } + defer sock.Close() + time.Sleep(cfg.writeDelay) + if err := _write(sock, cfg.serverToClientPayload); err != nil { + errorChan <- serverError(testStepWrite, err) + return + } + time.Sleep(cfg.readDelay) + buf := make([]byte, len(cfg.clientToServerPayload)) + n, err := io.ReadFull(sock, buf) + if err != nil && err != io.EOF { + errorChan <- serverError(testStepRead, err) + return + } + if err == io.EOF && n < len(buf) { + errorChan <- serverError(testStepRead, err) + return + } + if !bytes.Equal(buf, cfg.clientToServerPayload) { + t.Errorf("%s: clientToServerPayload mismatch", cfg.name) + } + return + }() + return errorChan +} + +// Get the configured endpoint +func (cfg *connTimeoutTestConfig) getEndpoint() string { + if cfg.endpoint != "" { + return cfg.endpoint + } + return fmt.Sprintf("127.0.0.1:%d", cfg.port) +} + +// Dial a connection to the configured endpoint using a Dialer +func (cfg *connTimeoutTestConfig) dialerDial() (*TimeoutConnection, error) { + dialer := NewDialer(&Dialer{ + Timeout: cfg.timeout, + ConnectTimeout: cfg.connectTimeout, + ReadTimeout: cfg.readTimeout, + WriteTimeout: cfg.writeTimeout, + }) + ret, err := dialer.Dial("tcp", cfg.getEndpoint()) + if err != nil { + return nil, err + } + return ret.(*TimeoutConnection), err +} + +// Dial a connection to the configured endpoint using a DialTimeoutConnectionEx +func (cfg *connTimeoutTestConfig) directDial() (*TimeoutConnection, error) { + ret, err := DialTimeoutConnectionEx("tcp", cfg.getEndpoint(), cfg.connectTimeout, cfg.timeout, cfg.readTimeout, cfg.writeTimeout) + if err != nil { + return nil, err + } + return ret.(*TimeoutConnection), err +} + +// Dial a connection to the configured endpoint using Dialer.DialContext +func (cfg *connTimeoutTestConfig) contextDial() (*TimeoutConnection, error) { + dialer := NewDialer(&Dialer{ + Timeout: cfg.timeout, + ConnectTimeout: cfg.connectTimeout, + ReadTimeout: cfg.readTimeout, + WriteTimeout: cfg.writeTimeout, + }) + ret, err := dialer.DialContext(context.Background(), "tcp", cfg.getEndpoint()) + if err != nil { + return nil, err + } + return ret.(*TimeoutConnection), err +} + +// Run the client: connect to the server, read the payload, write the payload, disconnect. +func (cfg *connTimeoutTestConfig) runClient(t *testing.T) (testStep, error) { + conn, err := cfg.dialer() + if err != nil { + return testStepConnect, err + } + defer conn.Close() + buf := make([]byte, len(cfg.serverToClientPayload)) + _, err = io.ReadFull(conn, buf) + if err != nil { + return testStepRead, err + } + if !bytes.Equal(cfg.serverToClientPayload, buf) { + t.Errorf("%s: serverToClientPayload payload mismatch", cfg.name) + } + if err := _write(conn, cfg.clientToServerPayload); err != nil { + return testStepWrite, err + } + return testStepDone, nil +} + +// Run the configured test -- start a server and a client to connect to it. +func (cfg *connTimeoutTestConfig) run(t *testing.T) { + done := make(chan *timeoutTestError) + serverError := cfg.runServer(t) + go func() { + defer func() { + if err := recover(); err != nil { + close(done) + panic(err) + } + }() + step, err := cfg.runClient(t) + done <- clientError(step, err) + }() + go func() { + time.Sleep(long + medium + short) + done <- &timeoutTestError{source: "timeout"} + }() + var ret *timeoutTestError + select { + case err := <-serverError: + t.Fatalf("%s: Server error: %v", cfg.name, err) + case ret = <-done: + if ret == nil { + t.Fatalf("Channel unexpectedly closed") + } + } + if ret.source != "client" { + t.Fatalf("%s: Unexpected error from %s: %v", cfg.name, ret.source, ret.cause) + } + if ret.step != cfg.failStep { + t.Errorf("%s: Failed at step %s, but expected to fail at step %s (error=%v)", cfg.name, ret.step, cfg.failStep, ret.cause) + return + } + if cfg.failError != "" { + errString := "none" + if ret.cause != nil { + errString = ret.cause.Error() + } + if !strings.Contains(errString, cfg.failError) { + t.Errorf("%s: Expected an error (%s) at step %s, got %s", cfg.name, cfg.failError, cfg.failStep, errString) + return + } + } else if ret.cause != nil { + t.Errorf("%s: expected no error at step %s, but got %v", cfg.name, cfg.failStep, ret.cause) + } +} + +var connTestConfigs = []connTimeoutTestConfig{ + // Long timeouts, short delays -- should succeed + { + name: "happy", + port: 0x5613, + timeout: long, + connectTimeout: medium, + readTimeout: medium, + writeTimeout: medium, + + acceptDelay: short, + writeDelay: short, + readDelay: short, + + serverToClientPayload: []byte("abc"), + clientToServerPayload: []byte("defghi"), + + failStep: testStepDone, + }, + // long session timeout, short connectTimeout. Use a non-local, nonexistent endpoint (localhost + // would return "connection refused" immediately) + { + name: "connect_timeout", + endpoint: "10.0.254.254:41591", + timeout: long, + connectTimeout: short, + readTimeout: medium, + writeTimeout: medium, + + acceptDelay: short, + writeDelay: short, + readDelay: short, + + serverToClientPayload: []byte("abc"), + clientToServerPayload: []byte("defghi"), + + failStep: testStepConnect, + failError: "i/o timeout", + }, + // short session timeout, medium connect timeout, with connect to nonexistent endpoint. + { + name: "session_connect_timeout", + endpoint: "10.0.254.254:41591", + timeout: short, + connectTimeout: medium, + readTimeout: medium, + writeTimeout: medium, + + acceptDelay: short, + writeDelay: short, + readDelay: short, + + serverToClientPayload: []byte("abc"), + clientToServerPayload: []byte("defghi"), + + failStep: testStepConnect, + failError: "i/o timeout", + }, + // Get an IO timeout on the read. + // sessionTimeout > acceptDelay + writeDelay > writeDelay > readTimeout + { + name: "read_timeout", + port: 0x5614, + timeout: long, + connectTimeout: short, + readTimeout: short, + writeTimeout: short, + + acceptDelay: short, + writeDelay: medium, + readDelay: short, + + serverToClientPayload: []byte("abc"), + clientToServerPayload: []byte("defghi"), + + failStep: testStepRead, + failError: "i/o timeout", + }, + // Get a context timeout on a read. + // readTimeout > writeDelay > timeout > acceptDelay + { + name: "session_read_timeout", + port: 0x5615, + timeout: short, + connectTimeout: long, + readTimeout: long, + writeTimeout: long, + + acceptDelay: 0, + writeDelay: medium * 2, + readDelay: 0, + + serverToClientPayload: []byte("abc"), + clientToServerPayload: []byte("defghi"), + + failStep: testStepWrite, + failError: "context deadline exceeded", + }, + // Use a session timeout that is longer than any individual action's timeout. + // acceptDelay+writeDelay+readDelay > timeout > acceptDelay >= writeDelay >= readDelay + { + name: "session_timeout", + port: 0x5616, + timeout: medium, + connectTimeout: long, + readTimeout: long, + writeTimeout: long, + + acceptDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()), + writeDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()), + readDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()), + + serverToClientPayload: []byte("abc"), + clientToServerPayload: []byte("defghi"), + + failStep: testStepWrite, + failError: "context deadline exceeded", + }, + // TODO: How to test write timeout? +} + +// TestTimeoutConnectionTimeouts tests that the TimeoutConnection behaves as expected with respect +// to timeouts. +func TestTimeoutConnectionTimeouts(t *testing.T) { + temp := make([]connTimeoutTestConfig, 0, len(connTestConfigs)*3) + // Make three copies of connTestConfigs, one with each dial method. + for _, cfg := range connTestConfigs { + direct := cfg + dialer := cfg + ctxDialer := cfg + + dialer.name = dialer.name + "_dialer" + dialer.port = dialer.port + 100 + dialer.dialer = dialer.dialerDial + + direct.name = direct.name + "_direct" + direct.port = direct.port + 200 + direct.dialer = direct.directDial + + ctxDialer.name = ctxDialer.name + "_context" + ctxDialer.port = ctxDialer.port + 300 + ctxDialer.dialer = ctxDialer.contextDial + temp = append(temp, direct, dialer, ctxDialer) + } + for _, cfg := range temp { + t.Logf("Running %s", cfg.name) + cfg.run(t) + } +} diff --git a/modules/http/http_readlimit_test.go b/modules/http/http_readlimit_test.go new file mode 100644 index 0000000..7ed76ca --- /dev/null +++ b/modules/http/http_readlimit_test.go @@ -0,0 +1,347 @@ +package http + +import ( + "crypto/rsa" + "encoding/hex" + "fmt" + "io" + "math/big" + "net" + "strings" + "testing" + "time" + + "github.com/zmap/zcrypto/tls" + "github.com/zmap/zgrab2" +) + +// BEGIN Taken from handshake_server_test.go -- certs for TLS server +func fromHex(s string) []byte { + b, _ := hex.DecodeString(s) + return b +} + +func bigFromString(s string) *big.Int { + ret := new(big.Int) + ret.SetString(s, 10) + return ret +} + +var testRSACertificate = fromHex("308202b030820219a00302010202090085b0bba48a7fb8ca300d06092a864886f70d01010505003045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3130303432343039303933385a170d3131303432343039303933385a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819f300d06092a864886f70d010101050003818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a381a73081a4301d0603551d0e04160414b1ade2855acfcb28db69ce2369ded3268e18883930750603551d23046e306c8014b1ade2855acfcb28db69ce2369ded3268e188839a149a4473045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746482090085b0bba48a7fb8ca300c0603551d13040530030101ff300d06092a864886f70d010105050003818100086c4524c76bb159ab0c52ccf2b014d7879d7a6475b55a9566e4c52b8eae12661feb4f38b36e60d392fdf74108b52513b1187a24fb301dbaed98b917ece7d73159db95d31d78ea50565cd5825a2d5a5f33c4b6d8c97590968c0f5298b5cd981f89205ff2a01ca31b9694dda9fd57e970e8266d71999b266e3850296c90a7bdd9") +var testSNICertificate = fromHex("308201f23082015da003020102020100300b06092a864886f70d01010530283110300e060355040a130741636d6520436f311430120603550403130b736e69746573742e636f6d301e170d3132303431313137343033355a170d3133303431313137343533355a30283110300e060355040a130741636d6520436f311430120603550403130b736e69746573742e636f6d30819d300b06092a864886f70d01010103818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a3323030300e0603551d0f0101ff0404030200a0300d0603551d0e0406040401020304300f0603551d2304083006800401020304300b06092a864886f70d0101050381810089c6455f1c1f5ef8eb1ab174ee2439059f5c4259bb1a8d86cdb1d056f56a717da40e95ab90f59e8deaf627c157995094db0802266eb34fc6842dea8a4b68d9c1389103ab84fb9e1f85d9b5d23ff2312c8670fbb540148245a4ebafe264d90c8a4cf4f85b0fac12ac2fc4a3154bad52462868af96c62c6525d652b6e31845bdcc") + +var testRSAPrivateKey = &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: bigFromString("131650079503776001033793877885499001334664249354723305978524647182322416328664556247316495448366990052837680518067798333412266673813370895702118944398081598789828837447552603077848001020611640547221687072142537202428102790818451901395596882588063427854225330436740647715202971973145151161964464812406232198521"), + E: 65537, + }, + D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"), + Primes: []*big.Int{ + bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), + bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + }, +} + +// END Taken from handshake_server_test.go -- certs for TLS server + +// Get the tls.Config object for the server; adapted from handshake_server_test.go. +func getTLSConfig() *tls.Config { + testConfig := &tls.Config{ + Time: func() time.Time { return time.Unix(0, 0) }, + Certificates: make([]tls.Certificate, 2), + InsecureSkipVerify: true, + MinVersion: tls.VersionSSL30, + MaxVersion: tls.VersionTLS12, + } + testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate} + testConfig.Certificates[0].PrivateKey = testRSAPrivateKey + testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate} + testConfig.Certificates[1].PrivateKey = testRSAPrivateKey + testConfig.BuildNameToCertificate() + return testConfig +} + +// Helper function to write and check for short writes +func _write(writer io.Writer, data []byte) error { + n, err := writer.Write(data) + if err == nil && len(data) != n { + err = io.ErrShortWrite + } + return err +} + +// Start a local server that sends responds to any requests with a cfg.headerSize-byte set of +// headers followed by a cfg.bodySize-byte body. +// The response ends up looking like this: +// HTTP/1.0 200 OK +// Bogus-Header: XXX... +// Content-Length: +// +// XXXX.... +func (cfg *readLimitTestConfig) runFakeHTTPServer(t *testing.T) { + endpoint := fmt.Sprintf("127.0.0.1:%d", cfg.port) + listener, err := net.Listen("tcp", endpoint) + if err != nil { + t.Fatal(err) + } + go func() { + defer listener.Close() + sock, err := listener.Accept() + if err != nil { + t.Fatal(err) + } + defer sock.Close() + if cfg.tls { + tlsSock := tls.Server(sock, getTLSConfig()) + if err := tlsSock.Handshake(); err != nil { + t.Fatalf("server handshake error: %v", err) + } + sock = tlsSock + } + // don't care what the client sends, always respond with a HTTP-like response + buf := make([]byte, 1) + _, err = sock.Read(buf) + if err != nil { + // any error, including EOF, is unexpected -- the client should send something + t.Fatalf("Unexpected error reading from client: %v", err) + } + + head := "HTTP/1.0 200 OK\r\nBogus-Header: X" + headSuffix := fmt.Sprintf("\r\nContent-Length: %d\r\n\r\n", cfg.bodySize) + size := cfg.headerSize - len(head) - len(headSuffix) + if size < 0 { + t.Fatalf("Header size %d too small: must be at least %d bytes", cfg.headerSize, len(head)+len(headSuffix)) + } + if err := _write(sock, []byte(head)); err != nil { + t.Fatalf("write error: %v", err) + } + chunkSize := 256 + sent := len(head) + chunk := []byte(strings.Repeat("X", chunkSize)) + for i := 0; i < size; i += chunkSize { + if i+chunkSize > size { + chunk = []byte(strings.Repeat("X", size-i)) + } + if err := _write(sock, chunk); err != nil { + t.Logf("Failed writing to client after %d bytes: %v", sent, err) + return + } + sent += len(chunk) + } + + if err := _write(sock, []byte(headSuffix)); err != nil { + t.Logf("Failed writing foot to client: %v", err) + return + } + sent += len(headSuffix) + body := strings.Repeat("X", cfg.bodySize) + if err := _write(sock, []byte(body)); err != nil { + t.Logf("Failed writing body to client: %v", err) + return + } + }() +} + +// Get an HTTP scanner module with the desired config +func (cfg *readLimitTestConfig) getScanner(t *testing.T) *Scanner { + var module Module + flags := module.NewFlags().(*Flags) + flags.Endpoint = "/" + flags.Method = "GET" + flags.UserAgent = "Mozilla/5.0 zgrab/0.x" + if cfg.maxBodySize&0x03ff != 0 { + t.Fatalf("%d is not a valid maxBodySize (must be a multiple of 1024)", cfg.maxBodySize) + } + flags.MaxSize = cfg.maxBodySize / 1024 + flags.MaxRedirects = 0 + flags.Timeout = 1 * time.Second + flags.Port = uint(cfg.port) + flags.UseHTTPS = cfg.tls + zgrab2.DefaultBytesReadLimit = cfg.maxReadSize + scanner := module.NewScanner() + scanner.Init(flags) + return scanner.(*Scanner) +} + +// Configuration for a single test run +type readLimitTestConfig struct { + // if true, the client/server will use TLS. NOTE: the limits are on the *raw* connection. + tls bool + + // port where the server listens. + port int + + // Bodies larger than this are truncated. NOTE: this must be a multiple of 1024, since MaxSize + // is given in kilobytes. + maxBodySize int + + // The maximum number of bytes to read from the (raw) socket. Beyond that data is truncated and + // EOF is returned. + maxReadSize int + + // The size of the HTTP server's "header" (actually, all of the data before the body). Must be + // at least 58 (the size of the static parts of the response). + headerSize int + + // The size of the HTTP body to send (the Content-Length). + bodySize int + + // The status that should be returned by the scan. + expectedStatus zgrab2.ScanStatus + + // If set, the error returned by the scan must contain this. + expectedError string +} + +const ( + readLimitTestConfigHTTPBasePort = 0x7f7f + readLimitTestConfigHTTPSBasePort = 0x7bbc +) + +var readLimitTestConfigs = map[string]*readLimitTestConfig{ + // The socket truncates the connection while reading the body. To the client it looks as if the + // server closed the connection prior to sending Content-Length bytes; the result is success, + // but with a truncated body. + // bodySize + headerSize > maxReadSize > headerSize + "truncate_read_body": { + tls: false, + port: readLimitTestConfigHTTPBasePort, + maxBodySize: 2048, + maxReadSize: 1024, + headerSize: 64, + bodySize: 4096, + expectedStatus: zgrab2.SCAN_SUCCESS, + }, + // NOTE: There is no tls_truncate_read_body, since the truncation will almost certainly occur + // in the middle of a TLS packet -- so the response would always be "unexpected EOF" + + // The HTTP library stops reading the body after reaching its internal limit. It returns success + // and the truncated body. + // maxReadSize > headerSize + bodySize > bodySize > maxBodySize + "truncate_body": { + tls: false, + port: readLimitTestConfigHTTPBasePort + 1, + maxBodySize: 2048, + maxReadSize: 8192, + headerSize: 64, + bodySize: 4096, + expectedStatus: zgrab2.SCAN_SUCCESS, + }, + "tls_truncate_body": { + tls: true, + port: readLimitTestConfigHTTPSBasePort + 1, + maxBodySize: 2048, + maxReadSize: 8192, + headerSize: 64, + bodySize: 4096, + expectedStatus: zgrab2.SCAN_SUCCESS, + }, + + // The socket truncates the connection while reading the headers. The result isn't a valid HTTP + // response, so the library returns an unexpected EOF error. + // headerSize > maxReadSize + "truncate_read_header": { + tls: false, + port: readLimitTestConfigHTTPBasePort + 2, + maxBodySize: 1024, + maxReadSize: 2048, + headerSize: 3072, + bodySize: 8, + expectedError: "unexpected EOF", + expectedStatus: zgrab2.SCAN_UNKNOWN_ERROR, + }, + "tls_truncate_read_header": { + tls: true, + port: readLimitTestConfigHTTPSBasePort + 2, + maxBodySize: 1024, + maxReadSize: 2048, + headerSize: 3072, + bodySize: 8, + expectedError: "unexpected EOF", + expectedStatus: zgrab2.SCAN_UNKNOWN_ERROR, + }, + + // Happy case. None of the limits are hit. + // maxReadSize >= maxBodySize > bodySize + headerSize + "happy_case": { + tls: false, + port: readLimitTestConfigHTTPBasePort + 3, + maxBodySize: 8192, + maxReadSize: 8192, + headerSize: 1024, + bodySize: 1024, + expectedStatus: zgrab2.SCAN_SUCCESS, + }, + "tls_happy_case": { + tls: true, + port: readLimitTestConfigHTTPSBasePort + 3, + maxBodySize: 8192, + maxReadSize: 8192, + headerSize: 1024, + bodySize: 1024, + expectedStatus: zgrab2.SCAN_SUCCESS, + }, +} + +// Try to get the HTTP body from a result; otherwise return the empty string. +func getBody(result interface{}) string { + if result == nil { + return "" + } + httpResult, ok := result.(*Results) + if !ok { + return "" + } + response := httpResult.Response + if response == nil { + return "" + } + return response.BodyText +} + +// Run a single test with the given configuration. +func (cfg *readLimitTestConfig) runTest(t *testing.T, testName string) { + scanner := cfg.getScanner(t) + cfg.runFakeHTTPServer(t) + target := zgrab2.ScanTarget{ + IP: net.ParseIP("127.0.0.1"), + } + status, ret, err := scanner.Scan(target) + + if status != cfg.expectedStatus { + t.Errorf("Wrong status: expected %s, got %s", cfg.expectedStatus, status) + } + if err != nil { + if !strings.Contains(err.Error(), cfg.expectedError) { + t.Errorf("Wrong error: expected %s, got %s", err.Error(), cfg.expectedError) + } + } else if len(cfg.expectedError) > 0 { + t.Errorf("Expected error '%s' but got none", cfg.expectedError) + } + if cfg.expectedStatus == zgrab2.SCAN_SUCCESS { + body := getBody(ret) + if body == "" { + t.Errorf("Expected success, but got no body") + } else { + if len(body) > cfg.maxBodySize || len(body) > cfg.maxReadSize { + t.Errorf("Body exceeds max size: len(body)=%d; maxBodySize=%d, maxReadSize=%d", len(body), cfg.maxBodySize, cfg.maxReadSize) + } + if !cfg.tls { + if len(body)+cfg.headerSize > cfg.maxReadSize { + t.Errorf("Body and header exceed max read size: len(body)=%d, headerSize=%d, maxReadSize=%d", len(body), cfg.headerSize, cfg.maxReadSize) + } + } + } + } +} + +// TestReadLimitHTTP checks that the HTTP scanner works as expected with the default +// ReadLimitExeededAction (specifically, ReadLimnitExceededActionTruncate) defined in conn.go. +func TestReadLimitHTTP(t *testing.T) { + if zgrab2.DefaultReadLimitExceededAction != zgrab2.ReadLimitExceededActionTruncate { + t.Logf("Warning: DefaultReadLimitExceededAction is %s, not %s", zgrab2.DefaultReadLimitExceededAction, zgrab2.ReadLimitExceededActionTruncate) + } + for testName, cfg := range readLimitTestConfigs { + cfg.runTest(t, testName) + } +} diff --git a/modules/http/scanner.go b/modules/http/scanner.go index 52c4ed8..c05dc61 100644 --- a/modules/http/scanner.go +++ b/modules/http/scanner.go @@ -8,12 +8,14 @@ package http import ( "bytes" + "context" "crypto/sha256" "errors" "io" "net" "net/url" "strconv" + "time" log "github.com/sirupsen/logrus" "github.com/zmap/zgrab2" @@ -76,13 +78,14 @@ type Scanner struct { // scan holds the state for a single scan. This may entail multiple connections. // It is used to implement the zgrab2.Scanner interface. type scan struct { - connections []net.Conn - scanner *Scanner - target *zgrab2.ScanTarget - transport *http.Transport - client *http.Client - results Results - url string + connections []net.Conn + scanner *Scanner + target *zgrab2.ScanTarget + transport *http.Transport + client *http.Client + results Results + url string + globalDeadline time.Time } // NewFlags returns an empty Flags object. @@ -142,15 +145,40 @@ func (scan *scan) Cleanup() { } } +// Get a context whose deadline is the earliest of the context's deadline (if it has one) and the +// global scan deadline. +func (scan *scan) withDeadlineContext(ctx context.Context) context.Context { + ctxDeadline, ok := ctx.Deadline() + if !ok || scan.globalDeadline.Before(ctxDeadline) { + ret, _ := context.WithDeadline(ctx, scan.globalDeadline) + return ret + } + return ctx +} + +// Dial a connection using the configured timeouts, as well as the global deadline, and on success, +// add the connection to the list of connections to be cleaned up. +func (scan *scan) dialContext(ctx context.Context, net string, addr string) (net.Conn, error) { + dialer := zgrab2.GetTimeoutConnectionDialer(scan.scanner.config.Timeout) + + timeoutContext, _ := context.WithTimeout(context.Background(), scan.scanner.config.Timeout) + + conn, err := dialer.DialContext(scan.withDeadlineContext(timeoutContext), net, addr) + if err != nil { + return nil, err + } + scan.connections = append(scan.connections, conn) + return conn, nil +} + // getTLSDialer returns a Dial function that connects using the // zgrab2.GetTLSConnection() func (scan *scan) getTLSDialer() func(net, addr string) (net.Conn, error) { return func(net, addr string) (net.Conn, error) { - outer, err := zgrab2.DialTimeoutConnection(net, addr, scan.scanner.config.Timeout) + outer, err := scan.dialContext(context.Background(), net, addr) if err != nil { return nil, err } - scan.connections = append(scan.connections, outer) tlsConn, err := scan.scanner.config.TLSFlags.GetTLSConnection(outer) if err != nil { return nil, err @@ -242,10 +270,11 @@ func (scanner *Scanner) newHTTPScan(t *zgrab2.ScanTarget) *scan { DisableCompression: false, MaxIdleConnsPerHost: scanner.config.MaxRedirects, }, - client: http.MakeNewClient(), + client: http.MakeNewClient(), + globalDeadline: time.Now().Add(scanner.config.Timeout), } ret.transport.DialTLS = ret.getTLSDialer() - ret.transport.DialContext = zgrab2.GetTimeoutConnectionDialer(scanner.config.Timeout).DialContext + ret.transport.DialContext = ret.dialContext ret.client.UserAgent = scanner.config.UserAgent ret.client.CheckRedirect = ret.getCheckRedirect() ret.client.Transport = ret.transport @@ -334,6 +363,7 @@ func (scanner *Scanner) Scan(t zgrab2.ScanTarget) (zgrab2.ScanStatus, interface{ // zgrab2 framework. func RegisterModule() { var module Module + _, err := zgrab2.AddCommand("http", "HTTP Banner Grab", "Grab a banner over HTTP", 80, &module) if err != nil { log.Fatal(err) diff --git a/modules/mssql/connection.go b/modules/mssql/connection.go index 58e5e83..7cce7d4 100644 --- a/modules/mssql/connection.go +++ b/modules/mssql/connection.go @@ -11,7 +11,7 @@ import ( "strings" "time" - logrus "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/zmap/zgrab2" ) @@ -401,6 +401,9 @@ func (options PreloginOptions) Encode() ([]byte, error) { for _, ik := range sortedKeys { k := PreloginOptionToken(ik) v := options[k] + if len(cursor) < 5 { + return nil, fmt.Errorf("encode: size mismatch (options.Size()=%d)", options.Size()) + } cursor[0] = byte(k) if offset > 0xffff { return nil, ErrTooLarge @@ -411,6 +414,9 @@ func (options PreloginOptions) Encode() ([]byte, error) { offset += len(v) cursor = cursor[5:] } + if len(cursor) < 1 { + return nil, fmt.Errorf("encode: size mismatch (options.Size()=%d, len(sortedKeys)=%d)", options.Size(), len(sortedKeys)) + } // Write the terminator after the last PL_OPTION header // (and just before the first value) cursor[0] = 0xff @@ -421,6 +427,9 @@ func (options PreloginOptions) Encode() ([]byte, error) { // returned in rest. // If body can't be decoded as a PRELOGIN body, returns nil, nil, ErrInvalidData func decodePreloginOptions(body []byte) (result *PreloginOptions, rest []byte, err error) { + if len(body) < 1 { + return nil, nil, ErrInvalidData + } cursor := body[:] options := make(PreloginOptions) max := 0 @@ -506,7 +515,13 @@ func (options PreloginOptions) MarshalJSON() ([]byte, error) { fedAuthRequired, hasFedAuthRequired := opts[PreloginFedAuthRequired] if hasFedAuthRequired { - aux.FedAuthRequired = &fedAuthRequired[0] + temp := uint8(0) + if len(fedAuthRequired) > 0 { + temp = fedAuthRequired[0] + } else { + logrus.Debugf("fedAuthRequired was present but empty (options=%#v)", options) + } + aux.FedAuthRequired = &temp } nonce, hasNonce := opts[PreloginNonce] @@ -632,6 +647,7 @@ func (connection *Connection) readPreloginPacket() (*TDSPacket, *PreloginOptions if packet.Type != TDSPacketTypeTabularResult { return packet, nil, &zgrab2.ScanError{Status: zgrab2.SCAN_APPLICATION_ERROR, Err: err} } + defer zgrab2.LogPanic("Error decoding Prelogin packet %#v", packet.Body) plOptions, rest, err := decodePreloginOptions(packet.Body) if err != nil { return packet, nil, err diff --git a/processing.go b/processing.go index 7c7bd49..bb4cf4f 100644 --- a/processing.go +++ b/processing.go @@ -82,10 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er if err != nil { return nil, err } - return &TimeoutConnection{ - Conn: conn, - Timeout: flags.Timeout, - }, nil + return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil } // grabTarget calls handler for each action diff --git a/utility.go b/utility.go index 8ac1e05..44fd536 100644 --- a/utility.go +++ b/utility.go @@ -10,6 +10,8 @@ import ( "time" "github.com/zmap/zflags" + "github.com/sirupsen/logrus" + "runtime/debug" ) var parser *flags.Parser @@ -209,3 +211,18 @@ func IsTimeoutError(err error) bool { return false } + +// LogPanic is intended to be called from within defer -- if there was no panic, it returns without +// doing anything. Otherwise, it logs the stacktrace, the panic error, and the provided message +// before re-raising the original panic. +// Example: +// defer zgrab2.LogPanic("Error decoding body '%x'", body) +func LogPanic(format string, args...interface{}) { + err := recover() + if err == nil { + return + } + logrus.Errorf("Uncaught panic at %s: %v", string(debug.Stack()), err) + logrus.Errorf(format, args...) + panic(err) +}