remove unnecessary indirection on net.Conn (#27)
This commit is contained in:
parent
02f94e2f49
commit
a5d8d0b57a
@ -220,7 +220,7 @@ type Connection struct {
|
||||
// Enum to track connection status
|
||||
State ConnectionState
|
||||
// TCP or TLS-wrapped Connection pointer (IsSecure() will tell which)
|
||||
Connection *net.Conn
|
||||
Connection net.Conn
|
||||
// The sequence number used with the server to number packets
|
||||
SequenceNumber uint8
|
||||
|
||||
@ -540,7 +540,7 @@ func (c *Connection) sendPacket(packet WritablePacket) (*ConnectionLogEntry, err
|
||||
}
|
||||
|
||||
// @TODO: Buffered send?
|
||||
_, err := (*c.Connection).Write(toSend)
|
||||
_, err := c.Connection.Write(toSend)
|
||||
return &logPacket, err
|
||||
}
|
||||
|
||||
@ -564,9 +564,8 @@ func (c *Connection) decodePacket(body []byte) (PacketInfo, error) {
|
||||
// Read a packet and sequence identifier off of the given connection
|
||||
func (c *Connection) readPacket() (*ConnectionLogEntry, error) {
|
||||
// @TODO @FIXME Find/use conventional buffered packet-reading functions, handle timeouts / connection reset / etc
|
||||
conn := *c.Connection
|
||||
reader := bufio.NewReader(conn)
|
||||
if terr := conn.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil {
|
||||
reader := bufio.NewReader(c.Connection)
|
||||
if terr := c.Connection.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil {
|
||||
return nil, fmt.Errorf("Error calling SetReadTimeout(): %s", terr)
|
||||
}
|
||||
var header [4]byte
|
||||
@ -618,13 +617,12 @@ func (c *Connection) GetHandshake() *HandshakePacket {
|
||||
|
||||
// Perform a TLS handshake using the configured TLSConfig on the current connection
|
||||
func (c *Connection) StartTLS() error {
|
||||
|
||||
client := tls.Client(*c.Connection, c.Config.TLSConfig)
|
||||
client := tls.Client(c.Connection, c.Config.TLSConfig)
|
||||
err := client.Handshake()
|
||||
if err != nil {
|
||||
return fmt.Errorf("TLS Handshake error: %s", err)
|
||||
}
|
||||
*(c.Connection) = client
|
||||
c.Connection = client
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -666,7 +664,7 @@ func (c *Connection) Connect() error {
|
||||
log.Debugf("Error connecting: %v", err)
|
||||
return fmt.Errorf("Connect error: %s", err)
|
||||
}
|
||||
c.Connection = &conn
|
||||
c.Connection = conn
|
||||
c.State = STATE_CONNECTED
|
||||
c.ConnectionLog = ConnectionLog{
|
||||
Handshake: nil,
|
||||
@ -706,7 +704,7 @@ func (c *Connection) Disconnect() error {
|
||||
}
|
||||
c.State = STATE_NOT_CONNECTED
|
||||
// Change state even if close fails
|
||||
return (*c.Connection).Close()
|
||||
return c.Connection.Close()
|
||||
}
|
||||
|
||||
// NUL STRING type from https://web.archive.org/web/20160316113745/https://dev.mysql.com/doc/internals/en/string.html
|
||||
|
@ -1,8 +1,6 @@
|
||||
package modules
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zmap/zgrab2"
|
||||
"github.com/zmap/zgrab2/lib/mysql"
|
||||
@ -104,17 +102,7 @@ func (s *MySQLScanner) Scan(t zgrab2.ScanTarget) (status zgrab2.ScanStatus, resu
|
||||
panic(err)
|
||||
}
|
||||
// Replace sql.Connection to allow hypothetical future calls to go over the secure connection
|
||||
var netConn net.Conn = conn
|
||||
sql.Connection = &netConn
|
||||
// Works:
|
||||
// var netConn net.Conn = conn
|
||||
// sql.Connection = &netConn
|
||||
// Does not work:
|
||||
// sql.Connection = &conn // (**ZGrabConnection is not *net.Conn)
|
||||
// sql.Connection = &(conn.(net.Conn)) // (conn is not an interface)
|
||||
// sql.Connection = conn.Conn // (cannot use conn.Conn (type tls.Conn) as type *net.Conn)
|
||||
// sql.Connection = &conn.Conn // (cannot use &conn.Conn (type *tls.Conn) as type *net.Conn)
|
||||
// sql.Connection = &(conn.Conn.conn) // (cannot refer to unexported field or method conn)
|
||||
sql.Connection = conn
|
||||
}
|
||||
// If we made it this far, the scan was a success.
|
||||
return zgrab2.SCAN_SUCCESS, result, nil
|
||||
|
4
tls.go
4
tls.go
@ -261,12 +261,12 @@ func (z *TLSConnection) Handshake() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TLSFlags) GetTLSConnection(conn *net.Conn) (*TLSConnection, error) {
|
||||
func (t *TLSFlags) GetTLSConnection(conn net.Conn) (*TLSConnection, error) {
|
||||
cfg, err := t.GetTLSConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error getting TLSConfig for options: %s", err)
|
||||
}
|
||||
tlsClient := tls.Client(*conn, cfg)
|
||||
tlsClient := tls.Client(conn, cfg)
|
||||
wrappedClient := TLSConnection{
|
||||
Conn: *tlsClient,
|
||||
flags: t,
|
||||
|
Loading…
Reference in New Issue
Block a user