diff --git a/modules/tls.go b/modules/tls.go index 7fd6c83..5dc6ccd 100644 --- a/modules/tls.go +++ b/modules/tls.go @@ -62,20 +62,29 @@ func (s *TLSScanner) InitPerSender(senderID int) error { return nil } +// Scan opens a TCP connection to the target (default port 443), then performs +// a TLS handshake. If the handshake gets past the ServerHello stage, the +// handshake log is returned (along with any other TLS-related logs, such as +// heartbleed, if enabled). func (s *TLSScanner) Scan(t zgrab2.ScanTarget) (zgrab2.ScanStatus, interface{}, error) { - tcpConn, err := t.Open(&s.config.BaseFlags) + conn, err := t.OpenTLS(&s.config.BaseFlags, &s.config.TLSFlags) + if conn != nil { + defer conn.Close() + } if err != nil { - return zgrab2.TryGetScanStatus(err), &zgrab2.TLSLog{}, err + if conn != nil { + if log := conn.GetLog(); log != nil { + if log.HandshakeLog.ServerHello != nil { + // If we got far enough to get a valid ServerHello, then + // consider it to be a positive TLS detection. + return zgrab2.TryGetScanStatus(err), log, err + } + // Otherwise, detection failed. + } + } + return zgrab2.TryGetScanStatus(err), nil, err } - var conn *zgrab2.TLSConnection - if conn, err = s.config.TLSFlags.GetTLSConnection(tcpConn); err != nil { - return zgrab2.TryGetScanStatus(err), &zgrab2.TLSLog{}, err - } - result := conn.GetLog() - if err = conn.Handshake(); err != nil { - return zgrab2.TryGetScanStatus(err), result, err - } - return zgrab2.SCAN_SUCCESS, result, nil + return zgrab2.SCAN_SUCCESS, conn.GetLog(), nil } // Protocol returns the protocol identifer for the scanner. diff --git a/processing.go b/processing.go index 9ddcd8a..64b6c40 100644 --- a/processing.go +++ b/processing.go @@ -60,6 +60,18 @@ func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error) { return DialTimeoutConnection("tcp", address, flags.Timeout, flags.BytesReadLimit) } +// OpenTLS connects to the ScanTarget using the configured flags, then performs +// the TLS handshake. On success error is nil, but the connection can be non-nil +// even if there is an error (this allows fetching the handshake log). +func (target *ScanTarget) OpenTLS(baseFlags *BaseFlags, tlsFlags *TLSFlags) (*TLSConnection, error) { + conn, err := tlsFlags.Connect(target, baseFlags) + if err != nil { + return conn, err + } + err = conn.Handshake() + return conn, err +} + // OpenUDP connects to the ScanTarget using the configured flags, and returns a net.Conn that uses the configured timeouts for Read/Write operations. // Note that the UDP "connection" does not have an associated timeout. func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, error) { diff --git a/tls.go b/tls.go index 92a9765..bd12ad8 100644 --- a/tls.go +++ b/tls.go @@ -80,6 +80,10 @@ func getCSV(arg string) []string { } func (t *TLSFlags) GetTLSConfig() (*tls.Config, error) { + return t.GetTLSConfigForTarget(nil) +} + +func (t *TLSFlags) GetTLSConfigForTarget(target *ScanTarget) (*tls.Config, error) { var err error // TODO: Find standard names @@ -129,7 +133,14 @@ func (t *TLSFlags) GetTLSConfig() (*tls.Config, error) { } if t.ServerName != "" { // TODO: In the original zgrab, this was only set of NoSNI was not set (though in that case, it set it to the scanning host name) + // Here, if an explicit ServerName is given, set that, ignoring NoSNI. ret.ServerName = t.ServerName + } else { + // If no explicit ServerName is given, and SNI is not disabled, use the + // target's domain name (if available). + if !t.NoSNI && target != nil { + ret.ServerName = target.Domain + } } if t.VerifyServerCertificate { ret.InsecureSkipVerify = false @@ -281,8 +292,23 @@ func (conn *TLSConnection) Close() error { return conn.Conn.Close() } +// Connect opens the TCP connection to the target using the given configuration, +// and then returns the configured wrapped TLS connection. The caller must still +// call Handshake(). +func (t *TLSFlags) Connect(target *ScanTarget, flags *BaseFlags) (*TLSConnection, error) { + tcpConn, err := target.Open(flags) + if err != nil { + return nil, err + } + return t.GetTLSConnectionForTarget(tcpConn, target) +} + func (t *TLSFlags) GetTLSConnection(conn net.Conn) (*TLSConnection, error) { - cfg, err := t.GetTLSConfig() + return t.GetTLSConnectionForTarget(conn, nil) +} + +func (t *TLSFlags) GetTLSConnectionForTarget(conn net.Conn, target *ScanTarget) (*TLSConnection, error) { + cfg, err := t.GetTLSConfigForTarget(target) if err != nil { return nil, fmt.Errorf("Error getting TLSConfig for options: %s", err) }