diff --git a/config.go b/config.go index 12402be..e91ca62 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package zgrab2 import ( + "net" "net/http" "os" "runtime" @@ -16,7 +17,7 @@ type Config struct { InputFileName string `short:"f" long:"input-file" default:"-" description:"Input filename, use - for stdin"` MetaFileName string `short:"m" long:"metadata-file" default:"-" description:"Metadata filename, use - for stderr"` LogFileName string `short:"l" long:"log-file" default:"-" description:"Log filename, use - for stderr"` - Interface string `short:"i" long:"interface" description:"Network interface to send on"` + LocalAddress string `long:"source-ip" description:"Local source IP address to use for making connections"` Senders int `short:"s" long:"senders" default:"1000" description:"Number of send goroutines to use"` Debug bool `long:"debug" description:"Include debug fields in the output."` GOMAXPROCS int `long:"gomaxprocs" default:"0" description:"Set GOMAXPROCS"` @@ -30,6 +31,7 @@ type Config struct { logFile *os.File inputTargets InputTargetsFunc outputResults OutputResultsFunc + localAddr *net.TCPAddr } // SetInputFunc sets the target input function to the provided function. @@ -61,6 +63,14 @@ func validateFrameworkConfiguration() { } SetInputFunc(InputTargetsCSV) + if config.LocalAddress != "" { + parsed := net.ParseIP(config.LocalAddress) + if parsed == nil { + log.Fatalf("Error parsing local interface %s as IP", config.LocalAddress) + } + config.localAddr = &net.TCPAddr{parsed, 0, ""} + } + if config.InputFileName == "-" { config.inputFile = os.Stdin } else { diff --git a/conn.go b/conn.go index d9f7cfa..8c590c1 100644 --- a/conn.go +++ b/conn.go @@ -219,10 +219,10 @@ func (c *TimeoutConnection) SetDefaults() *TimeoutConnection { // NewTimeoutConnection returns a new TimeoutConnection with the appropriate defaults. func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) *TimeoutConnection { ret := (&TimeoutConnection{ - Conn: conn, - Timeout: timeout, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, + Conn: conn, + Timeout: timeout, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, BytesReadLimit: bytesReadLimit, }).SetDefaults() if ctx == nil { @@ -298,6 +298,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. // ensure that our aux dialer is up-to-date; copied from http/transport.go d.Dialer.Timeout = d.getTimeout(d.ConnectTimeout) d.Dialer.KeepAlive = d.Timeout + + // Copy over the source IP if set, or nil + d.Dialer.LocalAddr = config.localAddr + dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout) defer cancelDial() conn, err := d.Dialer.DialContext(dialContext, network, address) diff --git a/go.mod b/go.mod index b03adc7..260f5d7 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/zmap/zgrab2 go 1.12 require ( - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.1 // indirect github.com/prometheus/client_golang v1.1.0 github.com/sirupsen/logrus v1.4.2 github.com/zmap/zcrypto v0.0.0-20190729165852-9051775e6a2e