Add a dialer for the non-TLS case in HTTP, so that --timeout is honored even if --use-https is not set (issue #109)

This commit is contained in:
Justin Bastress 2018-05-01 13:33:21 -04:00
parent 5c0cbeeee0
commit 3254857b58
2 changed files with 43 additions and 7 deletions

49
conn.go

@ -3,6 +3,7 @@ package zgrab2
import (
"net"
"time"
"context"
)
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
@ -31,13 +32,6 @@ func (c *TimeoutConnection) Write(b []byte) (n int, err error) {
return c.Conn.Write(b)
}
// GetTimeoutDialer returns a Dialer function that dials with the given timeout
func GetTimeoutDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(proto, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, timeout)
}
}
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
var conn net.Conn
@ -58,3 +52,44 @@ func DialTimeoutConnection(proto string, target string, timeout time.Duration) (
Timeout: timeout,
}, nil
}
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
type Dialer struct {
// Timeout is the maximum time to wait for a connection or I/O.
Timeout time.Duration
// dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a TimeoutConnection).
Dialer *net.Dialer
}
// DialContext wraps the connection returned by net.Dialer.DialContext() with a TimeoutConnection.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// ensure that our aux dialer is up-to-date
d.Dialer.Timeout = d.Timeout
d.Dialer.KeepAlive = d.Timeout
ret, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
return &TimeoutConnection{
Conn: ret,
Timeout: d.Timeout,
}, nil
}
// Dial returns a connection with the configured timeout.
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, d.Timeout)
}
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.
func GetTimeoutConnectionDialer(timeout time.Duration) *Dialer {
return &Dialer{
Timeout: timeout,
Dialer: &net.Dialer{
Timeout: timeout,
KeepAlive: timeout,
DualStack: true,
},
}
}

@ -241,6 +241,7 @@ func (scanner *Scanner) newHTTPScan(t *zgrab2.ScanTarget) *scan {
client: http.MakeNewClient(),
}
ret.transport.DialTLS = ret.getTLSDialer()
ret.transport.DialContext = zgrab2.GetTimeoutConnectionDialer(time.Duration(scanner.config.Timeout) * time.Second).DialContext
ret.client.UserAgent = scanner.config.UserAgent
ret.client.CheckRedirect = ret.getCheckRedirect()
ret.client.Transport = ret.transport