Merge pull request #180 from codyprime/dev

Add support for BytesReadLimit parameter in BaseFlags
This commit is contained in:
justinbastress 2018-10-23 10:54:30 -04:00 committed by GitHub
commit e5b7392ab4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 15 deletions

17
conn.go

@ -167,7 +167,7 @@ func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
// GetTimeoutDialFunc returns a DialFunc that dials with the given timeout
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(proto, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, timeout)
return DialTimeoutConnection(proto, target, timeout, 0)
}
}
@ -217,12 +217,13 @@ func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
}
// NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults.
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection {
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) *TimeoutConnection {
ret := (&TimeoutConnection{
Conn: conn,
Timeout: timeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
BytesReadLimit: bytesReadLimit,
}).SetDefaults()
if ctx == nil {
ctx = context.Background()
@ -232,7 +233,7 @@ func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeo
}
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) {
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) (net.Conn, error) {
var conn net.Conn
var err error
if dialTimeout > 0 {
@ -246,12 +247,12 @@ func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTi
}
return nil, err
}
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout, bytesReadLimit), nil
}
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout)
func DialTimeoutConnection(proto string, target string, timeout time.Duration, bytesReadLimit int) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout, bytesReadLimit)
}
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
@ -303,7 +304,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, err
}
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout)
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout, d.BytesReadLimit)
ret.BytesReadLimit = d.BytesReadLimit
ret.ReadLimitExceededAction = d.ReadLimitExceededAction
return ret, nil
@ -311,7 +312,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
// Dial returns a connection with the configured timeout.
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout)
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout, 0)
}
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.

@ -63,10 +63,11 @@ type ScanFlags interface {
// BaseFlags contains the options that every flags type must embed
type BaseFlags struct {
Port uint `short:"p" long:"port" description:"Specify port to grab on"`
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
Port uint `short:"p" long:"port" description:"Specify port to grab on"`
Name string `short:"n" long:"name" description:"Specify name for output json, only necessary if scanning multiple modules"`
Timeout time.Duration `short:"t" long:"timeout" description:"Set connection timeout (0 = no timeout)" default:"10s"`
Trigger string `short:"g" long:"trigger" description:"Invoke only on targets with specified tag"`
BytesReadLimit int `short:"m" long:"maxbytes" description:"Maximum byte read limit per scan (0 = defaults)"`
}
// UDPFlags contains the common options used for all UDP scans

@ -623,7 +623,7 @@ func (scan *scan) getCheckRedirect(scanner *Scanner) func(*http.Request, *http.R
// Taken from zgrab2 http library, slightly modified to use slightly leaner scan object
func (scan *scan) getTLSDialer(scanner *Scanner) func(net, addr string) (net.Conn, error) {
return func(net, addr string) (net.Conn, error) {
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout)
outer, err := zgrab2.DialTimeoutConnection(net, addr, scanner.config.BaseFlags.Timeout, 0)
if err != nil {
return nil, err
}

@ -57,7 +57,7 @@ func (target *ScanTarget) Host() string {
// Open connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error) {
address := net.JoinHostPort(target.Host(), fmt.Sprintf("%d", flags.Port))
return DialTimeoutConnection("tcp", address, flags.Timeout)
return DialTimeoutConnection("tcp", address, flags.Timeout, flags.BytesReadLimit)
}
// OpenUDP connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations.
@ -82,7 +82,7 @@ func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, er
if err != nil {
return nil, err
}
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0), nil
return NewTimeoutConnection(nil, conn, flags.Timeout, 0, 0, flags.BytesReadLimit), nil
}
// grabTarget calls handler for each action