remove unnecessary indirection on net.Conn (#27)

This commit is contained in:
justinbastress 2017-12-19 16:21:16 -05:00 committed by GitHub
parent 02f94e2f49
commit a5d8d0b57a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 25 deletions

View File

@ -220,7 +220,7 @@ type Connection struct {
// Enum to track connection status
State ConnectionState
// TCP or TLS-wrapped Connection pointer (IsSecure() will tell which)
Connection *net.Conn
Connection net.Conn
// The sequence number used with the server to number packets
SequenceNumber uint8
@ -540,7 +540,7 @@ func (c *Connection) sendPacket(packet WritablePacket) (*ConnectionLogEntry, err
}
// @TODO: Buffered send?
_, err := (*c.Connection).Write(toSend)
_, err := c.Connection.Write(toSend)
return &logPacket, err
}
@ -564,9 +564,8 @@ func (c *Connection) decodePacket(body []byte) (PacketInfo, error) {
// Read a packet and sequence identifier off of the given connection
func (c *Connection) readPacket() (*ConnectionLogEntry, error) {
// @TODO @FIXME Find/use conventional buffered packet-reading functions, handle timeouts / connection reset / etc
conn := *c.Connection
reader := bufio.NewReader(conn)
if terr := conn.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil {
reader := bufio.NewReader(c.Connection)
if terr := c.Connection.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil {
return nil, fmt.Errorf("Error calling SetReadTimeout(): %s", terr)
}
var header [4]byte
@ -618,13 +617,12 @@ func (c *Connection) GetHandshake() *HandshakePacket {
// Perform a TLS handshake using the configured TLSConfig on the current connection
func (c *Connection) StartTLS() error {
client := tls.Client(*c.Connection, c.Config.TLSConfig)
client := tls.Client(c.Connection, c.Config.TLSConfig)
err := client.Handshake()
if err != nil {
return fmt.Errorf("TLS Handshake error: %s", err)
}
*(c.Connection) = client
c.Connection = client
return nil
}
@ -666,7 +664,7 @@ func (c *Connection) Connect() error {
log.Debugf("Error connecting: %v", err)
return fmt.Errorf("Connect error: %s", err)
}
c.Connection = &conn
c.Connection = conn
c.State = STATE_CONNECTED
c.ConnectionLog = ConnectionLog{
Handshake: nil,
@ -706,7 +704,7 @@ func (c *Connection) Disconnect() error {
}
c.State = STATE_NOT_CONNECTED
// Change state even if close fails
return (*c.Connection).Close()
return c.Connection.Close()
}
// NUL STRING type from https://web.archive.org/web/20160316113745/https://dev.mysql.com/doc/internals/en/string.html

View File

@ -1,8 +1,6 @@
package modules
import (
"net"
log "github.com/sirupsen/logrus"
"github.com/zmap/zgrab2"
"github.com/zmap/zgrab2/lib/mysql"
@ -104,17 +102,7 @@ func (s *MySQLScanner) Scan(t zgrab2.ScanTarget) (status zgrab2.ScanStatus, resu
panic(err)
}
// Replace sql.Connection to allow hypothetical future calls to go over the secure connection
var netConn net.Conn = conn
sql.Connection = &netConn
// Works:
// var netConn net.Conn = conn
// sql.Connection = &netConn
// Does not work:
// sql.Connection = &conn // (**ZGrabConnection is not *net.Conn)
// sql.Connection = &(conn.(net.Conn)) // (conn is not an interface)
// sql.Connection = conn.Conn // (cannot use conn.Conn (type tls.Conn) as type *net.Conn)
// sql.Connection = &conn.Conn // (cannot use &conn.Conn (type *tls.Conn) as type *net.Conn)
// sql.Connection = &(conn.Conn.conn) // (cannot refer to unexported field or method conn)
sql.Connection = conn
}
// If we made it this far, the scan was a success.
return zgrab2.SCAN_SUCCESS, result, nil

4
tls.go
View File

@ -261,12 +261,12 @@ func (z *TLSConnection) Handshake() error {
}
}
func (t *TLSFlags) GetTLSConnection(conn *net.Conn) (*TLSConnection, error) {
func (t *TLSFlags) GetTLSConnection(conn net.Conn) (*TLSConnection, error) {
cfg, err := t.GetTLSConfig()
if err != nil {
return nil, fmt.Errorf("Error getting TLSConfig for options: %s", err)
}
tlsClient := tls.Client(*conn, cfg)
tlsClient := tls.Client(conn, cfg)
wrappedClient := TLSConnection{
Conn: *tlsClient,
flags: t,