From a5d8d0b57ad3236595f17d22130d84ae0950fc73 Mon Sep 17 00:00:00 2001 From: justinbastress <33579608+justinbastress@users.noreply.github.com> Date: Tue, 19 Dec 2017 16:21:16 -0500 Subject: [PATCH] remove unnecessary indirection on net.Conn (#27) --- lib/mysql/mysql.go | 18 ++++++++---------- modules/mysql.go | 14 +------------- tls.go | 4 ++-- 3 files changed, 11 insertions(+), 25 deletions(-) diff --git a/lib/mysql/mysql.go b/lib/mysql/mysql.go index 2faaa12..e3e68c0 100644 --- a/lib/mysql/mysql.go +++ b/lib/mysql/mysql.go @@ -220,7 +220,7 @@ type Connection struct { // Enum to track connection status State ConnectionState // TCP or TLS-wrapped Connection pointer (IsSecure() will tell which) - Connection *net.Conn + Connection net.Conn // The sequence number used with the server to number packets SequenceNumber uint8 @@ -540,7 +540,7 @@ func (c *Connection) sendPacket(packet WritablePacket) (*ConnectionLogEntry, err } // @TODO: Buffered send? - _, err := (*c.Connection).Write(toSend) + _, err := c.Connection.Write(toSend) return &logPacket, err } @@ -564,9 +564,8 @@ func (c *Connection) decodePacket(body []byte) (PacketInfo, error) { // Read a packet and sequence identifier off of the given connection func (c *Connection) readPacket() (*ConnectionLogEntry, error) { // @TODO @FIXME Find/use conventional buffered packet-reading functions, handle timeouts / connection reset / etc - conn := *c.Connection - reader := bufio.NewReader(conn) - if terr := conn.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil { + reader := bufio.NewReader(c.Connection) + if terr := c.Connection.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil { return nil, fmt.Errorf("Error calling SetReadTimeout(): %s", terr) } var header [4]byte @@ -618,13 +617,12 @@ func (c *Connection) GetHandshake() *HandshakePacket { // Perform a TLS handshake using the configured TLSConfig on the current connection func (c *Connection) StartTLS() error { - - client := tls.Client(*c.Connection, c.Config.TLSConfig) + client := tls.Client(c.Connection, c.Config.TLSConfig) err := client.Handshake() if err != nil { return fmt.Errorf("TLS Handshake error: %s", err) } - *(c.Connection) = client + c.Connection = client return nil } @@ -666,7 +664,7 @@ func (c *Connection) Connect() error { log.Debugf("Error connecting: %v", err) return fmt.Errorf("Connect error: %s", err) } - c.Connection = &conn + c.Connection = conn c.State = STATE_CONNECTED c.ConnectionLog = ConnectionLog{ Handshake: nil, @@ -706,7 +704,7 @@ func (c *Connection) Disconnect() error { } c.State = STATE_NOT_CONNECTED // Change state even if close fails - return (*c.Connection).Close() + return c.Connection.Close() } // NUL STRING type from https://web.archive.org/web/20160316113745/https://dev.mysql.com/doc/internals/en/string.html diff --git a/modules/mysql.go b/modules/mysql.go index 03d66f4..16a6809 100644 --- a/modules/mysql.go +++ b/modules/mysql.go @@ -1,8 +1,6 @@ package modules import ( - "net" - log "github.com/sirupsen/logrus" "github.com/zmap/zgrab2" "github.com/zmap/zgrab2/lib/mysql" @@ -104,17 +102,7 @@ func (s *MySQLScanner) Scan(t zgrab2.ScanTarget) (status zgrab2.ScanStatus, resu panic(err) } // Replace sql.Connection to allow hypothetical future calls to go over the secure connection - var netConn net.Conn = conn - sql.Connection = &netConn - // Works: - // var netConn net.Conn = conn - // sql.Connection = &netConn - // Does not work: - // sql.Connection = &conn // (**ZGrabConnection is not *net.Conn) - // sql.Connection = &(conn.(net.Conn)) // (conn is not an interface) - // sql.Connection = conn.Conn // (cannot use conn.Conn (type tls.Conn) as type *net.Conn) - // sql.Connection = &conn.Conn // (cannot use &conn.Conn (type *tls.Conn) as type *net.Conn) - // sql.Connection = &(conn.Conn.conn) // (cannot refer to unexported field or method conn) + sql.Connection = conn } // If we made it this far, the scan was a success. return zgrab2.SCAN_SUCCESS, result, nil diff --git a/tls.go b/tls.go index 1012e92..8978a64 100644 --- a/tls.go +++ b/tls.go @@ -261,12 +261,12 @@ func (z *TLSConnection) Handshake() error { } } -func (t *TLSFlags) GetTLSConnection(conn *net.Conn) (*TLSConnection, error) { +func (t *TLSFlags) GetTLSConnection(conn net.Conn) (*TLSConnection, error) { cfg, err := t.GetTLSConfig() if err != nil { return nil, fmt.Errorf("Error getting TLSConfig for options: %s", err) } - tlsClient := tls.Client(*conn, cfg) + tlsClient := tls.Client(conn, cfg) wrappedClient := TLSConnection{ Conn: *tlsClient, flags: t,