Add support for BytesReadLimit parameter in BaseFlags

Some protocols may require more data than others.  To accomodate those,
allow the BytesReadLimit to be changed by means of BaseFlags.

By setting BaseFlags.BytesReadLimit prior to calling .Open(), scanners
can override the default limit to one that is appropriate for the data
collected.
This commit is contained in:
Jeff Cody 2018-10-16 13:51:06 -04:00
parent 6c186abf2e
commit ec59b49540
No known key found for this signature in database
GPG Key ID: BDBE7B27C0DE3057
4 changed files with 17 additions and 15 deletions

17
conn.go

@ -167,7 +167,7 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
// GetTimeoutDialFunc returns a DialFunc that dials with the given timeout // GetTimeoutDialFunc returns a DialFunc that dials with the given timeout
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) { func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(proto, target string) (net.Conn, error) { return func(proto, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, timeout) return DialTimeoutConnection(proto, target, timeout, 0)
} }
} }
@ -217,12 +217,13 @@ func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
} }
// NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults. // NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults.
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection { func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) *TimeoutConnection {
ret := (&TimeoutConnection{ ret := (&TimeoutConnection{
Conn: conn, Conn: conn,
Timeout: timeout, Timeout: timeout,
ReadTimeout: readTimeout, ReadTimeout: readTimeout,
WriteTimeout: writeTimeout, WriteTimeout: writeTimeout,
BytesReadLimit: bytesReadLimit,
}).SetDefaults() }).SetDefaults()
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
@ -232,7 +233,7 @@ func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeo
} }
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations. // DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) { func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) (net.Conn, error) {
var conn net.Conn var conn net.Conn
var err error var err error
if dialTimeout > 0 { if dialTimeout > 0 {
@ -246,12 +247,12 @@ func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTi
} }
return nil, err return nil, err
} }
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout, bytesReadLimit), nil
} }
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations. // DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) { func DialTimeoutConnection(proto string, target string, timeout time.Duration, bytesReadLimit int) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout) return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout, bytesReadLimit)
} }
// Dialer provides Dial and DialContext methods to get connections with the given timeout. // Dialer provides Dial and DialContext methods to get connections with the given timeout.
@ -303,7 +304,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout) ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout, d.BytesReadLimit)
ret.BytesReadLimit = d.BytesReadLimit ret.BytesReadLimit = d.BytesReadLimit
ret.ReadLimitExceededAction = d.ReadLimitExceededAction ret.ReadLimitExceededAction = d.ReadLimitExceededAction
return ret, nil return ret, nil
@ -311,7 +312,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
// Dial returns a connection with the configured timeout. // Dial returns a connection with the configured timeout.
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) { func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout) return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout, 0)
} }
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout. // GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.

@ -63,10 +63,11 @@ type ScanFlags interface {
// BaseFlags contains the options that every flags type must embed // BaseFlags contains the options that every flags type must embed
type BaseFlags struct { type BaseFlags struct {
Port uint `short:"p" long:"port" description:"Specify port to grab on"` Port uint `short:"p" long:"port" description:"Specify port to grab on"`
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"` Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"` Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"` Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
BytesReadLimit int `short:"m" long:"maxbytes" description:"Maximum byte read limit per scan (0 = defaults)"`
} }
// UDPFlags contains the common options used for all UDP scans // UDPFlags contains the common options used for all UDP scans

@ -623,7 +623,7 @@ func (scan *scan) getCheckRedirect(scanner *Scanner) func(*http.Request, *http.R
// Taken from zgrab2 http library, slightly modified to use slightly leaner scan object // Taken from zgrab2 http library, slightly modified to use slightly leaner scan object
func (scan *scan) getTLSDialer(scanner *Scanner) func(net, addr string) (net.Conn, error) { func (scan *scan) getTLSDialer(scanner *Scanner) func(net, addr string) (net.Conn, error) {
return func(net, addr string) (net.Conn, error) { return func(net, addr string) (net.Conn, error) {
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout) outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -57,7 +57,7 @@ func (target *ScanTarget) Host() string {
// Open connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations. // Open connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error) { func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error) {
address := net.JoinHostPort(target.Host(), fmt.Sprintf("%d", flags.Port)) address := net.JoinHostPort(target.Host(), fmt.Sprintf("%d", flags.Port))
return DialTimeoutConnection("tcp", address, flags.Timeout) return DialTimeoutConnection("tcp", address, flags.Timeout, flags.BytesReadLimit)
} }
// OpenUDP connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations. // OpenUDP connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
@ -82,7 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0, flags.BytesReadLimit), nil
} }
// grabTarget calls handler for each action // grabTarget calls handler for each action