Add support for BytesReadLimit parameter in BaseFlags

Some protocols may require more data than others.  To accomodate those,
allow the BytesReadLimit to be changed by means of BaseFlags.

By setting BaseFlags.BytesReadLimit prior to calling .Open(), scanners
can override the default limit to one that is appropriate for the data
collected.
This commit is contained in:
Jeff Cody 2018-10-16 13:51:06 -04:00
parent 6c186abf2e
commit ec59b49540
No known key found for this signature in database
GPG Key ID: BDBE7B27C0DE3057
4 changed files with 17 additions and 15 deletions

17
conn.go
View File

@ -167,7 +167,7 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
// GetTimeoutDialFunc returns a DialFunc that dials with the given timeout
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(proto, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, timeout)
return DialTimeoutConnection(proto, target, timeout, 0)
}
}
@ -217,12 +217,13 @@ func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
}
// NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults.
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection {
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) *TimeoutConnection {
ret := (&TimeoutConnection{
Conn: conn,
Timeout: timeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
BytesReadLimit: bytesReadLimit,
}).SetDefaults()
if ctx == nil {
ctx = context.Background()
@ -232,7 +233,7 @@ func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeo
}
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) {
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) (net.Conn, error) {
var conn net.Conn
var err error
if dialTimeout > 0 {
@ -246,12 +247,12 @@ func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTi
}
return nil, err
}
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout, bytesReadLimit), nil
}
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout)
func DialTimeoutConnection(proto string, target string, timeout time.Duration, bytesReadLimit int) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout, bytesReadLimit)
}
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
@ -303,7 +304,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, err
}
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout)
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout, d.BytesReadLimit)
ret.BytesReadLimit = d.BytesReadLimit
ret.ReadLimitExceededAction = d.ReadLimitExceededAction
return ret, nil
@ -311,7 +312,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
// Dial returns a connection with the configured timeout.
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout)
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout, 0)
}
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.

View File

@ -63,10 +63,11 @@ type ScanFlags interface {
// BaseFlags contains the options that every flags type must embed
type BaseFlags struct {
Port uint `short:"p" long:"port" description:"Specify port to grab on"`
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
Port uint `short:"p" long:"port" description:"Specify port to grab on"`
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
BytesReadLimit int `short:"m" long:"maxbytes" description:"Maximum byte read limit per scan (0 = defaults)"`
}
// UDPFlags contains the common options used for all UDP scans

View File

@ -623,7 +623,7 @@ func (scan *scan) getCheckRedirect(scanner *Scanner) func(*http.Request, *http.R
// Taken from zgrab2 http library, slightly modified to use slightly leaner scan object
func (scan *scan) getTLSDialer(scanner *Scanner) func(net, addr string) (net.Conn, error) {
return func(net, addr string) (net.Conn, error) {
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout)
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout, 0)
if err != nil {
return nil, err
}

View File

@ -57,7 +57,7 @@ func (target *ScanTarget) Host() string {
// Open connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error) {
address := net.JoinHostPort(target.Host(), fmt.Sprintf("%d", flags.Port))
return DialTimeoutConnection("tcp", address, flags.Timeout)
return DialTimeoutConnection("tcp", address, flags.Timeout, flags.BytesReadLimit)
}
// OpenUDP connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
@ -82,7 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er
if err != nil {
return nil, err
}
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0, flags.BytesReadLimit), nil
}
// grabTarget calls handler for each action