remove unnecessary indirection on net.Conn (#27)
This commit is contained in:
parent
02f94e2f49
commit
a5d8d0b57a
@ -220,7 +220,7 @@ type Connection struct {
|
|||||||
// Enum to track connection status
|
// Enum to track connection status
|
||||||
State ConnectionState
|
State ConnectionState
|
||||||
// TCP or TLS-wrapped Connection pointer (IsSecure() will tell which)
|
// TCP or TLS-wrapped Connection pointer (IsSecure() will tell which)
|
||||||
Connection *net.Conn
|
Connection net.Conn
|
||||||
// The sequence number used with the server to number packets
|
// The sequence number used with the server to number packets
|
||||||
SequenceNumber uint8
|
SequenceNumber uint8
|
||||||
|
|
||||||
@ -540,7 +540,7 @@ func (c *Connection) sendPacket(packet WritablePacket) (*ConnectionLogEntry, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// @TODO: Buffered send?
|
// @TODO: Buffered send?
|
||||||
_, err := (*c.Connection).Write(toSend)
|
_, err := c.Connection.Write(toSend)
|
||||||
return &logPacket, err
|
return &logPacket, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -564,9 +564,8 @@ func (c *Connection) decodePacket(body []byte) (PacketInfo, error) {
|
|||||||
// Read a packet and sequence identifier off of the given connection
|
// Read a packet and sequence identifier off of the given connection
|
||||||
func (c *Connection) readPacket() (*ConnectionLogEntry, error) {
|
func (c *Connection) readPacket() (*ConnectionLogEntry, error) {
|
||||||
// @TODO @FIXME Find/use conventional buffered packet-reading functions, handle timeouts / connection reset / etc
|
// @TODO @FIXME Find/use conventional buffered packet-reading functions, handle timeouts / connection reset / etc
|
||||||
conn := *c.Connection
|
reader := bufio.NewReader(c.Connection)
|
||||||
reader := bufio.NewReader(conn)
|
if terr := c.Connection.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil {
|
||||||
if terr := conn.SetReadDeadline(time.Now().Add(c.Config.Timeout)); terr != nil {
|
|
||||||
return nil, fmt.Errorf("Error calling SetReadTimeout(): %s", terr)
|
return nil, fmt.Errorf("Error calling SetReadTimeout(): %s", terr)
|
||||||
}
|
}
|
||||||
var header [4]byte
|
var header [4]byte
|
||||||
@ -618,13 +617,12 @@ func (c *Connection) GetHandshake() *HandshakePacket {
|
|||||||
|
|
||||||
// Perform a TLS handshake using the configured TLSConfig on the current connection
|
// Perform a TLS handshake using the configured TLSConfig on the current connection
|
||||||
func (c *Connection) StartTLS() error {
|
func (c *Connection) StartTLS() error {
|
||||||
|
client := tls.Client(c.Connection, c.Config.TLSConfig)
|
||||||
client := tls.Client(*c.Connection, c.Config.TLSConfig)
|
|
||||||
err := client.Handshake()
|
err := client.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("TLS Handshake error: %s", err)
|
return fmt.Errorf("TLS Handshake error: %s", err)
|
||||||
}
|
}
|
||||||
*(c.Connection) = client
|
c.Connection = client
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -666,7 +664,7 @@ func (c *Connection) Connect() error {
|
|||||||
log.Debugf("Error connecting: %v", err)
|
log.Debugf("Error connecting: %v", err)
|
||||||
return fmt.Errorf("Connect error: %s", err)
|
return fmt.Errorf("Connect error: %s", err)
|
||||||
}
|
}
|
||||||
c.Connection = &conn
|
c.Connection = conn
|
||||||
c.State = STATE_CONNECTED
|
c.State = STATE_CONNECTED
|
||||||
c.ConnectionLog = ConnectionLog{
|
c.ConnectionLog = ConnectionLog{
|
||||||
Handshake: nil,
|
Handshake: nil,
|
||||||
@ -706,7 +704,7 @@ func (c *Connection) Disconnect() error {
|
|||||||
}
|
}
|
||||||
c.State = STATE_NOT_CONNECTED
|
c.State = STATE_NOT_CONNECTED
|
||||||
// Change state even if close fails
|
// Change state even if close fails
|
||||||
return (*c.Connection).Close()
|
return c.Connection.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NUL STRING type from https://web.archive.org/web/20160316113745/https://dev.mysql.com/doc/internals/en/string.html
|
// NUL STRING type from https://web.archive.org/web/20160316113745/https://dev.mysql.com/doc/internals/en/string.html
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package modules
|
package modules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/zmap/zgrab2"
|
"github.com/zmap/zgrab2"
|
||||||
"github.com/zmap/zgrab2/lib/mysql"
|
"github.com/zmap/zgrab2/lib/mysql"
|
||||||
@ -104,17 +102,7 @@ func (s *MySQLScanner) Scan(t zgrab2.ScanTarget) (status zgrab2.ScanStatus, resu
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
// Replace sql.Connection to allow hypothetical future calls to go over the secure connection
|
// Replace sql.Connection to allow hypothetical future calls to go over the secure connection
|
||||||
var netConn net.Conn = conn
|
sql.Connection = conn
|
||||||
sql.Connection = &netConn
|
|
||||||
// Works:
|
|
||||||
// var netConn net.Conn = conn
|
|
||||||
// sql.Connection = &netConn
|
|
||||||
// Does not work:
|
|
||||||
// sql.Connection = &conn // (**ZGrabConnection is not *net.Conn)
|
|
||||||
// sql.Connection = &(conn.(net.Conn)) // (conn is not an interface)
|
|
||||||
// sql.Connection = conn.Conn // (cannot use conn.Conn (type tls.Conn) as type *net.Conn)
|
|
||||||
// sql.Connection = &conn.Conn // (cannot use &conn.Conn (type *tls.Conn) as type *net.Conn)
|
|
||||||
// sql.Connection = &(conn.Conn.conn) // (cannot refer to unexported field or method conn)
|
|
||||||
}
|
}
|
||||||
// If we made it this far, the scan was a success.
|
// If we made it this far, the scan was a success.
|
||||||
return zgrab2.SCAN_SUCCESS, result, nil
|
return zgrab2.SCAN_SUCCESS, result, nil
|
||||||
|
4
tls.go
4
tls.go
@ -261,12 +261,12 @@ func (z *TLSConnection) Handshake() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSFlags) GetTLSConnection(conn *net.Conn) (*TLSConnection, error) {
|
func (t *TLSFlags) GetTLSConnection(conn net.Conn) (*TLSConnection, error) {
|
||||||
cfg, err := t.GetTLSConfig()
|
cfg, err := t.GetTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error getting TLSConfig for options: %s", err)
|
return nil, fmt.Errorf("Error getting TLSConfig for options: %s", err)
|
||||||
}
|
}
|
||||||
tlsClient := tls.Client(*conn, cfg)
|
tlsClient := tls.Client(conn, cfg)
|
||||||
wrappedClient := TLSConnection{
|
wrappedClient := TLSConnection{
|
||||||
Conn: *tlsClient,
|
Conn: *tlsClient,
|
||||||
flags: t,
|
flags: t,
|
||||||
|
Loading…
Reference in New Issue
Block a user