Merge branch 'jb/session-wide-timeout' into jb/mssqlBoundsChecking
This commit is contained in:
commit
65a7c8a578
@ -21,6 +21,7 @@ type Config struct {
|
||||
Debug bool `long:"debug" description:"Include debug fields in the output."`
|
||||
GOMAXPROCS int `long:"gomaxprocs" default:"0" description:"Set GOMAXPROCS"`
|
||||
ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"`
|
||||
ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"`
|
||||
Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."`
|
||||
Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"`
|
||||
inputFile *os.File
|
||||
@ -118,6 +119,11 @@ func validateFrameworkConfiguration() {
|
||||
if config.ConnectionsPerHost > 50 {
|
||||
log.Fatalf("connectionsPerHost must be in the range [0,50]")
|
||||
}
|
||||
|
||||
// Stop even third-party libraries from performing unbounded reads on untrusted hosts
|
||||
if config.ReadLimitPerHost > 0 {
|
||||
DefaultBytesReadLimit = config.ReadLimitPerHost * 1024
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetaFile returns the file to which metadata should be output
|
||||
|
272
conn.go
272
conn.go
@ -1,49 +1,125 @@
|
||||
package zgrab2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ReadLimitExceededAction describes how the connection reacts to an attempt to read more data than permitted.
|
||||
type ReadLimitExceededAction string
|
||||
|
||||
const (
|
||||
// ReadLimitExceededActionNotSet is a placeholder for the zero value, so that explicitly set values can be
|
||||
// distinguished from the empty default.
|
||||
ReadLimitExceededActionNotSet = ReadLimitExceededAction("")
|
||||
|
||||
// ReadLimitExceededActionTruncate causes the connection to truncate at BytesReadLimit bytes and return a bogus
|
||||
// io.EOF error. The fact that a truncation took place is logged at debug level.
|
||||
ReadLimitExceededActionTruncate = ReadLimitExceededAction("truncate")
|
||||
|
||||
// ReadLimitExceededActionError causes the Read call to return n, ErrReadLimitExceeded (in addition to truncating).
|
||||
ReadLimitExceededActionError = ReadLimitExceededAction("error")
|
||||
|
||||
// ReadLimitExceededActionPanic causes the Read call to panic(ErrReadLimitExceeded).
|
||||
ReadLimitExceededActionPanic = ReadLimitExceededAction("panic")
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBytesReadLimit is the maximum number of bytes to read per connection when no explicit value is provided.
|
||||
DefaultBytesReadLimit = 256 * 1024 * 1024
|
||||
|
||||
// DefaultReadLimitExceededAction is the action used when no explicit action is set.
|
||||
DefaultReadLimitExceededAction = ReadLimitExceededActionTruncate
|
||||
|
||||
// DefaultSessionTimeout is the default maximum time a connection may be used when no explicit value is provided.
|
||||
DefaultSessionTimeout = 1 * time.Minute
|
||||
)
|
||||
|
||||
var ErrReadLimitExceeded error = errors.New("read limit exceeded")
|
||||
|
||||
// TODO: Refactor this into TimeoutConnection, BoundedReader, LoggedReader, etc
|
||||
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
|
||||
type TimeoutConnection struct {
|
||||
net.Conn
|
||||
Timeout time.Duration
|
||||
explicitReadDeadline bool
|
||||
explicitWriteDeadline bool
|
||||
explicitDeadline bool
|
||||
ctx context.Context
|
||||
Timeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
BytesRead int
|
||||
BytesWritten int
|
||||
BytesReadLimit int
|
||||
ReadLimitExceededAction ReadLimitExceededAction
|
||||
Cancel context.CancelFunc
|
||||
explicitReadDeadline bool
|
||||
explicitWriteDeadline bool
|
||||
explicitDeadline bool
|
||||
}
|
||||
|
||||
// TimeoutConnection.Read calls Read() on the underlying connection, using any configured deadlines
|
||||
func (c *TimeoutConnection) Read(b []byte) (n int, err error) {
|
||||
if err := c.checkContext(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
origSize := len(b)
|
||||
if c.BytesRead+len(b) >= c.BytesReadLimit {
|
||||
b = b[0 : c.BytesReadLimit-c.BytesRead]
|
||||
}
|
||||
if c.explicitReadDeadline || c.explicitDeadline {
|
||||
c.explicitReadDeadline = false
|
||||
c.explicitDeadline = false
|
||||
} else if c.Timeout > 0 {
|
||||
if err = c.Conn.SetReadDeadline(time.Now().Add(c.Timeout)); err != nil {
|
||||
} else if readTimeout := c.getTimeout(c.ReadTimeout); readTimeout > 0 {
|
||||
if err = c.Conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
n, err = c.Conn.Read(b)
|
||||
c.BytesRead += n
|
||||
if err == nil && origSize != len(b) && n == len(b) {
|
||||
// we had to shrink the output buffer AND we used up the whole shrunk size, AND we're not at EOF
|
||||
switch c.ReadLimitExceededAction {
|
||||
case ReadLimitExceededActionTruncate:
|
||||
logrus.Debug("Truncated read from %d bytes to %d bytes (hit limit of %d bytes)", origSize, n, c.BytesReadLimit)
|
||||
err = io.EOF
|
||||
case ReadLimitExceededActionError:
|
||||
return n, ErrReadLimitExceeded
|
||||
case ReadLimitExceededActionPanic:
|
||||
panic(ErrReadLimitExceeded)
|
||||
default:
|
||||
logrus.Fatalf("Unrecognized ReadLimitExceededAction: %s", c.ReadLimitExceededAction)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines
|
||||
// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines.
|
||||
func (c *TimeoutConnection) Write(b []byte) (n int, err error) {
|
||||
if err := c.checkContext(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if c.explicitWriteDeadline || c.explicitDeadline {
|
||||
c.explicitWriteDeadline = false
|
||||
c.explicitDeadline = false
|
||||
} else if c.Timeout > 0 {
|
||||
if err = c.Conn.SetWriteDeadline(time.Now().Add(c.Timeout)); err != nil {
|
||||
} else if writeTimeout := c.getTimeout(c.WriteTimeout); writeTimeout > 0 {
|
||||
if err = c.Conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
n, err = c.Conn.Write(b)
|
||||
c.BytesWritten += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// SetReadDeadline sets an explicit ReadDeadline that will override the timeout
|
||||
// for one read. Use deadline = 0 to clear the deadline.
|
||||
func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error {
|
||||
if err := c.checkContext(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !deadline.IsZero() {
|
||||
err := c.Conn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
@ -57,6 +133,9 @@ func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error {
|
||||
// SetWriteDeadline sets an explicit WriteDeadline that will override the
|
||||
// WriteDeadline for one write. Use deadline = 0 to clear the deadline.
|
||||
func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error {
|
||||
if err := c.checkContext(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !deadline.IsZero() {
|
||||
err := c.Conn.SetWriteDeadline(deadline)
|
||||
if err != nil {
|
||||
@ -70,6 +149,9 @@ func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error {
|
||||
// SetDeadline sets a read / write deadline that will override the deadline for
|
||||
// a single read/write. Use deadline = 0 to clear the deadline.
|
||||
func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
|
||||
if err := c.checkContext(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !deadline.IsZero() {
|
||||
err := c.Conn.SetDeadline(deadline)
|
||||
if err != nil {
|
||||
@ -80,8 +162,8 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTimeoutDialer returns a Dialer function that dials with the given timeout
|
||||
func GetTimeoutDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
|
||||
// GetTimeoutDialer returns a DialFuncn that dials with the given timeout
|
||||
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
|
||||
return func(proto, target string) (net.Conn, error) {
|
||||
return DialTimeoutConnection(proto, target, timeout)
|
||||
}
|
||||
@ -92,14 +174,67 @@ func (c *TimeoutConnection) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
||||
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
|
||||
// Get the timeout for the given field, falling back to the global timeout.
|
||||
func (c *TimeoutConnection) getTimeout(field time.Duration) time.Duration {
|
||||
if field == 0 {
|
||||
return c.Timeout
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
// Check if the context has been cancelled, and if so, return an error (either the context error, or
|
||||
// if the context error is nil, ErrTotalTimeout).
|
||||
func (c *TimeoutConnection) checkContext() error {
|
||||
if c.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
if err := c.ctx.Err(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return ErrTotalTimeout
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
|
||||
if c.BytesReadLimit == 0 {
|
||||
c.BytesReadLimit = DefaultBytesReadLimit
|
||||
}
|
||||
if c.ReadLimitExceededAction == ReadLimitExceededActionNotSet {
|
||||
c.ReadLimitExceededAction = DefaultReadLimitExceededAction
|
||||
}
|
||||
if c.Timeout == 0 {
|
||||
c.Timeout = DefaultSessionTimeout
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection {
|
||||
ret := (&TimeoutConnection{
|
||||
Conn: conn,
|
||||
Timeout: timeout,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
}).SetDefaults()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ret.ctx, ret.Cancel = context.WithTimeout(ctx, timeout)
|
||||
return ret
|
||||
}
|
||||
|
||||
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
||||
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if timeout > 0 {
|
||||
conn, err = net.DialTimeout(proto, target, timeout)
|
||||
if dialTimeout > 0 {
|
||||
conn, err = net.DialTimeout(proto, target, dialTimeout)
|
||||
} else {
|
||||
conn, err = net.Dial(proto, target)
|
||||
conn, err = net.DialTimeout(proto, target, sessionTimeout)
|
||||
}
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
@ -107,49 +242,104 @@ func DialTimeoutConnection(proto string, target string, timeout time.Duration) (
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &TimeoutConnection{
|
||||
Conn: conn,
|
||||
Timeout: timeout,
|
||||
}, nil
|
||||
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil
|
||||
}
|
||||
|
||||
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
|
||||
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
|
||||
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout)
|
||||
}
|
||||
|
||||
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
|
||||
type Dialer struct {
|
||||
// Timeout is the maximum time to wait for a connection or I/O.
|
||||
// Timeout is the maximum time to wait for the entire session, after which any operations on the
|
||||
// connection will fail.
|
||||
Timeout time.Duration
|
||||
|
||||
// dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a TimeoutConnection).
|
||||
// ConnectTimeout is the maximum time to wait for a connection.
|
||||
ConnectTimeout time.Duration
|
||||
|
||||
// ReadTimeout is the maximum time to wait for a Read
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout is the maximum time to wait for a Write
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// Dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a
|
||||
// TimeoutConnection).
|
||||
Dialer *net.Dialer
|
||||
|
||||
// BytesReadLimit is the maximum number of bytes that connections dialed with this dialer will
|
||||
// read before erroring.
|
||||
BytesReadLimit int
|
||||
|
||||
// ReadLimitExceededAction describes how connections dialed with this dialer deal with exceeding
|
||||
// the BytesReadLimit.
|
||||
ReadLimitExceededAction ReadLimitExceededAction
|
||||
}
|
||||
|
||||
func (d *Dialer) getTimeout(field time.Duration) time.Duration {
|
||||
if field == 0 {
|
||||
return d.Timeout
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
// DialContext wraps the connection returned by net.Dialer.DialContext() with a TimeoutConnection.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// ensure that our aux dialer is up-to-date
|
||||
d.Dialer.Timeout = d.Timeout
|
||||
if d.Timeout != 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, d.Timeout)
|
||||
}
|
||||
// ensure that our aux dialer is up-to-date; copied from http/transport.go
|
||||
d.Dialer.Timeout = d.getTimeout(d.ConnectTimeout)
|
||||
d.Dialer.KeepAlive = d.Timeout
|
||||
ret, err := d.Dialer.DialContext(ctx, network, address)
|
||||
dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout)
|
||||
defer cancelDial()
|
||||
conn, err := d.Dialer.DialContext(dialContext, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TimeoutConnection{
|
||||
Conn: ret,
|
||||
Timeout: d.Timeout,
|
||||
}, nil
|
||||
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout)
|
||||
ret.BytesReadLimit = d.BytesReadLimit
|
||||
ret.ReadLimitExceededAction = d.ReadLimitExceededAction
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Dial returns a connection with the configured timeout.
|
||||
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
|
||||
return DialTimeoutConnection(proto, target, d.Timeout)
|
||||
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout)
|
||||
}
|
||||
|
||||
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.
|
||||
func GetTimeoutConnectionDialer(timeout time.Duration) *Dialer {
|
||||
return &Dialer{
|
||||
Timeout: timeout,
|
||||
Dialer: &net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: timeout,
|
||||
DualStack: true,
|
||||
},
|
||||
return NewDialer(&Dialer{Timeout: timeout})
|
||||
}
|
||||
|
||||
// SetDefaults for the Dialer.
|
||||
func (d *Dialer) SetDefaults() *Dialer {
|
||||
if d.Timeout == 0 {
|
||||
d.Timeout = DefaultSessionTimeout
|
||||
}
|
||||
}
|
||||
if d.ReadLimitExceededAction == ReadLimitExceededActionNotSet {
|
||||
d.ReadLimitExceededAction = DefaultReadLimitExceededAction
|
||||
}
|
||||
if d.BytesReadLimit == 0 {
|
||||
d.BytesReadLimit = DefaultBytesReadLimit
|
||||
}
|
||||
if d.Dialer == nil {
|
||||
d.Dialer = &net.Dialer{
|
||||
Timeout: d.Timeout,
|
||||
KeepAlive: d.Timeout,
|
||||
DualStack: true,
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// Create a new Dialer with default settings.
|
||||
func NewDialer(value *Dialer) *Dialer {
|
||||
if value == nil {
|
||||
value = &Dialer{}
|
||||
}
|
||||
return value.SetDefaults()
|
||||
}
|
||||
|
396
conn_bytelimit_test.go
Normal file
396
conn_bytelimit_test.go
Normal file
@ -0,0 +1,396 @@
|
||||
package zgrab2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Start a local echo server on port.
|
||||
func runEchoServer(t *testing.T, port int) {
|
||||
endpoint := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
listener, err := net.Listen("tcp", endpoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
sock, err := listener.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sock.Close()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := sock.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF && !strings.Contains(err.Error(), "connection reset") {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
sock.SetWriteDeadline(time.Now().Add(time.Millisecond * 250))
|
||||
n, err = sock.Write(buf[0:n])
|
||||
if err != nil {
|
||||
if err != io.EOF && !strings.Contains(err.Error(), "connection reset") && !strings.Contains(err.Error(), "broken pipe") {
|
||||
t.Logf("Unexpected error writing to client: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Interface for getting a TimeoutConnection; we want to test both the dialer and the direct Dial functions.
|
||||
type timeoutConnector interface {
|
||||
connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error)
|
||||
getConfig() readLimitTestConfig
|
||||
}
|
||||
|
||||
// Config for a test case
|
||||
type readLimitTestConfig struct {
|
||||
// The maximum bytes the connection should read
|
||||
limit int
|
||||
// The number of bytes that should be sent (so iff sendSize > limit, the action should be triggered)
|
||||
sendSize int
|
||||
// The action to run when too much data is sent
|
||||
action ReadLimitExceededAction
|
||||
}
|
||||
|
||||
// Call sendReceive(), and check that the input/output match, and that any expected errors / truncation occurs.
|
||||
func checkedSendReceive(t *testing.T, conn *TimeoutConnection, size int) (result error) {
|
||||
// helper to report + return an error
|
||||
tErrorf := func(format string, args ...interface{}) error {
|
||||
result = fmt.Errorf(format, args)
|
||||
t.Error(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// We will check that this increases by the correct size
|
||||
before := conn.BytesRead
|
||||
|
||||
// This is true if we expect an overflow to occur (and so the ReadLimitExceededAction should fire)
|
||||
overflowed := (before + size) > conn.BytesReadLimit
|
||||
|
||||
// Don't want to keep re-typing this
|
||||
action := conn.ReadLimitExceededAction
|
||||
|
||||
defer func() {
|
||||
if result != nil {
|
||||
// log any previous error -- more may still follow
|
||||
t.Error(result)
|
||||
}
|
||||
err := recover()
|
||||
if err != nil {
|
||||
if action != ReadLimitExceededActionPanic {
|
||||
// no reason to panic unless that is the action
|
||||
panic(err)
|
||||
}
|
||||
if !overflowed {
|
||||
tErrorf("panicked early: only sent %d bytes so far, but limit=%d", before+size, conn.BytesReadLimit)
|
||||
return
|
||||
}
|
||||
if err == ErrReadLimitExceeded {
|
||||
// We read too much data and this is the right error: silently succeed
|
||||
return
|
||||
}
|
||||
tErrorf("wrong panic error: got %v, expected ErrReadlimitExceeded", err)
|
||||
return
|
||||
}
|
||||
|
||||
if action != ReadLimitExceededActionPanic {
|
||||
// other action -- fine that we didn't panic
|
||||
return
|
||||
}
|
||||
if !overflowed {
|
||||
// not enough bytes read to overflow -- fine that we didn't panic
|
||||
return
|
||||
}
|
||||
// ReadLimitExceededActionPanic, read too many bytes: should have panicked but didn't
|
||||
tErrorf("should have panicked: action=ReadLimitExceededActionPanic, but sent without issue")
|
||||
}()
|
||||
|
||||
ret, err := sendReceive(t, conn, size)
|
||||
|
||||
if err != nil {
|
||||
if !overflowed {
|
||||
// If there is no overflow, there should be no error
|
||||
return tErrorf("read: unexpected error: %v", err)
|
||||
}
|
||||
if err != io.EOF && err != ErrReadLimitExceeded {
|
||||
// EOF and ErrReadLimitExceeded are the only errors that should be returned
|
||||
return tErrorf("read: wrong error: %v", err)
|
||||
}
|
||||
if err == io.EOF && action != ReadLimitExceededActionTruncate {
|
||||
// EOF should only occur with truncation
|
||||
return tErrorf("read: unexpected EOF")
|
||||
}
|
||||
if err == ErrReadLimitExceeded && action != ReadLimitExceededActionError {
|
||||
// ErrReadLimitExceeded should only occur with ReadLimitExceededActionError
|
||||
return tErrorf("read: unexpected ErrReadLimitExceeded")
|
||||
}
|
||||
// Otherwise, fall through -- we still need to check that the data matches
|
||||
} else {
|
||||
if overflowed && action == ReadLimitExceededActionError {
|
||||
return tErrorf("read: should have gotten an error, but did not")
|
||||
}
|
||||
}
|
||||
expectedSize := size
|
||||
if overflowed {
|
||||
expectedSize = conn.BytesReadLimit - before
|
||||
}
|
||||
|
||||
if conn.BytesRead != before+expectedSize {
|
||||
return tErrorf("check: BytesRead value inconsistent; expected %d, got %d", before+expectedSize, conn.BytesRead)
|
||||
}
|
||||
if len(ret) != expectedSize {
|
||||
return tErrorf("check: expected %d bytes, got %d", expectedSize, len(ret))
|
||||
}
|
||||
if expectedSize > 0 && !checkTestBuffer(ret) {
|
||||
return tErrorf("Got back invalid data (%x)", ret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send size testBuffer bytes to conn, then perform a read, and return the result/error.
|
||||
func sendReceive(t *testing.T, conn *TimeoutConnection, size int) ([]byte, error) {
|
||||
toSend := getTestBuffer(size)
|
||||
n, err := conn.Write(toSend)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if n != len(toSend) {
|
||||
t.Fatalf("Short write: expected to send %d bytes, returned %d", len(toSend), n)
|
||||
return nil, io.ErrShortWrite
|
||||
}
|
||||
readBuf := make([]byte, size)
|
||||
n, err = conn.Read(readBuf)
|
||||
return readBuf[0:n], err
|
||||
}
|
||||
|
||||
// Get a size-byte slice of sequential bytes (mod 256), starting from 0
|
||||
func getTestBuffer(size int) []byte {
|
||||
ret := make([]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
ret[i] = byte(i & 0xff)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Check that buf is of the type returned by getTestBuffer.
|
||||
func checkTestBuffer(buf []byte) bool {
|
||||
if buf == nil || len(buf) == 0 {
|
||||
return false
|
||||
}
|
||||
for i, v := range buf {
|
||||
if v != byte(i&0xff) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Send / receive cfg.sendSize bytes in a single shot and check that it behaves appropriately.
|
||||
func (cfg readLimitTestConfig) runSingleSend(t *testing.T, conn *TimeoutConnection, idx int) error {
|
||||
if err := checkedSendReceive(t, conn, cfg.sendSize); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send / receive cfg.sendSize bytes, split over five sends, and check that it behaves appropriately.
|
||||
func (cfg readLimitTestConfig) runMultiSend(t *testing.T, conn *TimeoutConnection, idx int) error {
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := checkedSendReceive(t, conn, cfg.sendSize/5); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A timeoutConnector that uses a dialer to dial the connections
|
||||
type dialerConnector struct {
|
||||
readLimitTestConfig
|
||||
|
||||
// This is lazily inited
|
||||
dialer *Dialer
|
||||
}
|
||||
|
||||
// Function that returns a connector
|
||||
type timeoutConnectorFactory func(readLimitTestConfig) timeoutConnector
|
||||
|
||||
// Dial the connection using the dialer (creating the dialer if necessary)
|
||||
func (d *dialerConnector) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) {
|
||||
if d.dialer == nil {
|
||||
d.dialer = NewDialer(&Dialer{
|
||||
BytesReadLimit: d.limit,
|
||||
ReadLimitExceededAction: d.action,
|
||||
})
|
||||
}
|
||||
var ret *TimeoutConnection
|
||||
conn, err := d.dialer.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if conn != nil {
|
||||
ret = conn.(*TimeoutConnection)
|
||||
}
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (d *dialerConnector) getConfig() readLimitTestConfig {
|
||||
return d.readLimitTestConfig
|
||||
}
|
||||
|
||||
func dialerTimeoutConnectorFactory(cfg readLimitTestConfig) timeoutConnector {
|
||||
return &dialerConnector{
|
||||
readLimitTestConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial using a direct call to DialTimeoutConnectionEx
|
||||
type directDial struct {
|
||||
readLimitTestConfig
|
||||
}
|
||||
|
||||
func (d *directDial) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) {
|
||||
conn, err := DialTimeoutConnectionEx("tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second, time.Second, time.Second, time.Second)
|
||||
var ret *TimeoutConnection
|
||||
if conn != nil {
|
||||
ret = conn.(*TimeoutConnection)
|
||||
ret.BytesReadLimit = d.limit
|
||||
ret.ReadLimitExceededAction = d.action
|
||||
}
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (d *directDial) getConfig() readLimitTestConfig {
|
||||
return d.readLimitTestConfig
|
||||
}
|
||||
|
||||
func directDialFactory(cfg readLimitTestConfig) timeoutConnector {
|
||||
return &directDial{cfg}
|
||||
}
|
||||
|
||||
var readLimitTestConfigs = map[string]readLimitTestConfig{
|
||||
// Check that a 2000-byte read gets truncated at 1000 bytes
|
||||
"truncate": {
|
||||
limit: 1000,
|
||||
sendSize: 2000,
|
||||
action: ReadLimitExceededActionTruncate,
|
||||
},
|
||||
|
||||
// Check that a 1005-byte read gets truncated at 1000 bytes
|
||||
"truncate_close": {
|
||||
limit: 1000,
|
||||
sendSize: 1005,
|
||||
action: ReadLimitExceededActionTruncate,
|
||||
},
|
||||
|
||||
// Check that a 2000-byte read errors after reading the first 1000 bytes
|
||||
"error": {
|
||||
limit: 1000,
|
||||
sendSize: 2000,
|
||||
action: ReadLimitExceededActionError,
|
||||
},
|
||||
|
||||
// Check that a 2000-byte read panics after reading the first 1000 bytes
|
||||
"panic": {
|
||||
limit: 1000,
|
||||
sendSize: 2000,
|
||||
action: ReadLimitExceededActionPanic,
|
||||
},
|
||||
|
||||
// Check that the default settings pass (backwards compatibility)
|
||||
"default": {},
|
||||
|
||||
// Check that a 100-byte read succeeds / is not truncated
|
||||
"happy": {
|
||||
limit: 1000,
|
||||
sendSize: 100,
|
||||
action: ReadLimitExceededActionPanic,
|
||||
},
|
||||
|
||||
// Check that a 1000-byte read succeeds / is not truncated
|
||||
"closeCall": {
|
||||
limit: 1000,
|
||||
sendSize: 1000,
|
||||
action: ReadLimitExceededActionPanic,
|
||||
},
|
||||
}
|
||||
|
||||
// Each of these gets run with each readLimitTestConfig
|
||||
var connTestConnectors = map[string]timeoutConnectorFactory{
|
||||
"directDial": directDialFactory,
|
||||
"dialerConnector": dialerTimeoutConnectorFactory,
|
||||
}
|
||||
|
||||
// Run a single full trial with the given connector: connect, send/receive the configured bytes, and
|
||||
// check that the response was properly truncated (or not), and that the bytes read total is
|
||||
// correctly tabulated.
|
||||
func runBytesReadLimitTrial(t *testing.T, connector timeoutConnector, idx int, method func(readLimitTestConfig, *testing.T, *TimeoutConnection, int) error) (result error) {
|
||||
cfg := connector.getConfig()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
port := 0x1234 + idx
|
||||
runEchoServer(t, port)
|
||||
conn, err := connector.connect(ctx, t, port, idx)
|
||||
if err != nil {
|
||||
t.Fatalf("Error dialing: %v", err)
|
||||
}
|
||||
expectedSize := cfg.sendSize
|
||||
if expectedSize > cfg.limit {
|
||||
expectedSize = cfg.limit
|
||||
}
|
||||
defer func() {
|
||||
if conn.BytesRead != expectedSize {
|
||||
result = fmt.Errorf("BytesRead(%d) != expected(%d)", conn.BytesRead, expectedSize)
|
||||
t.Error(result)
|
||||
}
|
||||
}()
|
||||
defer conn.Close()
|
||||
return method(cfg, t, conn, idx)
|
||||
}
|
||||
|
||||
// Run a full set of trials on the connector -- ten with a single send, and ten with multiple sends.
|
||||
func testBytesReadLimitOn(t *testing.T, connector timeoutConnector) error {
|
||||
for i := 0; i < 10; i++ {
|
||||
if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runSingleSend); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runMultiSend); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check that the BytesReadLimit is enforced (or not) as expected:
|
||||
// 1. Create an echo server
|
||||
// 2. Dial a fresh TimeoutConnection to the echo server with the given BytesReadLimit / ReadLimitExceededAction
|
||||
// 3. Send the configured number of bytes in a single packet
|
||||
// 4. Check that it (succeeds / truncates / errors / panics) according to the config
|
||||
// 5. Repeat 10 times
|
||||
// 6. Repeat the above 10 more times, except in #3, split the send across five packets
|
||||
func TestBytesReadLimit(t *testing.T) {
|
||||
connectors := make(map[string]timeoutConnector)
|
||||
// Create a fresh connector for each configuration
|
||||
for cfgName, cfg := range readLimitTestConfigs {
|
||||
for connectorName, factory := range connTestConnectors {
|
||||
connectors[connectorName+"_"+cfgName] = factory(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// Run each connector
|
||||
for name, connector := range connectors {
|
||||
t.Logf("Running %s", name)
|
||||
if err := testBytesReadLimitOn(t, connector); err != nil {
|
||||
t.Logf("Failed running %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
422
conn_timeout_test.go
Normal file
422
conn_timeout_test.go
Normal file
@ -0,0 +1,422 @@
|
||||
package zgrab2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Config for a single timeout test
|
||||
type connTimeoutTestConfig struct {
|
||||
// Name for the test for logging purposes
|
||||
name string
|
||||
|
||||
// Optional explicit endpoint to connect to (if absent, use 127.0.0.1)
|
||||
endpoint string
|
||||
|
||||
// TCP port number to communicate on
|
||||
port int
|
||||
|
||||
// Function used to dial new connections
|
||||
dialer func() (*TimeoutConnection, error)
|
||||
|
||||
// Client timeout values
|
||||
timeout time.Duration
|
||||
connectTimeout time.Duration
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
|
||||
// Time for server to wait after listening before accepting a connection
|
||||
acceptDelay time.Duration
|
||||
|
||||
// Time for server to wait after accepting before writing payload
|
||||
writeDelay time.Duration
|
||||
|
||||
// Time for server to wait before reading payload
|
||||
readDelay time.Duration
|
||||
|
||||
// Payload for server to send client after connecting
|
||||
serverToClientPayload []byte
|
||||
|
||||
// Payload for client to send server after reading the previous payload
|
||||
clientToServerPayload []byte
|
||||
|
||||
// Step when the client is expected to fail
|
||||
failStep testStep
|
||||
|
||||
// If non-empty, the error string returned by the client should contain this
|
||||
failError string
|
||||
}
|
||||
|
||||
// Standardized time units, separated by factors of 10.
|
||||
const (
|
||||
short = 100 * time.Millisecond
|
||||
medium = 1000 * time.Millisecond
|
||||
long = 10000 * time.Millisecond
|
||||
)
|
||||
|
||||
// enum type for the various locations where the test can fail
|
||||
type testStep string
|
||||
|
||||
const (
|
||||
testStepConnect = testStep("connect")
|
||||
testStepRead = testStep("read")
|
||||
testStepWrite = testStep("write")
|
||||
testStepDone = testStep("done")
|
||||
)
|
||||
|
||||
// Encapsulates a source for an error (client/server/???), the step where it occurred, and an
|
||||
// optional cause.
|
||||
type timeoutTestError struct {
|
||||
source string
|
||||
step testStep
|
||||
cause error
|
||||
}
|
||||
|
||||
func (err *timeoutTestError) Error() string {
|
||||
return fmt.Sprintf("%s error at %s: %v", err.source, err.step, err.cause)
|
||||
}
|
||||
|
||||
func serverError(step testStep, err error) *timeoutTestError {
|
||||
return &timeoutTestError{
|
||||
source: "server",
|
||||
step: step,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func clientError(step testStep, err error) *timeoutTestError {
|
||||
return &timeoutTestError{
|
||||
source: "client",
|
||||
step: step,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to ensure all data is written to a socket
|
||||
func _write(writer io.Writer, data []byte) error {
|
||||
n, err := writer.Write(data)
|
||||
if err == nil && n != len(data) {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Run the configured server. As soon as it returns, it is listening.
|
||||
// Returns a channel that receives a timeoutTestError on error, or is closed on successful completion.
|
||||
func (cfg *connTimeoutTestConfig) runServer(t *testing.T) (chan *timeoutTestError) {
|
||||
errorChan := make(chan *timeoutTestError)
|
||||
if cfg.endpoint != "" {
|
||||
// Only listen on localhost
|
||||
return errorChan
|
||||
}
|
||||
listener, err := net.Listen("tcp", cfg.getEndpoint())
|
||||
if err != nil {
|
||||
logrus.Fatalf("Error listening: %v", err)
|
||||
}
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
defer close(errorChan)
|
||||
time.Sleep(cfg.acceptDelay)
|
||||
sock, err := listener.Accept()
|
||||
if err != nil {
|
||||
errorChan <- serverError(testStepConnect, err)
|
||||
return
|
||||
}
|
||||
defer sock.Close()
|
||||
time.Sleep(cfg.writeDelay)
|
||||
if err := _write(sock, cfg.serverToClientPayload); err != nil {
|
||||
errorChan <- serverError(testStepWrite, err)
|
||||
return
|
||||
}
|
||||
time.Sleep(cfg.readDelay)
|
||||
buf := make([]byte, len(cfg.clientToServerPayload))
|
||||
n, err := io.ReadFull(sock, buf)
|
||||
if err != nil && err != io.EOF {
|
||||
errorChan <- serverError(testStepRead, err)
|
||||
return
|
||||
}
|
||||
if err == io.EOF && n < len(buf) {
|
||||
errorChan <- serverError(testStepRead, err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(buf, cfg.clientToServerPayload) {
|
||||
t.Errorf("%s: clientToServerPayload mismatch", cfg.name)
|
||||
}
|
||||
return
|
||||
}()
|
||||
return errorChan
|
||||
}
|
||||
|
||||
// Get the configured endpoint
|
||||
func (cfg *connTimeoutTestConfig) getEndpoint() string {
|
||||
if cfg.endpoint != "" {
|
||||
return cfg.endpoint
|
||||
}
|
||||
return fmt.Sprintf("127.0.0.1:%d", cfg.port)
|
||||
}
|
||||
|
||||
// Dial a connection to the configured endpoint using a Dialer
|
||||
func (cfg *connTimeoutTestConfig) dialerDial() (*TimeoutConnection, error) {
|
||||
dialer := NewDialer(&Dialer{
|
||||
Timeout: cfg.timeout,
|
||||
ConnectTimeout: cfg.connectTimeout,
|
||||
ReadTimeout: cfg.readTimeout,
|
||||
WriteTimeout: cfg.writeTimeout,
|
||||
})
|
||||
ret, err := dialer.Dial("tcp", cfg.getEndpoint())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret.(*TimeoutConnection), err
|
||||
}
|
||||
|
||||
// Dial a connection to the configured endpoint using a DialTimeoutConnectionEx
|
||||
func (cfg *connTimeoutTestConfig) directDial() (*TimeoutConnection, error) {
|
||||
ret, err := DialTimeoutConnectionEx("tcp", cfg.getEndpoint(), cfg.connectTimeout, cfg.timeout, cfg.readTimeout, cfg.writeTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret.(*TimeoutConnection), err
|
||||
}
|
||||
|
||||
// Dial a connection to the configured endpoint using Dialer.DialContext
|
||||
func (cfg *connTimeoutTestConfig) contextDial() (*TimeoutConnection, error) {
|
||||
dialer := NewDialer(&Dialer{
|
||||
Timeout: cfg.timeout,
|
||||
ConnectTimeout: cfg.connectTimeout,
|
||||
ReadTimeout: cfg.readTimeout,
|
||||
WriteTimeout: cfg.writeTimeout,
|
||||
})
|
||||
ret, err := dialer.DialContext(context.Background(), "tcp", cfg.getEndpoint())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret.(*TimeoutConnection), err
|
||||
}
|
||||
|
||||
// Run the client: connect to the server, read the payload, write the payload, disconnect.
|
||||
func (cfg *connTimeoutTestConfig) runClient(t *testing.T) (testStep, error) {
|
||||
conn, err := cfg.dialer()
|
||||
if err != nil {
|
||||
return testStepConnect, err
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, len(cfg.serverToClientPayload))
|
||||
_, err = io.ReadFull(conn, buf)
|
||||
if err != nil {
|
||||
return testStepRead, err
|
||||
}
|
||||
if !bytes.Equal(cfg.serverToClientPayload, buf) {
|
||||
t.Errorf("%s: serverToClientPayload payload mismatch", cfg.name)
|
||||
}
|
||||
if err := _write(conn, cfg.clientToServerPayload); err != nil {
|
||||
return testStepWrite, err
|
||||
}
|
||||
return testStepDone, nil
|
||||
}
|
||||
|
||||
// Run the configured test -- start a server and a client to connect to it.
|
||||
func (cfg *connTimeoutTestConfig) run(t *testing.T) {
|
||||
done := make(chan *timeoutTestError)
|
||||
serverError := cfg.runServer(t)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
close(done)
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
step, err := cfg.runClient(t)
|
||||
done <- clientError(step, err)
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(long + medium + short)
|
||||
done <- &timeoutTestError{source: "timeout"}
|
||||
}()
|
||||
var ret *timeoutTestError
|
||||
select {
|
||||
case err := <-serverError:
|
||||
t.Fatalf("%s: Server error: %v", cfg.name, err)
|
||||
case ret = <-done:
|
||||
if ret == nil {
|
||||
t.Fatalf("Channel unexpectedly closed")
|
||||
}
|
||||
}
|
||||
if ret.source != "client" {
|
||||
t.Fatalf("%s: Unexpected error from %s: %v", cfg.name, ret.source, ret.cause)
|
||||
}
|
||||
if ret.step != cfg.failStep {
|
||||
t.Errorf("%s: Failed at step %s, but expected to fail at step %s (error=%v)", cfg.name, ret.step, cfg.failStep, ret.cause)
|
||||
return
|
||||
}
|
||||
if cfg.failError != "" {
|
||||
errString := "none"
|
||||
if ret.cause != nil {
|
||||
errString = ret.cause.Error()
|
||||
}
|
||||
if !strings.Contains(errString, cfg.failError) {
|
||||
t.Errorf("%s: Expected an error (%s) at step %s, got %s", cfg.name, cfg.failError, cfg.failStep, errString)
|
||||
return
|
||||
}
|
||||
} else if ret.cause != nil {
|
||||
t.Errorf("%s: expected no error at step %s, but got %v", cfg.name, cfg.failStep, ret.cause)
|
||||
}
|
||||
}
|
||||
|
||||
var connTestConfigs = []connTimeoutTestConfig{
|
||||
// Long timeouts, short delays -- should succeed
|
||||
{
|
||||
name: "happy",
|
||||
port: 0x5613,
|
||||
timeout: long,
|
||||
connectTimeout: medium,
|
||||
readTimeout: medium,
|
||||
writeTimeout: medium,
|
||||
|
||||
acceptDelay: short,
|
||||
writeDelay: short,
|
||||
readDelay: short,
|
||||
|
||||
serverToClientPayload: []byte("abc"),
|
||||
clientToServerPayload: []byte("defghi"),
|
||||
|
||||
failStep: testStepDone,
|
||||
},
|
||||
// long session timeout, short connectTimeout. Use a non-local, nonexistent endpoint (localhost
|
||||
// would return "connection refused" immediately)
|
||||
{
|
||||
name: "connect_timeout",
|
||||
endpoint: "10.0.254.254:41591",
|
||||
timeout: long,
|
||||
connectTimeout: short,
|
||||
readTimeout: medium,
|
||||
writeTimeout: medium,
|
||||
|
||||
acceptDelay: short,
|
||||
writeDelay: short,
|
||||
readDelay: short,
|
||||
|
||||
serverToClientPayload: []byte("abc"),
|
||||
clientToServerPayload: []byte("defghi"),
|
||||
|
||||
failStep: testStepConnect,
|
||||
failError: "i/o timeout",
|
||||
},
|
||||
// short session timeout, medium connect timeout, with connect to nonexistent endpoint.
|
||||
{
|
||||
name: "session_connect_timeout",
|
||||
endpoint: "10.0.254.254:41591",
|
||||
timeout: short,
|
||||
connectTimeout: medium,
|
||||
readTimeout: medium,
|
||||
writeTimeout: medium,
|
||||
|
||||
acceptDelay: short,
|
||||
writeDelay: short,
|
||||
readDelay: short,
|
||||
|
||||
serverToClientPayload: []byte("abc"),
|
||||
clientToServerPayload: []byte("defghi"),
|
||||
|
||||
failStep: testStepConnect,
|
||||
failError: "i/o timeout",
|
||||
},
|
||||
// Get an IO timeout on the read.
|
||||
// sessionTimeout > acceptDelay + writeDelay > writeDelay > readTimeout
|
||||
{
|
||||
name: "read_timeout",
|
||||
port: 0x5614,
|
||||
timeout: long,
|
||||
connectTimeout: short,
|
||||
readTimeout: short,
|
||||
writeTimeout: short,
|
||||
|
||||
acceptDelay: short,
|
||||
writeDelay: medium,
|
||||
readDelay: short,
|
||||
|
||||
serverToClientPayload: []byte("abc"),
|
||||
clientToServerPayload: []byte("defghi"),
|
||||
|
||||
failStep: testStepRead,
|
||||
failError: "i/o timeout",
|
||||
},
|
||||
// Get a context timeout on a read.
|
||||
// readTimeout > writeDelay > timeout > acceptDelay
|
||||
{
|
||||
name: "session_read_timeout",
|
||||
port: 0x5615,
|
||||
timeout: short,
|
||||
connectTimeout: long,
|
||||
readTimeout: long,
|
||||
writeTimeout: long,
|
||||
|
||||
acceptDelay: 0,
|
||||
writeDelay: medium * 2,
|
||||
readDelay: 0,
|
||||
|
||||
serverToClientPayload: []byte("abc"),
|
||||
clientToServerPayload: []byte("defghi"),
|
||||
|
||||
failStep: testStepWrite,
|
||||
failError: "context deadline exceeded",
|
||||
},
|
||||
// Use a session timeout that is longer than any individual action's timeout.
|
||||
// acceptDelay+writeDelay+readDelay > timeout > acceptDelay >= writeDelay >= readDelay
|
||||
{
|
||||
name: "session_timeout",
|
||||
port: 0x5616,
|
||||
timeout: medium,
|
||||
connectTimeout: long,
|
||||
readTimeout: long,
|
||||
writeTimeout: long,
|
||||
|
||||
acceptDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
|
||||
writeDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
|
||||
readDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
|
||||
|
||||
serverToClientPayload: []byte("abc"),
|
||||
clientToServerPayload: []byte("defghi"),
|
||||
|
||||
failStep: testStepWrite,
|
||||
failError: "context deadline exceeded",
|
||||
},
|
||||
// TODO: How to test write timeout?
|
||||
}
|
||||
|
||||
// TestTimeoutConnectionTimeouts tests that the TimeoutConnection behaves as expected with respect
|
||||
// to timeouts.
|
||||
func TestTimeoutConnectionTimeouts(t *testing.T) {
|
||||
temp := make([]connTimeoutTestConfig, 0, len(connTestConfigs)*3)
|
||||
// Make three copies of connTestConfigs, one with each dial method.
|
||||
for _, cfg := range connTestConfigs {
|
||||
direct := cfg
|
||||
dialer := cfg
|
||||
ctxDialer := cfg
|
||||
|
||||
dialer.name = dialer.name + "_dialer"
|
||||
dialer.port = dialer.port + 100
|
||||
dialer.dialer = dialer.dialerDial
|
||||
|
||||
direct.name = direct.name + "_direct"
|
||||
direct.port = direct.port + 200
|
||||
direct.dialer = direct.directDial
|
||||
|
||||
ctxDialer.name = ctxDialer.name + "_context"
|
||||
ctxDialer.port = ctxDialer.port + 300
|
||||
ctxDialer.dialer = ctxDialer.contextDial
|
||||
temp = append(temp, direct, dialer, ctxDialer)
|
||||
}
|
||||
for _, cfg := range temp {
|
||||
t.Logf("Running %s", cfg.name)
|
||||
cfg.run(t)
|
||||
}
|
||||
}
|
347
modules/http/http_readlimit_test.go
Normal file
347
modules/http/http_readlimit_test.go
Normal file
@ -0,0 +1,347 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zmap/zcrypto/tls"
|
||||
"github.com/zmap/zgrab2"
|
||||
)
|
||||
|
||||
// BEGIN Taken from handshake_server_test.go -- certs for TLS server
|
||||
func fromHex(s string) []byte {
|
||||
b, _ := hex.DecodeString(s)
|
||||
return b
|
||||
}
|
||||
|
||||
func bigFromString(s string) *big.Int {
|
||||
ret := new(big.Int)
|
||||
ret.SetString(s, 10)
|
||||
return ret
|
||||
}
|
||||
|
||||
var testRSACertificate = fromHex("308202b030820219a00302010202090085b0bba48a7fb8ca300d06092a864886f70d01010505003045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3130303432343039303933385a170d3131303432343039303933385a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819f300d06092a864886f70d010101050003818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a381a73081a4301d0603551d0e04160414b1ade2855acfcb28db69ce2369ded3268e18883930750603551d23046e306c8014b1ade2855acfcb28db69ce2369ded3268e188839a149a4473045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746482090085b0bba48a7fb8ca300c0603551d13040530030101ff300d06092a864886f70d010105050003818100086c4524c76bb159ab0c52ccf2b014d7879d7a6475b55a9566e4c52b8eae12661feb4f38b36e60d392fdf74108b52513b1187a24fb301dbaed98b917ece7d73159db95d31d78ea50565cd5825a2d5a5f33c4b6d8c97590968c0f5298b5cd981f89205ff2a01ca31b9694dda9fd57e970e8266d71999b266e3850296c90a7bdd9")
|
||||
var testSNICertificate = fromHex("308201f23082015da003020102020100300b06092a864886f70d01010530283110300e060355040a130741636d6520436f311430120603550403130b736e69746573742e636f6d301e170d3132303431313137343033355a170d3133303431313137343533355a30283110300e060355040a130741636d6520436f311430120603550403130b736e69746573742e636f6d30819d300b06092a864886f70d01010103818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a3323030300e0603551d0f0101ff0404030200a0300d0603551d0e0406040401020304300f0603551d2304083006800401020304300b06092a864886f70d0101050381810089c6455f1c1f5ef8eb1ab174ee2439059f5c4259bb1a8d86cdb1d056f56a717da40e95ab90f59e8deaf627c157995094db0802266eb34fc6842dea8a4b68d9c1389103ab84fb9e1f85d9b5d23ff2312c8670fbb540148245a4ebafe264d90c8a4cf4f85b0fac12ac2fc4a3154bad52462868af96c62c6525d652b6e31845bdcc")
|
||||
|
||||
var testRSAPrivateKey = &rsa.PrivateKey{
|
||||
PublicKey: rsa.PublicKey{
|
||||
N: bigFromString("131650079503776001033793877885499001334664249354723305978524647182322416328664556247316495448366990052837680518067798333412266673813370895702118944398081598789828837447552603077848001020611640547221687072142537202428102790818451901395596882588063427854225330436740647715202971973145151161964464812406232198521"),
|
||||
E: 65537,
|
||||
},
|
||||
D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"),
|
||||
Primes: []*big.Int{
|
||||
bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"),
|
||||
bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"),
|
||||
},
|
||||
}
|
||||
|
||||
// END Taken from handshake_server_test.go -- certs for TLS server
|
||||
|
||||
// Get the tls.Config object for the server; adapted from handshake_server_test.go.
|
||||
func getTLSConfig() *tls.Config {
|
||||
testConfig := &tls.Config{
|
||||
Time: func() time.Time { return time.Unix(0, 0) },
|
||||
Certificates: make([]tls.Certificate, 2),
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionSSL30,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
}
|
||||
testConfig.Certificates[0].Certificate = [][]byte{testRSACertificate}
|
||||
testConfig.Certificates[0].PrivateKey = testRSAPrivateKey
|
||||
testConfig.Certificates[1].Certificate = [][]byte{testSNICertificate}
|
||||
testConfig.Certificates[1].PrivateKey = testRSAPrivateKey
|
||||
testConfig.BuildNameToCertificate()
|
||||
return testConfig
|
||||
}
|
||||
|
||||
// Helper function to write and check for short writes
|
||||
func _write(writer io.Writer, data []byte) error {
|
||||
n, err := writer.Write(data)
|
||||
if err == nil && len(data) != n {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a local server that sends responds to any requests with a cfg.headerSize-byte set of
|
||||
// headers followed by a cfg.bodySize-byte body.
|
||||
// The response ends up looking like this:
|
||||
// HTTP/1.0 200 OK
|
||||
// Bogus-Header: XXX...
|
||||
// Content-Length: <bodySize>
|
||||
//
|
||||
// XXXX....
|
||||
func (cfg *readLimitTestConfig) runFakeHTTPServer(t *testing.T) {
|
||||
endpoint := fmt.Sprintf("127.0.0.1:%d", cfg.port)
|
||||
listener, err := net.Listen("tcp", endpoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
sock, err := listener.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sock.Close()
|
||||
if cfg.tls {
|
||||
tlsSock := tls.Server(sock, getTLSConfig())
|
||||
if err := tlsSock.Handshake(); err != nil {
|
||||
t.Fatalf("server handshake error: %v", err)
|
||||
}
|
||||
sock = tlsSock
|
||||
}
|
||||
// don't care what the client sends, always respond with a HTTP-like response
|
||||
buf := make([]byte, 1)
|
||||
_, err = sock.Read(buf)
|
||||
if err != nil {
|
||||
// any error, including EOF, is unexpected -- the client should send something
|
||||
t.Fatalf("Unexpected error reading from client: %v", err)
|
||||
}
|
||||
|
||||
head := "HTTP/1.0 200 OK\r\nBogus-Header: X"
|
||||
headSuffix := fmt.Sprintf("\r\nContent-Length: %d\r\n\r\n", cfg.bodySize)
|
||||
size := cfg.headerSize - len(head) - len(headSuffix)
|
||||
if size < 0 {
|
||||
t.Fatalf("Header size %d too small: must be at least %d bytes", cfg.headerSize, len(head)+len(headSuffix))
|
||||
}
|
||||
if err := _write(sock, []byte(head)); err != nil {
|
||||
t.Fatalf("write error: %v", err)
|
||||
}
|
||||
chunkSize := 256
|
||||
sent := len(head)
|
||||
chunk := []byte(strings.Repeat("X", chunkSize))
|
||||
for i := 0; i < size; i += chunkSize {
|
||||
if i+chunkSize > size {
|
||||
chunk = []byte(strings.Repeat("X", size-i))
|
||||
}
|
||||
if err := _write(sock, chunk); err != nil {
|
||||
t.Logf("Failed writing to client after %d bytes: %v", sent, err)
|
||||
return
|
||||
}
|
||||
sent += len(chunk)
|
||||
}
|
||||
|
||||
if err := _write(sock, []byte(headSuffix)); err != nil {
|
||||
t.Logf("Failed writing foot to client: %v", err)
|
||||
return
|
||||
}
|
||||
sent += len(headSuffix)
|
||||
body := strings.Repeat("X", cfg.bodySize)
|
||||
if err := _write(sock, []byte(body)); err != nil {
|
||||
t.Logf("Failed writing body to client: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Get an HTTP scanner module with the desired config
|
||||
func (cfg *readLimitTestConfig) getScanner(t *testing.T) *Scanner {
|
||||
var module Module
|
||||
flags := module.NewFlags().(*Flags)
|
||||
flags.Endpoint = "/"
|
||||
flags.Method = "GET"
|
||||
flags.UserAgent = "Mozilla/5.0 zgrab/0.x"
|
||||
if cfg.maxBodySize&0x03ff != 0 {
|
||||
t.Fatalf("%d is not a valid maxBodySize (must be a multiple of 1024)", cfg.maxBodySize)
|
||||
}
|
||||
flags.MaxSize = cfg.maxBodySize / 1024
|
||||
flags.MaxRedirects = 0
|
||||
flags.Timeout = 1 * time.Second
|
||||
flags.Port = uint(cfg.port)
|
||||
flags.UseHTTPS = cfg.tls
|
||||
zgrab2.DefaultBytesReadLimit = cfg.maxReadSize
|
||||
scanner := module.NewScanner()
|
||||
scanner.Init(flags)
|
||||
return scanner.(*Scanner)
|
||||
}
|
||||
|
||||
// Configuration for a single test run
|
||||
type readLimitTestConfig struct {
|
||||
// if true, the client/server will use TLS. NOTE: the limits are on the *raw* connection.
|
||||
tls bool
|
||||
|
||||
// port where the server listens.
|
||||
port int
|
||||
|
||||
// Bodies larger than this are truncated. NOTE: this must be a multiple of 1024, since MaxSize
|
||||
// is given in kilobytes.
|
||||
maxBodySize int
|
||||
|
||||
// The maximum number of bytes to read from the (raw) socket. Beyond that data is truncated and
|
||||
// EOF is returned.
|
||||
maxReadSize int
|
||||
|
||||
// The size of the HTTP server's "header" (actually, all of the data before the body). Must be
|
||||
// at least 58 (the size of the static parts of the response).
|
||||
headerSize int
|
||||
|
||||
// The size of the HTTP body to send (the Content-Length).
|
||||
bodySize int
|
||||
|
||||
// The status that should be returned by the scan.
|
||||
expectedStatus zgrab2.ScanStatus
|
||||
|
||||
// If set, the error returned by the scan must contain this.
|
||||
expectedError string
|
||||
}
|
||||
|
||||
const (
|
||||
readLimitTestConfigHTTPBasePort = 0x7f7f
|
||||
readLimitTestConfigHTTPSBasePort = 0x7bbc
|
||||
)
|
||||
|
||||
var readLimitTestConfigs = map[string]*readLimitTestConfig{
|
||||
// The socket truncates the connection while reading the body. To the client it looks as if the
|
||||
// server closed the connection prior to sending Content-Length bytes; the result is success,
|
||||
// but with a truncated body.
|
||||
// bodySize + headerSize > maxReadSize > headerSize
|
||||
"truncate_read_body": {
|
||||
tls: false,
|
||||
port: readLimitTestConfigHTTPBasePort,
|
||||
maxBodySize: 2048,
|
||||
maxReadSize: 1024,
|
||||
headerSize: 64,
|
||||
bodySize: 4096,
|
||||
expectedStatus: zgrab2.SCAN_SUCCESS,
|
||||
},
|
||||
// NOTE: There is no tls_truncate_read_body, since the truncation will almost certainly occur
|
||||
// in the middle of a TLS packet -- so the response would always be "unexpected EOF"
|
||||
|
||||
// The HTTP library stops reading the body after reaching its internal limit. It returns success
|
||||
// and the truncated body.
|
||||
// maxReadSize > headerSize + bodySize > bodySize > maxBodySize
|
||||
"truncate_body": {
|
||||
tls: false,
|
||||
port: readLimitTestConfigHTTPBasePort + 1,
|
||||
maxBodySize: 2048,
|
||||
maxReadSize: 8192,
|
||||
headerSize: 64,
|
||||
bodySize: 4096,
|
||||
expectedStatus: zgrab2.SCAN_SUCCESS,
|
||||
},
|
||||
"tls_truncate_body": {
|
||||
tls: true,
|
||||
port: readLimitTestConfigHTTPSBasePort + 1,
|
||||
maxBodySize: 2048,
|
||||
maxReadSize: 8192,
|
||||
headerSize: 64,
|
||||
bodySize: 4096,
|
||||
expectedStatus: zgrab2.SCAN_SUCCESS,
|
||||
},
|
||||
|
||||
// The socket truncates the connection while reading the headers. The result isn't a valid HTTP
|
||||
// response, so the library returns an unexpected EOF error.
|
||||
// headerSize > maxReadSize
|
||||
"truncate_read_header": {
|
||||
tls: false,
|
||||
port: readLimitTestConfigHTTPBasePort + 2,
|
||||
maxBodySize: 1024,
|
||||
maxReadSize: 2048,
|
||||
headerSize: 3072,
|
||||
bodySize: 8,
|
||||
expectedError: "unexpected EOF",
|
||||
expectedStatus: zgrab2.SCAN_UNKNOWN_ERROR,
|
||||
},
|
||||
"tls_truncate_read_header": {
|
||||
tls: true,
|
||||
port: readLimitTestConfigHTTPSBasePort + 2,
|
||||
maxBodySize: 1024,
|
||||
maxReadSize: 2048,
|
||||
headerSize: 3072,
|
||||
bodySize: 8,
|
||||
expectedError: "unexpected EOF",
|
||||
expectedStatus: zgrab2.SCAN_UNKNOWN_ERROR,
|
||||
},
|
||||
|
||||
// Happy case. None of the limits are hit.
|
||||
// maxReadSize >= maxBodySize > bodySize + headerSize
|
||||
"happy_case": {
|
||||
tls: false,
|
||||
port: readLimitTestConfigHTTPBasePort + 3,
|
||||
maxBodySize: 8192,
|
||||
maxReadSize: 8192,
|
||||
headerSize: 1024,
|
||||
bodySize: 1024,
|
||||
expectedStatus: zgrab2.SCAN_SUCCESS,
|
||||
},
|
||||
"tls_happy_case": {
|
||||
tls: true,
|
||||
port: readLimitTestConfigHTTPSBasePort + 3,
|
||||
maxBodySize: 8192,
|
||||
maxReadSize: 8192,
|
||||
headerSize: 1024,
|
||||
bodySize: 1024,
|
||||
expectedStatus: zgrab2.SCAN_SUCCESS,
|
||||
},
|
||||
}
|
||||
|
||||
// Try to get the HTTP body from a result; otherwise return the empty string.
|
||||
func getBody(result interface{}) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
httpResult, ok := result.(*Results)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
response := httpResult.Response
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
return response.BodyText
|
||||
}
|
||||
|
||||
// Run a single test with the given configuration.
|
||||
func (cfg *readLimitTestConfig) runTest(t *testing.T, testName string) {
|
||||
scanner := cfg.getScanner(t)
|
||||
cfg.runFakeHTTPServer(t)
|
||||
target := zgrab2.ScanTarget{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
status, ret, err := scanner.Scan(target)
|
||||
|
||||
if status != cfg.expectedStatus {
|
||||
t.Errorf("Wrong status: expected %s, got %s", cfg.expectedStatus, status)
|
||||
}
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), cfg.expectedError) {
|
||||
t.Errorf("Wrong error: expected %s, got %s", err.Error(), cfg.expectedError)
|
||||
}
|
||||
} else if len(cfg.expectedError) > 0 {
|
||||
t.Errorf("Expected error '%s' but got none", cfg.expectedError)
|
||||
}
|
||||
if cfg.expectedStatus == zgrab2.SCAN_SUCCESS {
|
||||
body := getBody(ret)
|
||||
if body == "" {
|
||||
t.Errorf("Expected success, but got no body")
|
||||
} else {
|
||||
if len(body) > cfg.maxBodySize || len(body) > cfg.maxReadSize {
|
||||
t.Errorf("Body exceeds max size: len(body)=%d; maxBodySize=%d, maxReadSize=%d", len(body), cfg.maxBodySize, cfg.maxReadSize)
|
||||
}
|
||||
if !cfg.tls {
|
||||
if len(body)+cfg.headerSize > cfg.maxReadSize {
|
||||
t.Errorf("Body and header exceed max read size: len(body)=%d, headerSize=%d, maxReadSize=%d", len(body), cfg.headerSize, cfg.maxReadSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadLimitHTTP checks that the HTTP scanner works as expected with the default
|
||||
// ReadLimitExeededAction (specifically, ReadLimnitExceededActionTruncate) defined in conn.go.
|
||||
func TestReadLimitHTTP(t *testing.T) {
|
||||
if zgrab2.DefaultReadLimitExceededAction != zgrab2.ReadLimitExceededActionTruncate {
|
||||
t.Logf("Warning: DefaultReadLimitExceededAction is %s, not %s", zgrab2.DefaultReadLimitExceededAction, zgrab2.ReadLimitExceededActionTruncate)
|
||||
}
|
||||
for testName, cfg := range readLimitTestConfigs {
|
||||
cfg.runTest(t, testName)
|
||||
}
|
||||
}
|
@ -8,12 +8,14 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zmap/zgrab2"
|
||||
@ -76,13 +78,14 @@ type Scanner struct {
|
||||
// scan holds the state for a single scan. This may entail multiple connections.
|
||||
// It is used to implement the zgrab2.Scanner interface.
|
||||
type scan struct {
|
||||
connections []net.Conn
|
||||
scanner *Scanner
|
||||
target *zgrab2.ScanTarget
|
||||
transport *http.Transport
|
||||
client *http.Client
|
||||
results Results
|
||||
url string
|
||||
connections []net.Conn
|
||||
scanner *Scanner
|
||||
target *zgrab2.ScanTarget
|
||||
transport *http.Transport
|
||||
client *http.Client
|
||||
results Results
|
||||
url string
|
||||
globalDeadline time.Time
|
||||
}
|
||||
|
||||
// NewFlags returns an empty Flags object.
|
||||
@ -142,15 +145,40 @@ func (scan *scan) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// Get a context whose deadline is the earliest of the context's deadline (if it has one) and the
|
||||
// global scan deadline.
|
||||
func (scan *scan) withDeadlineContext(ctx context.Context) context.Context {
|
||||
ctxDeadline, ok := ctx.Deadline()
|
||||
if !ok || scan.globalDeadline.Before(ctxDeadline) {
|
||||
ret, _ := context.WithDeadline(ctx, scan.globalDeadline)
|
||||
return ret
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Dial a connection using the configured timeouts, as well as the global deadline, and on success,
|
||||
// add the connection to the list of connections to be cleaned up.
|
||||
func (scan *scan) dialContext(ctx context.Context, net string, addr string) (net.Conn, error) {
|
||||
dialer := zgrab2.GetTimeoutConnectionDialer(scan.scanner.config.Timeout)
|
||||
|
||||
timeoutContext, _ := context.WithTimeout(context.Background(), scan.scanner.config.Timeout)
|
||||
|
||||
conn, err := dialer.DialContext(scan.withDeadlineContext(timeoutContext), net, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scan.connections = append(scan.connections, conn)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// getTLSDialer returns a Dial function that connects using the
|
||||
// zgrab2.GetTLSConnection()
|
||||
func (scan *scan) getTLSDialer() func(net, addr string) (net.Conn, error) {
|
||||
return func(net, addr string) (net.Conn, error) {
|
||||
outer, err := zgrab2.DialTimeoutConnection(net, addr, scan.scanner.config.Timeout)
|
||||
outer, err := scan.dialContext(context.Background(), net, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scan.connections = append(scan.connections, outer)
|
||||
tlsConn, err := scan.scanner.config.TLSFlags.GetTLSConnection(outer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -242,10 +270,11 @@ func (scanner *Scanner) newHTTPScan(t *zgrab2.ScanTarget) *scan {
|
||||
DisableCompression: false,
|
||||
MaxIdleConnsPerHost: scanner.config.MaxRedirects,
|
||||
},
|
||||
client: http.MakeNewClient(),
|
||||
client: http.MakeNewClient(),
|
||||
globalDeadline: time.Now().Add(scanner.config.Timeout),
|
||||
}
|
||||
ret.transport.DialTLS = ret.getTLSDialer()
|
||||
ret.transport.DialContext = zgrab2.GetTimeoutConnectionDialer(scanner.config.Timeout).DialContext
|
||||
ret.transport.DialContext = ret.dialContext
|
||||
ret.client.UserAgent = scanner.config.UserAgent
|
||||
ret.client.CheckRedirect = ret.getCheckRedirect()
|
||||
ret.client.Transport = ret.transport
|
||||
@ -334,6 +363,7 @@ func (scanner *Scanner) Scan(t zgrab2.ScanTarget) (zgrab2.ScanStatus, interface{
|
||||
// zgrab2 framework.
|
||||
func RegisterModule() {
|
||||
var module Module
|
||||
|
||||
_, err := zgrab2.AddCommand("http", "HTTP Banner Grab", "Grab a banner over HTTP", 80, &module)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -82,10 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TimeoutConnection{
|
||||
Conn: conn,
|
||||
Timeout: flags.Timeout,
|
||||
}, nil
|
||||
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil
|
||||
}
|
||||
|
||||
// grabTarget calls handler for each action
|
||||
|
Loading…
Reference in New Issue
Block a user