Add a dialer for the non-TLS case in HTTP, so that --timeout is honored even if --use-https is not set (issue #109)
This commit is contained in:
parent
5c0cbeeee0
commit
3254857b58
49
conn.go
49
conn.go
@ -3,6 +3,7 @@ package zgrab2
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
|
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
|
||||||
@ -31,13 +32,6 @@ func (c *TimeoutConnection) Write(b []byte) (n int, err error) {
|
|||||||
return c.Conn.Write(b)
|
return c.Conn.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTimeoutDialer returns a Dialer function that dials with the given timeout
|
|
||||||
func GetTimeoutDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
|
|
||||||
return func(proto, target string) (net.Conn, error) {
|
|
||||||
return DialTimeoutConnection(proto, target, timeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
||||||
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
|
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
@ -58,3 +52,44 @@ func DialTimeoutConnection(proto string, target string, timeout time.Duration) (
|
|||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
|
||||||
|
type Dialer struct {
|
||||||
|
// Timeout is the maximum time to wait for a connection or I/O.
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a TimeoutConnection).
|
||||||
|
Dialer *net.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext wraps the connection returned by net.Dialer.DialContext() with a TimeoutConnection.
|
||||||
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
// ensure that our aux dialer is up-to-date
|
||||||
|
d.Dialer.Timeout = d.Timeout
|
||||||
|
d.Dialer.KeepAlive = d.Timeout
|
||||||
|
ret, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TimeoutConnection{
|
||||||
|
Conn: ret,
|
||||||
|
Timeout: d.Timeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial returns a connection with the configured timeout.
|
||||||
|
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
|
||||||
|
return DialTimeoutConnection(proto, target, d.Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.
|
||||||
|
func GetTimeoutConnectionDialer(timeout time.Duration) *Dialer {
|
||||||
|
return &Dialer{
|
||||||
|
Timeout: timeout,
|
||||||
|
Dialer: &net.Dialer{
|
||||||
|
Timeout: timeout,
|
||||||
|
KeepAlive: timeout,
|
||||||
|
DualStack: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
@ -241,6 +241,7 @@ func (scanner *Scanner) newHTTPScan(t *zgrab2.ScanTarget) *scan {
|
|||||||
client: http.MakeNewClient(),
|
client: http.MakeNewClient(),
|
||||||
}
|
}
|
||||||
ret.transport.DialTLS = ret.getTLSDialer()
|
ret.transport.DialTLS = ret.getTLSDialer()
|
||||||
|
ret.transport.DialContext = zgrab2.GetTimeoutConnectionDialer(time.Duration(scanner.config.Timeout) * time.Second).DialContext
|
||||||
ret.client.UserAgent = scanner.config.UserAgent
|
ret.client.UserAgent = scanner.config.UserAgent
|
||||||
ret.client.CheckRedirect = ret.getCheckRedirect()
|
ret.client.CheckRedirect = ret.getCheckRedirect()
|
||||||
ret.client.Transport = ret.transport
|
ret.client.Transport = ret.transport
|
||||||
|
Loading…
Reference in New Issue
Block a user