Add support for BytesReadLimit parameter in BaseFlags
Some protocols may require more data than others. To accomodate those, allow the BytesReadLimit to be changed by means of BaseFlags. By setting BaseFlags.BytesReadLimit prior to calling .Open(), scanners can override the default limit to one that is appropriate for the data collected.
This commit is contained in:
parent
6c186abf2e
commit
ec59b49540
17
conn.go
17
conn.go
|
@ -167,7 +167,7 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
|
|||
// GetTimeoutDialFunc returns a DialFunc that dials with the given timeout
|
||||
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
|
||||
return func(proto, target string) (net.Conn, error) {
|
||||
return DialTimeoutConnection(proto, target, timeout)
|
||||
return DialTimeoutConnection(proto, target, timeout, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -217,12 +217,13 @@ func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
|
|||
}
|
||||
|
||||
// NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults.
|
||||
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection {
|
||||
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) *TimeoutConnection {
|
||||
ret := (&TimeoutConnection{
|
||||
Conn: conn,
|
||||
Timeout: timeout,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
BytesReadLimit: bytesReadLimit,
|
||||
}).SetDefaults()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
|
@ -232,7 +233,7 @@ func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeo
|
|||
}
|
||||
|
||||
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
||||
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) {
|
||||
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) (net.Conn, error) {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if dialTimeout > 0 {
|
||||
|
@ -246,12 +247,12 @@ func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTi
|
|||
}
|
||||
return nil, err
|
||||
}
|
||||
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil
|
||||
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout, bytesReadLimit), nil
|
||||
}
|
||||
|
||||
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
|
||||
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
|
||||
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout)
|
||||
func DialTimeoutConnection(proto string, target string, timeout time.Duration, bytesReadLimit int) (net.Conn, error) {
|
||||
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout, bytesReadLimit)
|
||||
}
|
||||
|
||||
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
|
||||
|
@ -303,7 +304,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout)
|
||||
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout, d.BytesReadLimit)
|
||||
ret.BytesReadLimit = d.BytesReadLimit
|
||||
ret.ReadLimitExceededAction = d.ReadLimitExceededAction
|
||||
return ret, nil
|
||||
|
@ -311,7 +312,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||
|
||||
// Dial returns a connection with the configured timeout.
|
||||
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
|
||||
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout)
|
||||
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout, 0)
|
||||
}
|
||||
|
||||
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.
|
||||
|
|
|
@ -63,10 +63,11 @@ type ScanFlags interface {
|
|||
|
||||
// BaseFlags contains the options that every flags type must embed
|
||||
type BaseFlags struct {
|
||||
Port uint `short:"p" long:"port" description:"Specify port to grab on"`
|
||||
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
|
||||
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
|
||||
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
|
||||
Port uint `short:"p" long:"port" description:"Specify port to grab on"`
|
||||
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
|
||||
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
|
||||
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
|
||||
BytesReadLimit int `short:"m" long:"maxbytes" description:"Maximum byte read limit per scan (0 = defaults)"`
|
||||
}
|
||||
|
||||
// UDPFlags contains the common options used for all UDP scans
|
||||
|
|
|
@ -623,7 +623,7 @@ func (scan *scan) getCheckRedirect(scanner *Scanner) func(*http.Request, *http.R
|
|||
// Taken from zgrab2 http library, slightly modified to use slightly leaner scan object
|
||||
func (scan *scan) getTLSDialer(scanner *Scanner) func(net, addr string) (net.Conn, error) {
|
||||
return func(net, addr string) (net.Conn, error) {
|
||||
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout)
|
||||
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ func (target *ScanTarget) Host() string {
|
|||
// Open connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
||||
func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error) {
|
||||
address := net.JoinHostPort(target.Host(), fmt.Sprintf("%d", flags.Port))
|
||||
return DialTimeoutConnection("tcp", address, flags.Timeout)
|
||||
return DialTimeoutConnection("tcp", address, flags.Timeout, flags.BytesReadLimit)
|
||||
}
|
||||
|
||||
// OpenUDP connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
||||
|
@ -82,7 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil
|
||||
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0, flags.BytesReadLimit), nil
|
||||
}
|
||||
|
||||
// grabTarget calls handler for each action
|
||||
|
|
Loading…
Reference in New Issue