Merge pull request #169 from zmap/jb/mssqlBoundsChecking

add some tighter bounds checking in MSSQL scanner
This commit is contained in:
justinbastress 2018-10-04 11:19:15 -04:00 committed by GitHub
commit 15127f1b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1483 additions and 58 deletions

@ -21,6 +21,7 @@ type Config struct {
Debug bool `long:"debug" description:"Include debug fields in the output."`
GOMAXPROCS int `long:"gomaxprocs" default:"0" description:"Set GOMAXPROCS"`
ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"`
ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"`
Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."`
Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"`
inputFile *os.File
@ -118,6 +119,11 @@ func validateFrameworkConfiguration() {
if config.ConnectionsPerHost > 50 {
log.Fatalf("connectionsPerHost must be in the range [0,50]")
}
// Stop even third-party libraries from performing unbounded reads on untrusted hosts
if config.ReadLimitPerHost > 0 {
DefaultBytesReadLimit = config.ReadLimitPerHost * 1024
}
}
// GetMetaFile returns the file to which metadata should be output

276
conn.go

@ -1,49 +1,127 @@
package zgrab2
import (
"context"
"errors"
"io"
"net"
"time"
"context"
"github.com/sirupsen/logrus"
)
// ReadLimitExceededAction describes how the connection reacts to an attempt to read more data than permitted.
type ReadLimitExceededAction string
const (
// ReadLimitExceededActionNotSet is a placeholder for the zero value, so that explicitly set values can be
// distinguished from the empty default.
ReadLimitExceededActionNotSet = ReadLimitExceededAction("")
// ReadLimitExceededActionTruncate causes the connection to truncate at BytesReadLimit bytes and return a bogus
// io.EOF error. The fact that a truncation took place is logged at debug level.
ReadLimitExceededActionTruncate = ReadLimitExceededAction("truncate")
// ReadLimitExceededActionError causes the Read call to return n, ErrReadLimitExceeded (in addition to truncating).
ReadLimitExceededActionError = ReadLimitExceededAction("error")
// ReadLimitExceededActionPanic causes the Read call to panic(ErrReadLimitExceeded).
ReadLimitExceededActionPanic = ReadLimitExceededAction("panic")
)
var (
// DefaultBytesReadLimit is the maximum number of bytes to read per connection when no explicit value is provided.
DefaultBytesReadLimit = 256 * 1024 * 1024
// DefaultReadLimitExceededAction is the action used when no explicit action is set.
DefaultReadLimitExceededAction = ReadLimitExceededActionTruncate
// DefaultSessionTimeout is the default maximum time a connection may be used when no explicit value is provided.
DefaultSessionTimeout = 1 * time.Minute
)
// ErrReadLimitExceeded is returned / panic'd from Read if the read limit is exceeded when the
// ReadLimitExceededAction is error / panic.
var ErrReadLimitExceeded = errors.New("read limit exceeded")
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
// TODO: Refactor this into TimeoutConnection, BoundedReader, LoggedReader, etc
type TimeoutConnection struct {
net.Conn
Timeout time.Duration
explicitReadDeadline bool
explicitWriteDeadline bool
explicitDeadline bool
ctx context.Context
Timeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
BytesRead int
BytesWritten int
BytesReadLimit int
ReadLimitExceededAction ReadLimitExceededAction
Cancel context.CancelFunc
explicitReadDeadline bool
explicitWriteDeadline bool
explicitDeadline bool
}
// TimeoutConnection.Read calls Read() on the underlying connection, using any configured deadlines
func (c *TimeoutConnection) Read(b []byte) (n int, err error) {
if err := c.checkContext(); err != nil {
return 0, err
}
origSize := len(b)
if c.BytesRead+len(b) >= c.BytesReadLimit {
b = b[0 : c.BytesReadLimit-c.BytesRead]
}
if c.explicitReadDeadline || c.explicitDeadline {
c.explicitReadDeadline = false
c.explicitDeadline = false
} else if c.Timeout > 0 {
if err = c.Conn.SetReadDeadline(time.Now().Add(c.Timeout)); err != nil {
} else if readTimeout := c.getTimeout(c.ReadTimeout); readTimeout > 0 {
if err = c.Conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
return 0, err
}
}
return c.Conn.Read(b)
n, err = c.Conn.Read(b)
c.BytesRead += n
if err == nil && origSize != len(b) && n == len(b) {
// we had to shrink the output buffer AND we used up the whole shrunk size, AND we're not at EOF
switch c.ReadLimitExceededAction {
case ReadLimitExceededActionTruncate:
logrus.Debug("Truncated read from %d bytes to %d bytes (hit limit of %d bytes)", origSize, n, c.BytesReadLimit)
err = io.EOF
case ReadLimitExceededActionError:
return n, ErrReadLimitExceeded
case ReadLimitExceededActionPanic:
panic(ErrReadLimitExceeded)
default:
logrus.Fatalf("Unrecognized ReadLimitExceededAction: %s", c.ReadLimitExceededAction)
}
}
return n, err
}
// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines
// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines.
func (c *TimeoutConnection) Write(b []byte) (n int, err error) {
if err := c.checkContext(); err != nil {
return 0, err
}
if c.explicitWriteDeadline || c.explicitDeadline {
c.explicitWriteDeadline = false
c.explicitDeadline = false
} else if c.Timeout > 0 {
if err = c.Conn.SetWriteDeadline(time.Now().Add(c.Timeout)); err != nil {
} else if writeTimeout := c.getTimeout(c.WriteTimeout); writeTimeout > 0 {
if err = c.Conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
return 0, err
}
}
return c.Conn.Write(b)
n, err = c.Conn.Write(b)
c.BytesWritten += n
return n, err
}
// SetReadDeadline sets an explicit ReadDeadline that will override the timeout
// for one read. Use deadline = 0 to clear the deadline.
func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error {
if err := c.checkContext(); err != nil {
return err
}
if !deadline.IsZero() {
err := c.Conn.SetReadDeadline(deadline)
if err != nil {
@ -57,6 +135,9 @@ func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error {
// SetWriteDeadline sets an explicit WriteDeadline that will override the
// WriteDeadline for one write. Use deadline = 0 to clear the deadline.
func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error {
if err := c.checkContext(); err != nil {
return err
}
if !deadline.IsZero() {
err := c.Conn.SetWriteDeadline(deadline)
if err != nil {
@ -70,6 +151,9 @@ func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error {
// SetDeadline sets a read / write deadline that will override the deadline for
// a single read/write. Use deadline = 0 to clear the deadline.
func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
if err := c.checkContext(); err != nil {
return err
}
if !deadline.IsZero() {
err := c.Conn.SetDeadline(deadline)
if err != nil {
@ -80,8 +164,8 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
return nil
}
// GetTimeoutDialer returns a Dialer function that dials with the given timeout
func GetTimeoutDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
// GetTimeoutDialFunc returns a DialFunc that dials with the given timeout
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(proto, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, timeout)
}
@ -92,14 +176,69 @@ func (c *TimeoutConnection) Close() error {
return c.Conn.Close()
}
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
// Get the timeout for the given field, falling back to the global timeout.
func (c *TimeoutConnection) getTimeout(field time.Duration) time.Duration {
if field == 0 {
return c.Timeout
}
return field
}
// Check if the context has been cancelled, and if so, return an error (either the context error, or
// if the context error is nil, ErrTotalTimeout).
func (c *TimeoutConnection) checkContext() error {
if c.ctx == nil {
return nil
}
select {
case <-c.ctx.Done():
if err := c.ctx.Err(); err != nil {
return err
} else {
return ErrTotalTimeout
}
default:
return nil
}
}
// SetDefaults on the connection.
func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
if c.BytesReadLimit == 0 {
c.BytesReadLimit = DefaultBytesReadLimit
}
if c.ReadLimitExceededAction == ReadLimitExceededActionNotSet {
c.ReadLimitExceededAction = DefaultReadLimitExceededAction
}
if c.Timeout == 0 {
c.Timeout = DefaultSessionTimeout
}
return c
}
// NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults.
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection {
ret := (&TimeoutConnection{
Conn: conn,
Timeout: timeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
}).SetDefaults()
if ctx == nil {
ctx = context.Background()
}
ret.ctx, ret.Cancel = context.WithTimeout(ctx, timeout)
return ret
}
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) {
var conn net.Conn
var err error
if timeout > 0 {
conn, err = net.DialTimeout(proto, target, timeout)
if dialTimeout > 0 {
conn, err = net.DialTimeout(proto, target, dialTimeout)
} else {
conn, err = net.Dial(proto, target)
conn, err = net.DialTimeout(proto, target, sessionTimeout)
}
if err != nil {
if conn != nil {
@ -107,49 +246,104 @@ func DialTimeoutConnection(proto string, target string, timeout time.Duration) (
}
return nil, err
}
return &TimeoutConnection{
Conn: conn,
Timeout: timeout,
}, nil
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil
}
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout)
}
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
type Dialer struct {
// Timeout is the maximum time to wait for a connection or I/O.
// Timeout is the maximum time to wait for the entire session, after which any operations on the
// connection will fail.
Timeout time.Duration
// dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a TimeoutConnection).
// ConnectTimeout is the maximum time to wait for a connection.
ConnectTimeout time.Duration
// ReadTimeout is the maximum time to wait for a Read
ReadTimeout time.Duration
// WriteTimeout is the maximum time to wait for a Write
WriteTimeout time.Duration
// Dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a
// TimeoutConnection).
Dialer *net.Dialer
// BytesReadLimit is the maximum number of bytes that connections dialed with this dialer will
// read before erroring.
BytesReadLimit int
// ReadLimitExceededAction describes how connections dialed with this dialer deal with exceeding
// the BytesReadLimit.
ReadLimitExceededAction ReadLimitExceededAction
}
func (d *Dialer) getTimeout(field time.Duration) time.Duration {
if field == 0 {
return d.Timeout
}
return field
}
// DialContext wraps the connection returned by net.Dialer.DialContext() with a TimeoutConnection.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// ensure that our aux dialer is up-to-date
d.Dialer.Timeout = d.Timeout
if d.Timeout != 0 {
ctx, _ = context.WithTimeout(ctx, d.Timeout)
}
// ensure that our aux dialer is up-to-date; copied from http/transport.go
d.Dialer.Timeout = d.getTimeout(d.ConnectTimeout)
d.Dialer.KeepAlive = d.Timeout
ret, err := d.Dialer.DialContext(ctx, network, address)
dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout)
defer cancelDial()
conn, err := d.Dialer.DialContext(dialContext, network, address)
if err != nil {
return nil, err
}
return &TimeoutConnection{
Conn: ret,
Timeout: d.Timeout,
}, nil
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout)
ret.BytesReadLimit = d.BytesReadLimit
ret.ReadLimitExceededAction = d.ReadLimitExceededAction
return ret, nil
}
// Dial returns a connection with the configured timeout.
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, d.Timeout)
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout)
}
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.
func GetTimeoutConnectionDialer(timeout time.Duration) *Dialer {
return &Dialer{
Timeout: timeout,
Dialer: &net.Dialer{
Timeout: timeout,
KeepAlive: timeout,
DualStack: true,
},
return NewDialer(&Dialer{Timeout: timeout})
}
// SetDefaults for the Dialer.
func (d *Dialer) SetDefaults() *Dialer {
if d.Timeout == 0 {
d.Timeout = DefaultSessionTimeout
}
}
if d.ReadLimitExceededAction == ReadLimitExceededActionNotSet {
d.ReadLimitExceededAction = DefaultReadLimitExceededAction
}
if d.BytesReadLimit == 0 {
d.BytesReadLimit = DefaultBytesReadLimit
}
if d.Dialer == nil {
d.Dialer = &net.Dialer{
Timeout: d.Timeout,
KeepAlive: d.Timeout,
DualStack: true,
}
}
return d
}
// NewDialer creates a new Dialer with default settings.
func NewDialer(value *Dialer) *Dialer {
if value == nil {
value = &Dialer{}
}
return value.SetDefaults()
}

396
conn_bytelimit_test.go Normal file

@ -0,0 +1,396 @@
package zgrab2
import (
"context"
"fmt"
"io"
"net"
"strings"
"testing"
"time"
)
// Start a local echo server on port.
func runEchoServer(t *testing.T, port int) {
endpoint := fmt.Sprintf("127.0.0.1:%d", port)
listener, err := net.Listen("tcp", endpoint)
if err != nil {
t.Fatal(err)
}
go func() {
defer listener.Close()
sock, err := listener.Accept()
if err != nil {
t.Fatal(err)
}
defer sock.Close()
buf := make([]byte, 1024)
for {
n, err := sock.Read(buf)
if err != nil {
if err != io.EOF && !strings.Contains(err.Error(), "connection reset") {
t.Fatal(err)
}
return
}
sock.SetWriteDeadline(time.Now().Add(time.Millisecond * 250))
n, err = sock.Write(buf[0:n])
if err != nil {
if err != io.EOF && !strings.Contains(err.Error(), "connection reset") && !strings.Contains(err.Error(), "broken pipe") {
t.Logf("Unexpected error writing to client: %v", err)
}
return
}
}
}()
}
// Interface for getting a TimeoutConnection; we want to test both the dialer and the direct Dial functions.
type timeoutConnector interface {
connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error)
getConfig() readLimitTestConfig
}
// Config for a test case
type readLimitTestConfig struct {
// The maximum bytes the connection should read
limit int
// The number of bytes that should be sent (so iff sendSize > limit, the action should be triggered)
sendSize int
// The action to run when too much data is sent
action ReadLimitExceededAction
}
// Call sendReceive(), and check that the input/output match, and that any expected errors / truncation occurs.
func checkedSendReceive(t *testing.T, conn *TimeoutConnection, size int) (result error) {
// helper to report + return an error
tErrorf := func(format string, args ...interface{}) error {
result = fmt.Errorf(format, args)
t.Error(result)
return result
}
// We will check that this increases by the correct size
before := conn.BytesRead
// This is true if we expect an overflow to occur (and so the ReadLimitExceededAction should fire)
overflowed := (before + size) > conn.BytesReadLimit
// Don't want to keep re-typing this
action := conn.ReadLimitExceededAction
defer func() {
if result != nil {
// log any previous error -- more may still follow
t.Error(result)
}
err := recover()
if err != nil {
if action != ReadLimitExceededActionPanic {
// no reason to panic unless that is the action
panic(err)
}
if !overflowed {
tErrorf("panicked early: only sent %d bytes so far, but limit=%d", before+size, conn.BytesReadLimit)
return
}
if err == ErrReadLimitExceeded {
// We read too much data and this is the right error: silently succeed
return
}
tErrorf("wrong panic error: got %v, expected ErrReadlimitExceeded", err)
return
}
if action != ReadLimitExceededActionPanic {
// other action -- fine that we didn't panic
return
}
if !overflowed {
// not enough bytes read to overflow -- fine that we didn't panic
return
}
// ReadLimitExceededActionPanic, read too many bytes: should have panicked but didn't
tErrorf("should have panicked: action=ReadLimitExceededActionPanic, but sent without issue")
}()
ret, err := sendReceive(t, conn, size)
if err != nil {
if !overflowed {
// If there is no overflow, there should be no error
return tErrorf("read: unexpected error: %v", err)
}
if err != io.EOF && err != ErrReadLimitExceeded {
// EOF and ErrReadLimitExceeded are the only errors that should be returned
return tErrorf("read: wrong error: %v", err)
}
if err == io.EOF && action != ReadLimitExceededActionTruncate {
// EOF should only occur with truncation
return tErrorf("read: unexpected EOF")
}
if err == ErrReadLimitExceeded && action != ReadLimitExceededActionError {
// ErrReadLimitExceeded should only occur with ReadLimitExceededActionError
return tErrorf("read: unexpected ErrReadLimitExceeded")
}
// Otherwise, fall through -- we still need to check that the data matches
} else {
if overflowed && action == ReadLimitExceededActionError {
return tErrorf("read: should have gotten an error, but did not")
}
}
expectedSize := size
if overflowed {
expectedSize = conn.BytesReadLimit - before
}
if conn.BytesRead != before+expectedSize {
return tErrorf("check: BytesRead value inconsistent; expected %d, got %d", before+expectedSize, conn.BytesRead)
}
if len(ret) != expectedSize {
return tErrorf("check: expected %d bytes, got %d", expectedSize, len(ret))
}
if expectedSize > 0 && !checkTestBuffer(ret) {
return tErrorf("Got back invalid data (%x)", ret)
}
return nil
}
// Send size testBuffer bytes to conn, then perform a read, and return the result/error.
func sendReceive(t *testing.T, conn *TimeoutConnection, size int) ([]byte, error) {
toSend := getTestBuffer(size)
n, err := conn.Write(toSend)
if err != nil {
t.Fatalf("Send failed: %v", err)
return nil, err
}
if n != len(toSend) {
t.Fatalf("Short write: expected to send %d bytes, returned %d", len(toSend), n)
return nil, io.ErrShortWrite
}
readBuf := make([]byte, size)
n, err = conn.Read(readBuf)
return readBuf[0:n], err
}
// Get a size-byte slice of sequential bytes (mod 256), starting from 0
func getTestBuffer(size int) []byte {
ret := make([]byte, size)
for i := 0; i < size; i++ {
ret[i] = byte(i & 0xff)
}
return ret
}
// Check that buf is of the type returned by getTestBuffer.
func checkTestBuffer(buf []byte) bool {
if buf == nil || len(buf) == 0 {
return false
}
for i, v := range buf {
if v != byte(i&0xff) {
return false
}
}
return true
}
// Send / receive cfg.sendSize bytes in a single shot and check that it behaves appropriately.
func (cfg readLimitTestConfig) runSingleSend(t *testing.T, conn *TimeoutConnection, idx int) error {
if err := checkedSendReceive(t, conn, cfg.sendSize); err != nil {
return err
}
return nil
}
// Send / receive cfg.sendSize bytes, split over five sends, and check that it behaves appropriately.
func (cfg readLimitTestConfig) runMultiSend(t *testing.T, conn *TimeoutConnection, idx int) error {
for i := 0; i < 5; i++ {
if err := checkedSendReceive(t, conn, cfg.sendSize/5); err != nil {
return err
}
}
return nil
}
// A timeoutConnector that uses a dialer to dial the connections
type dialerConnector struct {
readLimitTestConfig
// This is lazily inited
dialer *Dialer
}
// Function that returns a connector
type timeoutConnectorFactory func(readLimitTestConfig) timeoutConnector
// Dial the connection using the dialer (creating the dialer if necessary)
func (d *dialerConnector) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) {
if d.dialer == nil {
d.dialer = NewDialer(&Dialer{
BytesReadLimit: d.limit,
ReadLimitExceededAction: d.action,
})
}
var ret *TimeoutConnection
conn, err := d.dialer.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if conn != nil {
ret = conn.(*TimeoutConnection)
}
return ret, err
}
func (d *dialerConnector) getConfig() readLimitTestConfig {
return d.readLimitTestConfig
}
func dialerTimeoutConnectorFactory(cfg readLimitTestConfig) timeoutConnector {
return &dialerConnector{
readLimitTestConfig: cfg,
}
}
// Dial using a direct call to DialTimeoutConnectionEx
type directDial struct {
readLimitTestConfig
}
func (d *directDial) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) {
conn, err := DialTimeoutConnectionEx("tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second, time.Second, time.Second, time.Second)
var ret *TimeoutConnection
if conn != nil {
ret = conn.(*TimeoutConnection)
ret.BytesReadLimit = d.limit
ret.ReadLimitExceededAction = d.action
}
return ret, err
}
func (d *directDial) getConfig() readLimitTestConfig {
return d.readLimitTestConfig
}
func directDialFactory(cfg readLimitTestConfig) timeoutConnector {
return &directDial{cfg}
}
var readLimitTestConfigs = map[string]readLimitTestConfig{
// Check that a 2000-byte read gets truncated at 1000 bytes
"truncate": {
limit: 1000,
sendSize: 2000,
action: ReadLimitExceededActionTruncate,
},
// Check that a 1005-byte read gets truncated at 1000 bytes
"truncate_close": {
limit: 1000,
sendSize: 1005,
action: ReadLimitExceededActionTruncate,
},
// Check that a 2000-byte read errors after reading the first 1000 bytes
"error": {
limit: 1000,
sendSize: 2000,
action: ReadLimitExceededActionError,
},
// Check that a 2000-byte read panics after reading the first 1000 bytes
"panic": {
limit: 1000,
sendSize: 2000,
action: ReadLimitExceededActionPanic,
},
// Check that the default settings pass (backwards compatibility)
"default": {},
// Check that a 100-byte read succeeds / is not truncated
"happy": {
limit: 1000,
sendSize: 100,
action: ReadLimitExceededActionPanic,
},
// Check that a 1000-byte read succeeds / is not truncated
"closeCall": {
limit: 1000,
sendSize: 1000,
action: ReadLimitExceededActionPanic,
},
}
// Each of these gets run with each readLimitTestConfig
var connTestConnectors = map[string]timeoutConnectorFactory{
"directDial": directDialFactory,
"dialerConnector": dialerTimeoutConnectorFactory,
}
// Run a single full trial with the given connector: connect, send/receive the configured bytes, and
// check that the response was properly truncated (or not), and that the bytes read total is
// correctly tabulated.
func runBytesReadLimitTrial(t *testing.T, connector timeoutConnector, idx int, method func(readLimitTestConfig, *testing.T, *TimeoutConnection, int) error) (result error) {
cfg := connector.getConfig()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
port := 0x1234 + idx
runEchoServer(t, port)
conn, err := connector.connect(ctx, t, port, idx)
if err != nil {
t.Fatalf("Error dialing: %v", err)
}
expectedSize := cfg.sendSize
if expectedSize > cfg.limit {
expectedSize = cfg.limit
}
defer func() {
if conn.BytesRead != expectedSize {
result = fmt.Errorf("BytesRead(%d) != expected(%d)", conn.BytesRead, expectedSize)
t.Error(result)
}
}()
defer conn.Close()
return method(cfg, t, conn, idx)
}
// Run a full set of trials on the connector -- ten with a single send, and ten with multiple sends.
func testBytesReadLimitOn(t *testing.T, connector timeoutConnector) error {
for i := 0; i < 10; i++ {
if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runSingleSend); err != nil {
return err
}
}
for i := 0; i < 10; i++ {
if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runMultiSend); err != nil {
return err
}
}
return nil
}
// Check that the BytesReadLimit is enforced (or not) as expected:
// 1. Create an echo server
// 2. Dial a fresh TimeoutConnection to the echo server with the given BytesReadLimit / ReadLimitExceededAction
// 3. Send the configured number of bytes in a single packet
// 4. Check that it (succeeds / truncates / errors / panics) according to the config
// 5. Repeat 10 times
// 6. Repeat the above 10 more times, except in #3, split the send across five packets
func TestBytesReadLimit(t *testing.T) {
connectors := make(map[string]timeoutConnector)
// Create a fresh connector for each configuration
for cfgName, cfg := range readLimitTestConfigs {
for connectorName, factory := range connTestConnectors {
connectors[connectorName+"_"+cfgName] = factory(cfg)
}
}
// Run each connector
for name, connector := range connectors {
t.Logf("Running %s", name)
if err := testBytesReadLimitOn(t, connector); err != nil {
t.Logf("Failed running %s: %v", name, err)
}
}
}

422
conn_timeout_test.go Normal file

@ -0,0 +1,422 @@
package zgrab2
import (
"bytes"
"context"
"fmt"
"io"
"net"
"strings"
"testing"
"time"
"github.com/sirupsen/logrus"
)
// Config for a single timeout test
type connTimeoutTestConfig struct {
// Name for the test for logging purposes
name string
// Optional explicit endpoint to connect to (if absent, use 127.0.0.1)
endpoint string
// TCP port number to communicate on
port int
// Function used to dial new connections
dialer func() (*TimeoutConnection, error)
// Client timeout values
timeout time.Duration
connectTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
// Time for server to wait after listening before accepting a connection
acceptDelay time.Duration
// Time for server to wait after accepting before writing payload
writeDelay time.Duration
// Time for server to wait before reading payload
readDelay time.Duration
// Payload for server to send client after connecting
serverToClientPayload []byte
// Payload for client to send server after reading the previous payload
clientToServerPayload []byte
// Step when the client is expected to fail
failStep testStep
// If non-empty, the error string returned by the client should contain this
failError string
}
// Standardized time units, separated by factors of 10.
const (
short = 100 * time.Millisecond
medium = 1000 * time.Millisecond
long = 10000 * time.Millisecond
)
// enum type for the various locations where the test can fail
type testStep string
const (
testStepConnect = testStep("connect")
testStepRead = testStep("read")
testStepWrite = testStep("write")
testStepDone = testStep("done")
)
// Encapsulates a source for an error (client/server/???), the step where it occurred, and an
// optional cause.
type timeoutTestError struct {
source string
step testStep
cause error
}
func (err *timeoutTestError) Error() string {
return fmt.Sprintf("%s error at %s: %v", err.source, err.step, err.cause)
}
func serverError(step testStep, err error) *timeoutTestError {
return &timeoutTestError{
source: "server",
step: step,
cause: err,
}
}
func clientError(step testStep, err error) *timeoutTestError {
return &timeoutTestError{
source: "client",
step: step,
cause: err,
}
}
// Helper to ensure all data is written to a socket
func _write(writer io.Writer, data []byte) error {
n, err := writer.Write(data)
if err == nil && n != len(data) {
err = io.ErrShortWrite
}
return err
}
// Run the configured server. As soon as it returns, it is listening.
// Returns a channel that receives a timeoutTestError on error, or is closed on successful completion.
func (cfg *connTimeoutTestConfig) runServer(t *testing.T) (chan *timeoutTestError) {
errorChan := make(chan *timeoutTestError)
if cfg.endpoint != "" {
// Only listen on localhost
return errorChan
}
listener, err := net.Listen("tcp", cfg.getEndpoint())
if err != nil {
logrus.Fatalf("Error listening: %v", err)
}
go func() {
defer listener.Close()
defer close(errorChan)
time.Sleep(cfg.acceptDelay)
sock, err := listener.Accept()
if err != nil {
errorChan <- serverError(testStepConnect, err)
return
}
defer sock.Close()
time.Sleep(cfg.writeDelay)
if err := _write(sock, cfg.serverToClientPayload); err != nil {
errorChan <- serverError(testStepWrite, err)
return
}
time.Sleep(cfg.readDelay)
buf := make([]byte, len(cfg.clientToServerPayload))
n, err := io.ReadFull(sock, buf)
if err != nil && err != io.EOF {
errorChan <- serverError(testStepRead, err)
return
}
if err == io.EOF && n < len(buf) {
errorChan <- serverError(testStepRead, err)
return
}
if !bytes.Equal(buf, cfg.clientToServerPayload) {
t.Errorf("%s: clientToServerPayload mismatch", cfg.name)
}
return
}()
return errorChan
}
// Get the configured endpoint
func (cfg *connTimeoutTestConfig) getEndpoint() string {
if cfg.endpoint != "" {
return cfg.endpoint
}
return fmt.Sprintf("127.0.0.1:%d", cfg.port)
}
// Dial a connection to the configured endpoint using a Dialer
func (cfg *connTimeoutTestConfig) dialerDial() (*TimeoutConnection, error) {
dialer := NewDialer(&Dialer{
Timeout: cfg.timeout,
ConnectTimeout: cfg.connectTimeout,
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
})
ret, err := dialer.Dial("tcp", cfg.getEndpoint())
if err != nil {
return nil, err
}
return ret.(*TimeoutConnection), err
}
// Dial a connection to the configured endpoint using a DialTimeoutConnectionEx
func (cfg *connTimeoutTestConfig) directDial() (*TimeoutConnection, error) {
ret, err := DialTimeoutConnectionEx("tcp", cfg.getEndpoint(), cfg.connectTimeout, cfg.timeout, cfg.readTimeout, cfg.writeTimeout)
if err != nil {
return nil, err
}
return ret.(*TimeoutConnection), err
}
// Dial a connection to the configured endpoint using Dialer.DialContext
func (cfg *connTimeoutTestConfig) contextDial() (*TimeoutConnection, error) {
dialer := NewDialer(&Dialer{
Timeout: cfg.timeout,
ConnectTimeout: cfg.connectTimeout,
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
})
ret, err := dialer.DialContext(context.Background(), "tcp", cfg.getEndpoint())
if err != nil {
return nil, err
}
return ret.(*TimeoutConnection), err
}
// Run the client: connect to the server, read the payload, write the payload, disconnect.
func (cfg *connTimeoutTestConfig) runClient(t *testing.T) (testStep, error) {
conn, err := cfg.dialer()
if err != nil {
return testStepConnect, err
}
defer conn.Close()
buf := make([]byte, len(cfg.serverToClientPayload))
_, err = io.ReadFull(conn, buf)
if err != nil {
return testStepRead, err
}
if !bytes.Equal(cfg.serverToClientPayload, buf) {
t.Errorf("%s: serverToClientPayload payload mismatch", cfg.name)
}
if err := _write(conn, cfg.clientToServerPayload); err != nil {
return testStepWrite, err
}
return testStepDone, nil
}
// Run the configured test -- start a server and a client to connect to it.
func (cfg *connTimeoutTestConfig) run(t *testing.T) {
done := make(chan *timeoutTestError)
serverError := cfg.runServer(t)
go func() {
defer func() {
if err := recover(); err != nil {
close(done)
panic(err)
}
}()
step, err := cfg.runClient(t)
done <- clientError(step, err)
}()
go func() {
time.Sleep(long + medium + short)
done <- &timeoutTestError{source: "timeout"}
}()
var ret *timeoutTestError
select {
case err := <-serverError:
t.Fatalf("%s: Server error: %v", cfg.name, err)
case ret = <-done:
if ret == nil {
t.Fatalf("Channel unexpectedly closed")
}
}
if ret.source != "client" {
t.Fatalf("%s: Unexpected error from %s: %v", cfg.name, ret.source, ret.cause)
}
if ret.step != cfg.failStep {
t.Errorf("%s: Failed at step %s, but expected to fail at step %s (error=%v)", cfg.name, ret.step, cfg.failStep, ret.cause)
return
}
if cfg.failError != "" {
errString := "none"
if ret.cause != nil {
errString = ret.cause.Error()
}
if !strings.Contains(errString, cfg.failError) {
t.Errorf("%s: Expected an error (%s) at step %s, got %s", cfg.name, cfg.failError, cfg.failStep, errString)
return
}
} else if ret.cause != nil {
t.Errorf("%s: expected no error at step %s, but got %v", cfg.name, cfg.failStep, ret.cause)
}
}
var connTestConfigs = []connTimeoutTestConfig{
// Long timeouts, short delays -- should succeed
{
name: "happy",
port: 0x5613,
timeout: long,
connectTimeout: medium,
readTimeout: medium,
writeTimeout: medium,
acceptDelay: short,
writeDelay: short,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepDone,
},
// long session timeout, short connectTimeout. Use a non-local, nonexistent endpoint (localhost
// would return "connection refused" immediately)
{
name: "connect_timeout",
endpoint: "10.0.254.254:41591",
timeout: long,
connectTimeout: short,
readTimeout: medium,
writeTimeout: medium,
acceptDelay: short,
writeDelay: short,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepConnect,
failError: "i/o timeout",
},
// short session timeout, medium connect timeout, with connect to nonexistent endpoint.
{
name: "session_connect_timeout",
endpoint: "10.0.254.254:41591",
timeout: short,
connectTimeout: medium,
readTimeout: medium,
writeTimeout: medium,
acceptDelay: short,
writeDelay: short,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepConnect,
failError: "i/o timeout",
},
// Get an IO timeout on the read.
// sessionTimeout > acceptDelay + writeDelay > writeDelay > readTimeout
{
name: "read_timeout",
port: 0x5614,
timeout: long,
connectTimeout: short,
readTimeout: short,
writeTimeout: short,
acceptDelay: short,
writeDelay: medium,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepRead,
failError: "i/o timeout",
},
// Get a context timeout on a read.
// readTimeout > writeDelay > timeout > acceptDelay
{
name: "session_read_timeout",
port: 0x5615,
timeout: short,
connectTimeout: long,
readTimeout: long,
writeTimeout: long,
acceptDelay: 0,
writeDelay: medium * 2,
readDelay: 0,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepWrite,
failError: "context deadline exceeded",
},
// Use a session timeout that is longer than any individual action's timeout.
// acceptDelay+writeDelay+readDelay > timeout > acceptDelay >= writeDelay >= readDelay
{
name: "session_timeout",
port: 0x5616,
timeout: medium,
connectTimeout: long,
readTimeout: long,
writeTimeout: long,
acceptDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
writeDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
readDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepWrite,
failError: "context deadline exceeded",
},
// TODO: How to test write timeout?
}
// TestTimeoutConnectionTimeouts tests that the TimeoutConnection behaves as expected with respect
// to timeouts.
func TestTimeoutConnectionTimeouts(t *testing.T) {
temp := make([]connTimeoutTestConfig, 0, len(connTestConfigs)*3)
// Make three copies of connTestConfigs, one with each dial method.
for _, cfg := range connTestConfigs {
direct := cfg
dialer := cfg
ctxDialer := cfg
dialer.name = dialer.name + "_dialer"
dialer.port = dialer.port + 100
dialer.dialer = dialer.dialerDial
direct.name = direct.name + "_direct"
direct.port = direct.port + 200
direct.dialer = direct.directDial
ctxDialer.name = ctxDialer.name + "_context"
ctxDialer.port = ctxDialer.port + 300
ctxDialer.dialer = ctxDialer.contextDial
temp = append(temp, direct, dialer, ctxDialer)
}
for _, cfg := range temp {
t.Logf("Running %s", cfg.name)
cfg.run(t)
}
}

@ -0,0 +1,347 @@
package http
import (
"crypto/rsa"
"encoding/hex"
"fmt"
"io"
"math/big"
"net"
"strings"
"testing"
"time"
"github.com/zmap/zcrypto/tls"
"github.com/zmap/zgrab2"
)
// BEGIN Taken from handshake_server_test.go -- certs for TLS server
func fromHex(s string) []byte {
b, _ := hex.DecodeString(s)
return b
}
func bigFromString(s string) *big.Int {
ret := new(big.Int)
ret.SetString(s, 10)
return ret
}
var testRSACertificate = fromHex("308202b030820219a00302010202090085b0bba48a7fb8ca300d06092a864886f70d01010505003045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3130303432343039303933385a170d3131303432343039303933385a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819f300d06092a864886f70d010101050003818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a381a73081a4301d0603551d0e04160414b1ade2855acfcb28db69ce2369ded3268e18883930750603551d23046e306c8014b1ade2855acfcb28db69ce2369ded3268e188839a149a4473045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746482090085b0bba48a7fb8ca300c0603551d13040530030101ff300d06092a864886f70d010105050003818100086c4524c76bb159ab0c52ccf2b014d7879d7a6475b55a9566e4c52b8eae12661feb4f38b36e60d392fdf74108b52513b1187a24fb301dbaed98b917ece7d73159db95d31d78ea50565cd5825a2d5a5f33c4b6d8c97590968c0f5298b5cd981f89205ff2a01ca31b9694dda9fd57e970e8266d71999b266e3850296c90a7bdd9")
var testSNICertificate = fromHex("308201f23082015da003020102020100300b06092a864886f70d01010530283110300e060355040a130741636d6520436f311430120603550403130b736e69746573742e636f6d301e170d3132303431313137343033355a170d3133303431313137343533355a30283110300e060355040a130741636d6520436f311430120603550403130b736e69746573742e636f6d30819d300b06092a864886f70d01010103818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a3323030300e0603551d0f0101ff0404030200a0300d0603551d0e0406040401020304300f0603551d2304083006800401020304300b06092a864886f70d0101050381810089c6455f1c1f5ef8eb1ab174ee2439059f5c4259bb1a8d86cdb1d056f56a717da40e95ab90f59e8deaf627c157995094db0802266eb34fc6842dea8a4b68d9c1389103ab84fb9e1f85d9b5d23ff2312c8670fbb540148245a4ebafe264d90c8a4cf4f85b0fac12ac2fc4a3154bad52462868af96c62c6525d652b6e31845bdcc")
var testRSAPrivateKey = &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: bigFromString("131650079503776001033793877885499001334664249354723305978524647182322416328664556247316495448366990052837680518067798333412266673813370895702118944398081598789828837447552603077848001020611640547221687072142537202428102790818451901395596882588063427854225330436740647715202971973145151161964464812406232198521"),
E: 65537,
},
D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"),
Primes: []*big.Int{
bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"),
bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"),
},
}
// END Taken from handshake_server_test.go -- certs for TLS server
// Get the tls.Config object for the server; adapted from handshake_server_test.go.
func getTLSConfig() *tls.Config {
testConfig := &tls.Config{
Time: func() time.Time { return time.Unix(0, 0) },
Certificates: make([]tls.Certificate, 2),
InsecureSkipVerify: true,
MinVersion: tls.VersionSSL30,
MaxVersion: tls.VersionTLS12,
}
testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate}
testConfig.Certificates[0].PrivateKey = testRSAPrivateKey
testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate}
testConfig.Certificates[1].PrivateKey = testRSAPrivateKey
testConfig.BuildNameToCertificate()
return testConfig
}
// Helper function to write and check for short writes
func _write(writer io.Writer, data []byte) error {
n, err := writer.Write(data)
if err == nil && len(data) != n {
err = io.ErrShortWrite
}
return err
}
// Start a local server that sends responds to any requests with a cfg.headerSize-byte set of
// headers followed by a cfg.bodySize-byte body.
// The response ends up looking like this:
// HTTP/1.0 200 OK
// Bogus-Header: XXX...
// Content-Length: <bodySize>
//
// XXXX....
func (cfg *readLimitTestConfig) runFakeHTTPServer(t *testing.T) {
endpoint := fmt.Sprintf("127.0.0.1:%d", cfg.port)
listener, err := net.Listen("tcp", endpoint)
if err != nil {
t.Fatal(err)
}
go func() {
defer listener.Close()
sock, err := listener.Accept()
if err != nil {
t.Fatal(err)
}
defer sock.Close()
if cfg.tls {
tlsSock := tls.Server(sock, getTLSConfig())
if err := tlsSock.Handshake(); err != nil {
t.Fatalf("server handshake error: %v", err)
}
sock = tlsSock
}
// don't care what the client sends, always respond with a HTTP-like response
buf := make([]byte, 1)
_, err = sock.Read(buf)
if err != nil {
// any error, including EOF, is unexpected -- the client should send something
t.Fatalf("Unexpected error reading from client: %v", err)
}
head := "HTTP/1.0 200 OK\r\nBogus-Header: X"
headSuffix := fmt.Sprintf("\r\nContent-Length: %d\r\n\r\n", cfg.bodySize)
size := cfg.headerSize - len(head) - len(headSuffix)
if size < 0 {
t.Fatalf("Header size %d too small: must be at least %d bytes", cfg.headerSize, len(head)+len(headSuffix))
}
if err := _write(sock, []byte(head)); err != nil {
t.Fatalf("write error: %v", err)
}
chunkSize := 256
sent := len(head)
chunk := []byte(strings.Repeat("X", chunkSize))
for i := 0; i < size; i += chunkSize {
if i+chunkSize > size {
chunk = []byte(strings.Repeat("X", size-i))
}
if err := _write(sock, chunk); err != nil {
t.Logf("Failed writing to client after %d bytes: %v", sent, err)
return
}
sent += len(chunk)
}
if err := _write(sock, []byte(headSuffix)); err != nil {
t.Logf("Failed writing foot to client: %v", err)
return
}
sent += len(headSuffix)
body := strings.Repeat("X", cfg.bodySize)
if err := _write(sock, []byte(body)); err != nil {
t.Logf("Failed writing body to client: %v", err)
return
}
}()
}
// Get an HTTP scanner module with the desired config
func (cfg *readLimitTestConfig) getScanner(t *testing.T) *Scanner {
var module Module
flags := module.NewFlags().(*Flags)
flags.Endpoint = "/"
flags.Method = "GET"
flags.UserAgent = "Mozilla/5.0 zgrab/0.x"
if cfg.maxBodySize&0x03ff != 0 {
t.Fatalf("%d is not a valid maxBodySize (must be a multiple of 1024)", cfg.maxBodySize)
}
flags.MaxSize = cfg.maxBodySize / 1024
flags.MaxRedirects = 0
flags.Timeout = 1 * time.Second
flags.Port = uint(cfg.port)
flags.UseHTTPS = cfg.tls
zgrab2.DefaultBytesReadLimit = cfg.maxReadSize
scanner := module.NewScanner()
scanner.Init(flags)
return scanner.(*Scanner)
}
// Configuration for a single test run
type readLimitTestConfig struct {
// if true, the client/server will use TLS. NOTE: the limits are on the *raw* connection.
tls bool
// port where the server listens.
port int
// Bodies larger than this are truncated. NOTE: this must be a multiple of 1024, since MaxSize
// is given in kilobytes.
maxBodySize int
// The maximum number of bytes to read from the (raw) socket. Beyond that data is truncated and
// EOF is returned.
maxReadSize int
// The size of the HTTP server's "header" (actually, all of the data before the body). Must be
// at least 58 (the size of the static parts of the response).
headerSize int
// The size of the HTTP body to send (the Content-Length).
bodySize int
// The status that should be returned by the scan.
expectedStatus zgrab2.ScanStatus
// If set, the error returned by the scan must contain this.
expectedError string
}
const (
readLimitTestConfigHTTPBasePort = 0x7f7f
readLimitTestConfigHTTPSBasePort = 0x7bbc
)
var readLimitTestConfigs = map[string]*readLimitTestConfig{
// The socket truncates the connection while reading the body. To the client it looks as if the
// server closed the connection prior to sending Content-Length bytes; the result is success,
// but with a truncated body.
// bodySize + headerSize > maxReadSize > headerSize
"truncate_read_body": {
tls: false,
port: readLimitTestConfigHTTPBasePort,
maxBodySize: 2048,
maxReadSize: 1024,
headerSize: 64,
bodySize: 4096,
expectedStatus: zgrab2.SCAN_SUCCESS,
},
// NOTE: There is no tls_truncate_read_body, since the truncation will almost certainly occur
// in the middle of a TLS packet -- so the response would always be "unexpected EOF"
// The HTTP library stops reading the body after reaching its internal limit. It returns success
// and the truncated body.
// maxReadSize > headerSize + bodySize > bodySize > maxBodySize
"truncate_body": {
tls: false,
port: readLimitTestConfigHTTPBasePort + 1,
maxBodySize: 2048,
maxReadSize: 8192,
headerSize: 64,
bodySize: 4096,
expectedStatus: zgrab2.SCAN_SUCCESS,
},
"tls_truncate_body": {
tls: true,
port: readLimitTestConfigHTTPSBasePort + 1,
maxBodySize: 2048,
maxReadSize: 8192,
headerSize: 64,
bodySize: 4096,
expectedStatus: zgrab2.SCAN_SUCCESS,
},
// The socket truncates the connection while reading the headers. The result isn't a valid HTTP
// response, so the library returns an unexpected EOF error.
// headerSize > maxReadSize
"truncate_read_header": {
tls: false,
port: readLimitTestConfigHTTPBasePort + 2,
maxBodySize: 1024,
maxReadSize: 2048,
headerSize: 3072,
bodySize: 8,
expectedError: "unexpected EOF",
expectedStatus: zgrab2.SCAN_UNKNOWN_ERROR,
},
"tls_truncate_read_header": {
tls: true,
port: readLimitTestConfigHTTPSBasePort + 2,
maxBodySize: 1024,
maxReadSize: 2048,
headerSize: 3072,
bodySize: 8,
expectedError: "unexpected EOF",
expectedStatus: zgrab2.SCAN_UNKNOWN_ERROR,
},
// Happy case. None of the limits are hit.
// maxReadSize >= maxBodySize > bodySize + headerSize
"happy_case": {
tls: false,
port: readLimitTestConfigHTTPBasePort + 3,
maxBodySize: 8192,
maxReadSize: 8192,
headerSize: 1024,
bodySize: 1024,
expectedStatus: zgrab2.SCAN_SUCCESS,
},
"tls_happy_case": {
tls: true,
port: readLimitTestConfigHTTPSBasePort + 3,
maxBodySize: 8192,
maxReadSize: 8192,
headerSize: 1024,
bodySize: 1024,
expectedStatus: zgrab2.SCAN_SUCCESS,
},
}
// Try to get the HTTP body from a result; otherwise return the empty string.
func getBody(result interface{}) string {
if result == nil {
return ""
}
httpResult, ok := result.(*Results)
if !ok {
return ""
}
response := httpResult.Response
if response == nil {
return ""
}
return response.BodyText
}
// Run a single test with the given configuration.
func (cfg *readLimitTestConfig) runTest(t *testing.T, testName string) {
scanner := cfg.getScanner(t)
cfg.runFakeHTTPServer(t)
target := zgrab2.ScanTarget{
IP: net.ParseIP("127.0.0.1"),
}
status, ret, err := scanner.Scan(target)
if status != cfg.expectedStatus {
t.Errorf("Wrong status: expected %s, got %s", cfg.expectedStatus, status)
}
if err != nil {
if !strings.Contains(err.Error(), cfg.expectedError) {
t.Errorf("Wrong error: expected %s, got %s", err.Error(), cfg.expectedError)
}
} else if len(cfg.expectedError) > 0 {
t.Errorf("Expected error '%s' but got none", cfg.expectedError)
}
if cfg.expectedStatus == zgrab2.SCAN_SUCCESS {
body := getBody(ret)
if body == "" {
t.Errorf("Expected success, but got no body")
} else {
if len(body) > cfg.maxBodySize || len(body) > cfg.maxReadSize {
t.Errorf("Body exceeds max size: len(body)=%d; maxBodySize=%d, maxReadSize=%d", len(body), cfg.maxBodySize, cfg.maxReadSize)
}
if !cfg.tls {
if len(body)+cfg.headerSize > cfg.maxReadSize {
t.Errorf("Body and header exceed max read size: len(body)=%d, headerSize=%d, maxReadSize=%d", len(body), cfg.headerSize, cfg.maxReadSize)
}
}
}
}
}
// TestReadLimitHTTP checks that the HTTP scanner works as expected with the default
// ReadLimitExeededAction (specifically, ReadLimnitExceededActionTruncate) defined in conn.go.
func TestReadLimitHTTP(t *testing.T) {
if zgrab2.DefaultReadLimitExceededAction != zgrab2.ReadLimitExceededActionTruncate {
t.Logf("Warning: DefaultReadLimitExceededAction is %s, not %s", zgrab2.DefaultReadLimitExceededAction, zgrab2.ReadLimitExceededActionTruncate)
}
for testName, cfg := range readLimitTestConfigs {
cfg.runTest(t, testName)
}
}

@ -8,12 +8,14 @@ package http
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"io"
"net"
"net/url"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"github.com/zmap/zgrab2"
@ -76,13 +78,14 @@ type Scanner struct {
// scan holds the state for a single scan. This may entail multiple connections.
// It is used to implement the zgrab2.Scanner interface.
type scan struct {
connections []net.Conn
scanner *Scanner
target *zgrab2.ScanTarget
transport *http.Transport
client *http.Client
results Results
url string
connections []net.Conn
scanner *Scanner
target *zgrab2.ScanTarget
transport *http.Transport
client *http.Client
results Results
url string
globalDeadline time.Time
}
// NewFlags returns an empty Flags object.
@ -142,15 +145,40 @@ func (scan *scan) Cleanup() {
}
}
// Get a context whose deadline is the earliest of the context's deadline (if it has one) and the
// global scan deadline.
func (scan *scan) withDeadlineContext(ctx context.Context) context.Context {
ctxDeadline, ok := ctx.Deadline()
if !ok || scan.globalDeadline.Before(ctxDeadline) {
ret, _ := context.WithDeadline(ctx, scan.globalDeadline)
return ret
}
return ctx
}
// Dial a connection using the configured timeouts, as well as the global deadline, and on success,
// add the connection to the list of connections to be cleaned up.
func (scan *scan) dialContext(ctx context.Context, net string, addr string) (net.Conn, error) {
dialer := zgrab2.GetTimeoutConnectionDialer(scan.scanner.config.Timeout)
timeoutContext, _ := context.WithTimeout(context.Background(), scan.scanner.config.Timeout)
conn, err := dialer.DialContext(scan.withDeadlineContext(timeoutContext), net, addr)
if err != nil {
return nil, err
}
scan.connections = append(scan.connections, conn)
return conn, nil
}
// getTLSDialer returns a Dial function that connects using the
// zgrab2.GetTLSConnection()
func (scan *scan) getTLSDialer() func(net, addr string) (net.Conn, error) {
return func(net, addr string) (net.Conn, error) {
outer, err := zgrab2.DialTimeoutConnection(net, addr, scan.scanner.config.Timeout)
outer, err := scan.dialContext(context.Background(), net, addr)
if err != nil {
return nil, err
}
scan.connections = append(scan.connections, outer)
tlsConn, err := scan.scanner.config.TLSFlags.GetTLSConnection(outer)
if err != nil {
return nil, err
@ -242,10 +270,11 @@ func (scanner *Scanner) newHTTPScan(t *zgrab2.ScanTarget) *scan {
DisableCompression: false,
MaxIdleConnsPerHost: scanner.config.MaxRedirects,
},
client: http.MakeNewClient(),
client: http.MakeNewClient(),
globalDeadline: time.Now().Add(scanner.config.Timeout),
}
ret.transport.DialTLS = ret.getTLSDialer()
ret.transport.DialContext = zgrab2.GetTimeoutConnectionDialer(scanner.config.Timeout).DialContext
ret.transport.DialContext = ret.dialContext
ret.client.UserAgent = scanner.config.UserAgent
ret.client.CheckRedirect = ret.getCheckRedirect()
ret.client.Transport = ret.transport
@ -334,6 +363,7 @@ func (scanner *Scanner) Scan(t zgrab2.ScanTarget) (zgrab2.ScanStatus, interface{
// zgrab2 framework.
func RegisterModule() {
var module Module
_, err := zgrab2.AddCommand("http", "HTTP Banner Grab", "Grab a banner over HTTP", 80, &module)
if err != nil {
log.Fatal(err)

@ -11,7 +11,7 @@ import (
"strings"
"time"
logrus "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
"github.com/zmap/zgrab2"
)
@ -401,6 +401,9 @@ func (options PreloginOptions) Encode() ([]byte, error) {
for _, ik := range sortedKeys {
k := PreloginOptionToken(ik)
v := options[k]
if len(cursor) < 5 {
return nil, fmt.Errorf("encode: size mismatch (options.Size()=%d)", options.Size())
}
cursor[0] = byte(k)
if offset > 0xffff {
return nil, ErrTooLarge
@ -411,6 +414,9 @@ func (options PreloginOptions) Encode() ([]byte, error) {
offset += len(v)
cursor = cursor[5:]
}
if len(cursor) < 1 {
return nil, fmt.Errorf("encode: size mismatch (options.Size()=%d, len(sortedKeys)=%d)", options.Size(), len(sortedKeys))
}
// Write the terminator after the last PL_OPTION header
// (and just before the first value)
cursor[0] = 0xff
@ -421,6 +427,9 @@ func (options PreloginOptions) Encode() ([]byte, error) {
// returned in rest.
// If body can't be decoded as a PRELOGIN body, returns nil, nil, ErrInvalidData
func decodePreloginOptions(body []byte) (result *PreloginOptions, rest []byte, err error) {
if len(body) < 1 {
return nil, nil, ErrInvalidData
}
cursor := body[:]
options := make(PreloginOptions)
max := 0
@ -506,7 +515,13 @@ func (options PreloginOptions) MarshalJSON() ([]byte, error) {
fedAuthRequired, hasFedAuthRequired := opts[PreloginFedAuthRequired]
if hasFedAuthRequired {
aux.FedAuthRequired = &fedAuthRequired[0]
temp := uint8(0)
if len(fedAuthRequired) > 0 {
temp = fedAuthRequired[0]
} else {
logrus.Debugf("fedAuthRequired was present but empty (options=%#v)", options)
}
aux.FedAuthRequired = &temp
}
nonce, hasNonce := opts[PreloginNonce]
@ -632,6 +647,7 @@ func (connection *Connection) readPreloginPacket() (*TDSPacket, *PreloginOptions
if packet.Type != TDSPacketTypeTabularResult {
return packet, nil, &zgrab2.ScanError{Status: zgrab2.SCAN_APPLICATION_ERROR, Err: err}
}
defer zgrab2.LogPanic("Error decoding Prelogin packet %#v", packet.Body)
plOptions, rest, err := decodePreloginOptions(packet.Body)
if err != nil {
return packet, nil, err

@ -82,10 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er
if err != nil {
return nil, err
}
return &TimeoutConnection{
Conn: conn,
Timeout: flags.Timeout,
}, nil
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil
}
// grabTarget calls handler for each action

@ -10,6 +10,8 @@ import (
"time"
"github.com/zmap/zflags"
"github.com/sirupsen/logrus"
"runtime/debug"
)
var parser *flags.Parser
@ -209,3 +211,18 @@ func IsTimeoutError(err error) bool {
return false
}
// LogPanic is intended to be called from within defer -- if there was no panic, it returns without
// doing anything. Otherwise, it logs the stacktrace, the panic error, and the provided message
// before re-raising the original panic.
// Example:
// defer zgrab2.LogPanic("Error decoding body '%x'", body)
func LogPanic(format string, args...interface{}) {
err := recover()
if err == nil {
return
}
logrus.Errorf("Uncaught panic at %s: %v", string(debug.Stack()), err)
logrus.Errorf(format, args...)
panic(err)
}