diff --git a/fake_resolver.go b/fake_resolver.go new file mode 100644 index 0000000..00578db --- /dev/null +++ b/fake_resolver.go @@ -0,0 +1,165 @@ +package zgrab2 + +import ( + "context" + "errors" + "fmt" + "golang.org/x/net/dns/dnsmessage" + "net" + "time" +) + +// Fake DNS Resolver, to force a DNS lookup to return a pinned address +// Inspired by the golang net/dnsclient_unix_test.go code +// +// For a given IP, create a new Resolver that wraps a fake +// DNS server. This resolver will always return an IP that +// is represented by "ipstr", for DNS queries of the same +// IP type. Otherwise, it will return a DNS lookup error. +func NewFakeResolver(ipstr string) (*net.Resolver, error) { + ip := net.ParseIP(ipstr) + if len(ip) < 4 { + return nil, fmt.Errorf("Fake resolver can't use non-IP '%s'", ipstr) + } + fDNS := FakeDNSServer{ + IP: ip, + } + return &net.Resolver{ + PreferGo: true, // Needed to force the use of the Go internal resolver + Dial: fDNS.DialContext, + }, nil +} + +type FakeDNSServer struct { + // Any domain name will resolve to this IP. It can be either ipv4 or ipv6 + IP net.IP +} + +// For a given DNS query, return the hard-coded IP that is part of +// FakeDNSServer. +// +// It will work with either ipv4 or ipv6 addresses; if a TypeA question +// is received, we will only return the IP if what we have to return is +// ipv4. The same for TypeAAAA and ipv6. +func (f *FakeDNSServer) fakeDNS(s string, dmsg dnsmessage.Message) (r dnsmessage.Message, err error) { + + r = dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: dmsg.ID, + Response: true, + }, + Questions: dmsg.Questions, + } + ipv6 := f.IP.To16() + ipv4 := f.IP.To4() + switch t := dmsg.Questions[0].Type; { + case t == dnsmessage.TypeA && ipv4 != nil: + var ip [4]byte + copy(ip[:], []byte(ipv4)) + r.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dmsg.Questions[0].Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + Length: 4, + }, + Body: &dnsmessage.AResource{ + A: ip, + }, + }, + } + case t == dnsmessage.TypeAAAA && ipv4 == nil: + var ip [16]byte + copy(ip[:], []byte(ipv6)) + r.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dmsg.Questions[0].Name, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + Length: 16, + }, + Body: &dnsmessage.AAAAResource{ + AAAA: ip, + }, + }, + } + default: + r.Header.RCode = dnsmessage.RCodeNameError + } + + return r, nil +} + +// This merely wraps a custom net.Conn, that is only good for DNS +// messages +func (f *FakeDNSServer) DialContext(ctx context.Context, network, + address string) (net.Conn, error) { + + conn := &fakeDNSPacketConn{ + fakeDNSConn: fakeDNSConn{ + server: f, + network: network, + address: address, + }, + } + return conn, nil +} + +type fakeDNSConn struct { + net.Conn + server *FakeDNSServer + network string + address string + dmsg dnsmessage.Message +} + +func (fc *fakeDNSConn) Read(b []byte) (int, error) { + resp, err := fc.server.fakeDNS(fc.address, fc.dmsg) + if err != nil { + return 0, err + } + + bb := make([]byte, 2, 514) + bb, err = resp.AppendPack(bb) + if err != nil { + return 0, fmt.Errorf("cannot marshal DNS message: %v", err) + } + + bb = bb[2:] + if len(b) < len(bb) { + return 0, errors.New("read would fragment DNS message") + } + + copy(b, bb) + return len(bb), nil +} + +func (fc *fakeDNSConn) Write(b []byte) (int, error) { + if fc.dmsg.Unpack(b) != nil { + return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", fc.network, len(b)) + } + return len(b), nil +} + +func (fc *fakeDNSConn) SetDeadline(deadline time.Time) error { + return nil +} + +func (fc *fakeDNSConn) Close() error { + return nil +} + +type fakeDNSPacketConn struct { + net.PacketConn + fakeDNSConn +} + +func (f *fakeDNSPacketConn) SetDeadline(deadline time.Time) error { + return nil +} + +func (f *fakeDNSPacketConn) Close() error { + return f.fakeDNSConn.Close() +} diff --git a/modules/http/scanner.go b/modules/http/scanner.go index 4292ca9..e8033a7 100644 --- a/modules/http/scanner.go +++ b/modules/http/scanner.go @@ -158,12 +158,40 @@ func (scan *scan) withDeadlineContext(ctx context.Context) context.Context { // Dial a connection using the configured timeouts, as well as the global deadline, and on success, // add the connection to the list of connections to be cleaned up. -func (scan *scan) dialContext(ctx context.Context, net string, addr string) (net.Conn, error) { +func (scan *scan) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) { dialer := zgrab2.GetTimeoutConnectionDialer(scan.scanner.config.Timeout) + switch network { + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": + // If the scan is for a specific IP, and a domain name is provided, we + // don't want to just let the http library resolve the domain. Create + // a fake resolver that we will use, that always returns the IP we are + // given to scan. + if scan.target.IP != nil && scan.target.Domain != "" { + host, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("http/scanner.go dialContext: unable to split host:port '%s'", addr) + log.Errorf("No fake resolver, IP address may be incorrect: %s", err) + } else { + // In the case of redirects, we don't want to blindly use the + // IP we were given to scan, however. Only use the fake + // resolver if the domain originally specified for the scan + // target matches the current address being looked up in this + // DialContext. + if host == scan.target.Domain { + resolver, err := zgrab2.NewFakeResolver(scan.target.IP.String()) + if err != nil { + return nil, err + } + dialer.Dialer.Resolver = resolver + } + } + } + } + timeoutContext, _ := context.WithTimeout(context.Background(), scan.scanner.config.Timeout) - conn, err := dialer.DialContext(scan.withDeadlineContext(timeoutContext), net, addr) + conn, err := dialer.DialContext(scan.withDeadlineContext(timeoutContext), network, addr) if err != nil { return nil, err } @@ -210,7 +238,8 @@ func redirectsToLocalhost(host string) bool { return false } -// Taken from zgrab/zlib/grabber.go -- get a CheckRedirect callback that uses the redirectToLocalhost and MaxRedirects config +// Taken from zgrab/zlib/grabber.go -- get a CheckRedirect callback that uses +// the redirectToLocalhost and MaxRedirects config func (scan *scan) getCheckRedirect() func(*http.Request, *http.Response, []*http.Request) error { return func(req *http.Request, res *http.Response, via []*http.Request) error { if !scan.scanner.config.FollowLocalhostRedirects && redirectsToLocalhost(req.URL.Hostname()) {