zgrab2/conn.go
2018-02-02 14:52:03 -05:00

61 lines
1.6 KiB
Go

package zgrab2
import (
"net"
"time"
)
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
type TimeoutConnection struct {
net.Conn
Timeout time.Duration
}
// TimeoutConnection.Read calls Read() on the underlying connection, using any configured deadlines
func (c *TimeoutConnection) Read(b []byte) (n int, err error) {
if c.Timeout > 0 {
if err = c.Conn.SetReadDeadline(time.Now().Add(c.Timeout)); err != nil {
return 0, err
}
}
return c.Conn.Read(b)
}
// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines
func (c *TimeoutConnection) Write(b []byte) (n int, err error) {
if c.Timeout > 0 {
if err = c.Conn.SetWriteDeadline(time.Now().Add(c.Timeout)); err != nil {
return 0, err
}
}
return c.Conn.Write(b)
}
// GetTimeoutDialer returns a Dialer function that dials with the given timeout
func GetTimeoutDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(proto, target string) (net.Conn, error) {
return DialTimeoutConnection(proto, target, timeout)
}
}
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
var conn net.Conn
var err error
if timeout > 0 {
conn, err = net.DialTimeout(proto, target, timeout)
} else {
conn, err = net.Dial(proto, target)
}
if err != nil {
if conn != nil {
conn.Close()
}
return nil, err
}
return &TimeoutConnection{
Conn: conn,
Timeout: timeout,
}, nil
}