346 lines
11 KiB
Go
346 lines
11 KiB
Go
package zgrab2
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ReadLimitExceededAction describes how the connection reacts to an attempt to read more data than permitted.
|
|
type ReadLimitExceededAction string
|
|
|
|
const (
|
|
// ReadLimitExceededActionNotSet is a placeholder for the zero value, so that explicitly set values can be
|
|
// distinguished from the empty default.
|
|
ReadLimitExceededActionNotSet = ReadLimitExceededAction("")
|
|
|
|
// ReadLimitExceededActionTruncate causes the connection to truncate at BytesReadLimit bytes and return a bogus
|
|
// io.EOF error. The fact that a truncation took place is logged at debug level.
|
|
ReadLimitExceededActionTruncate = ReadLimitExceededAction("truncate")
|
|
|
|
// ReadLimitExceededActionError causes the Read call to return n, ErrReadLimitExceeded (in addition to truncating).
|
|
ReadLimitExceededActionError = ReadLimitExceededAction("error")
|
|
|
|
// ReadLimitExceededActionPanic causes the Read call to panic(ErrReadLimitExceeded).
|
|
ReadLimitExceededActionPanic = ReadLimitExceededAction("panic")
|
|
)
|
|
|
|
var (
|
|
// DefaultBytesReadLimit is the maximum number of bytes to read per connection when no explicit value is provided.
|
|
DefaultBytesReadLimit = 256 * 1024 * 1024
|
|
|
|
// DefaultReadLimitExceededAction is the action used when no explicit action is set.
|
|
DefaultReadLimitExceededAction = ReadLimitExceededActionTruncate
|
|
|
|
// DefaultSessionTimeout is the default maximum time a connection may be used when no explicit value is provided.
|
|
DefaultSessionTimeout = 1 * time.Minute
|
|
)
|
|
|
|
var ErrReadLimitExceeded error = errors.New("read limit exceeded")
|
|
|
|
// TODO: Refactor this into TimeoutConnection, BoundedReader, LoggedReader, etc
|
|
// TimeoutConnection wraps an existing net.Conn connection, overriding the Read/Write methods to use the configured timeouts
|
|
type TimeoutConnection struct {
|
|
net.Conn
|
|
ctx context.Context
|
|
Timeout time.Duration
|
|
ReadTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
BytesRead int
|
|
BytesWritten int
|
|
BytesReadLimit int
|
|
ReadLimitExceededAction ReadLimitExceededAction
|
|
Cancel context.CancelFunc
|
|
explicitReadDeadline bool
|
|
explicitWriteDeadline bool
|
|
explicitDeadline bool
|
|
}
|
|
|
|
// TimeoutConnection.Read calls Read() on the underlying connection, using any configured deadlines
|
|
func (c *TimeoutConnection) Read(b []byte) (n int, err error) {
|
|
if err := c.checkContext(); err != nil {
|
|
return 0, err
|
|
}
|
|
origSize := len(b)
|
|
if c.BytesRead+len(b) >= c.BytesReadLimit {
|
|
b = b[0 : c.BytesReadLimit-c.BytesRead]
|
|
}
|
|
if c.explicitReadDeadline || c.explicitDeadline {
|
|
c.explicitReadDeadline = false
|
|
c.explicitDeadline = false
|
|
} else if readTimeout := c.getTimeout(c.ReadTimeout); readTimeout > 0 {
|
|
if err = c.Conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
n, err = c.Conn.Read(b)
|
|
c.BytesRead += n
|
|
if err == nil && origSize != len(b) && n == len(b) {
|
|
// we had to shrink the output buffer AND we used up the whole shrunk size, AND we're not at EOF
|
|
switch c.ReadLimitExceededAction {
|
|
case ReadLimitExceededActionTruncate:
|
|
logrus.Debug("Truncated read from %d bytes to %d bytes (hit limit of %d bytes)", origSize, n, c.BytesReadLimit)
|
|
err = io.EOF
|
|
case ReadLimitExceededActionError:
|
|
return n, ErrReadLimitExceeded
|
|
case ReadLimitExceededActionPanic:
|
|
panic(ErrReadLimitExceeded)
|
|
default:
|
|
logrus.Fatalf("Unrecognized ReadLimitExceededAction: %s", c.ReadLimitExceededAction)
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// TimeoutConnection.Write calls Write() on the underlying connection, using any configured deadlines.
|
|
func (c *TimeoutConnection) Write(b []byte) (n int, err error) {
|
|
if err := c.checkContext(); err != nil {
|
|
return 0, err
|
|
}
|
|
if c.explicitWriteDeadline || c.explicitDeadline {
|
|
c.explicitWriteDeadline = false
|
|
c.explicitDeadline = false
|
|
} else if writeTimeout := c.getTimeout(c.WriteTimeout); writeTimeout > 0 {
|
|
if err = c.Conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
n, err = c.Conn.Write(b)
|
|
c.BytesWritten += n
|
|
return n, err
|
|
}
|
|
|
|
// SetReadDeadline sets an explicit ReadDeadline that will override the timeout
|
|
// for one read. Use deadline = 0 to clear the deadline.
|
|
func (c *TimeoutConnection) SetReadDeadline(deadline time.Time) error {
|
|
if err := c.checkContext(); err != nil {
|
|
return err
|
|
}
|
|
if !deadline.IsZero() {
|
|
err := c.Conn.SetReadDeadline(deadline)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
c.explicitReadDeadline = !deadline.IsZero()
|
|
return nil
|
|
}
|
|
|
|
// SetWriteDeadline sets an explicit WriteDeadline that will override the
|
|
// WriteDeadline for one write. Use deadline = 0 to clear the deadline.
|
|
func (c *TimeoutConnection) SetWriteDeadline(deadline time.Time) error {
|
|
if err := c.checkContext(); err != nil {
|
|
return err
|
|
}
|
|
if !deadline.IsZero() {
|
|
err := c.Conn.SetWriteDeadline(deadline)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
c.explicitWriteDeadline = deadline.IsZero()
|
|
return nil
|
|
}
|
|
|
|
// SetDeadline sets a read / write deadline that will override the deadline for
|
|
// a single read/write. Use deadline = 0 to clear the deadline.
|
|
func (c *TimeoutConnection) SetDeadline(deadline time.Time) error {
|
|
if err := c.checkContext(); err != nil {
|
|
return err
|
|
}
|
|
if !deadline.IsZero() {
|
|
err := c.Conn.SetDeadline(deadline)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
c.explicitDeadline = deadline.IsZero()
|
|
return nil
|
|
}
|
|
|
|
// GetTimeoutDialer returns a DialFuncn that dials with the given timeout
|
|
func GetTimeoutDialFunc(timeout time.Duration) func(string, string) (net.Conn, error) {
|
|
return func(proto, target string) (net.Conn, error) {
|
|
return DialTimeoutConnection(proto, target, timeout)
|
|
}
|
|
}
|
|
|
|
// Close the underlying connection.
|
|
func (c *TimeoutConnection) Close() error {
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
// Get the timeout for the given field, falling back to the global timeout.
|
|
func (c *TimeoutConnection) getTimeout(field time.Duration) time.Duration {
|
|
if field == 0 {
|
|
return c.Timeout
|
|
}
|
|
return field
|
|
}
|
|
|
|
// Check if the context has been cancelled, and if so, return an error (either the context error, or
|
|
// if the context error is nil, ErrTotalTimeout).
|
|
func (c *TimeoutConnection) checkContext() error {
|
|
if c.ctx == nil {
|
|
return nil
|
|
}
|
|
select {
|
|
case <-c.ctx.Done():
|
|
if err := c.ctx.Err(); err != nil {
|
|
return err
|
|
} else {
|
|
return ErrTotalTimeout
|
|
}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (c *TimeoutConnection) SetDefaults() *TimeoutConnection {
|
|
if c.BytesReadLimit == 0 {
|
|
c.BytesReadLimit = DefaultBytesReadLimit
|
|
}
|
|
if c.ReadLimitExceededAction == ReadLimitExceededActionNotSet {
|
|
c.ReadLimitExceededAction = DefaultReadLimitExceededAction
|
|
}
|
|
if c.Timeout == 0 {
|
|
c.Timeout = DefaultSessionTimeout
|
|
}
|
|
return c
|
|
}
|
|
|
|
func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeout, writeTimeout time.Duration) *TimeoutConnection {
|
|
ret := (&TimeoutConnection{
|
|
Conn: conn,
|
|
Timeout: timeout,
|
|
ReadTimeout: readTimeout,
|
|
WriteTimeout: writeTimeout,
|
|
}).SetDefaults()
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
ret.ctx, ret.Cancel = context.WithTimeout(ctx, timeout)
|
|
return ret
|
|
}
|
|
|
|
// DialTimeoutConnectionEx dials the target and returns a net.Conn that uses the configured timeouts for Read/Write operations.
|
|
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration) (net.Conn, error) {
|
|
var conn net.Conn
|
|
var err error
|
|
if dialTimeout > 0 {
|
|
conn, err = net.DialTimeout(proto, target, dialTimeout)
|
|
} else {
|
|
conn, err = net.DialTimeout(proto, target, sessionTimeout)
|
|
}
|
|
if err != nil {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
return nil, err
|
|
}
|
|
return NewTimeoutConnection(context.Background(), conn, sessionTimeout, readTimeout, writeTimeout), nil
|
|
}
|
|
|
|
// DialTimeoutConnection dials the target and returns a net.Conn that uses the configured single timeout for all operations.
|
|
func DialTimeoutConnection(proto string, target string, timeout time.Duration) (net.Conn, error) {
|
|
return DialTimeoutConnectionEx(proto, target, timeout, timeout, timeout, timeout)
|
|
}
|
|
|
|
// Dialer provides Dial and DialContext methods to get connections with the given timeout.
|
|
type Dialer struct {
|
|
// Timeout is the maximum time to wait for the entire session, after which any operations on the
|
|
// connection will fail.
|
|
Timeout time.Duration
|
|
|
|
// ConnectTimeout is the maximum time to wait for a connection.
|
|
ConnectTimeout time.Duration
|
|
|
|
// ReadTimeout is the maximum time to wait for a Read
|
|
ReadTimeout time.Duration
|
|
|
|
// WriteTimeout is the maximum time to wait for a Write
|
|
WriteTimeout time.Duration
|
|
|
|
// Dialer is an auxiliary dialer used for DialContext (the result gets wrapped in a
|
|
// TimeoutConnection).
|
|
Dialer *net.Dialer
|
|
|
|
// BytesReadLimit is the maximum number of bytes that connections dialed with this dialer will
|
|
// read before erroring.
|
|
BytesReadLimit int
|
|
|
|
// ReadLimitExceededAction describes how connections dialed with this dialer deal with exceeding
|
|
// the BytesReadLimit.
|
|
ReadLimitExceededAction ReadLimitExceededAction
|
|
}
|
|
|
|
func (d *Dialer) getTimeout(field time.Duration) time.Duration {
|
|
if field == 0 {
|
|
return d.Timeout
|
|
}
|
|
return field
|
|
}
|
|
|
|
// DialContext wraps the connection returned by net.Dialer.DialContext() with a TimeoutConnection.
|
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
if d.Timeout != 0 {
|
|
ctx, _ = context.WithTimeout(ctx, d.Timeout)
|
|
}
|
|
// ensure that our aux dialer is up-to-date; copied from http/transport.go
|
|
d.Dialer.Timeout = d.getTimeout(d.ConnectTimeout)
|
|
d.Dialer.KeepAlive = d.Timeout
|
|
dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout)
|
|
defer cancelDial()
|
|
conn, err := d.Dialer.DialContext(dialContext, network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := NewTimeoutConnection(ctx, conn, d.Timeout, d.ReadTimeout, d.WriteTimeout)
|
|
ret.BytesReadLimit = d.BytesReadLimit
|
|
ret.ReadLimitExceededAction = d.ReadLimitExceededAction
|
|
return ret, nil
|
|
}
|
|
|
|
// Dial returns a connection with the configured timeout.
|
|
func (d *Dialer) Dial(proto string, target string) (net.Conn, error) {
|
|
return DialTimeoutConnectionEx(proto, target, d.ConnectTimeout, d.Timeout, d.ReadTimeout, d.WriteTimeout)
|
|
}
|
|
|
|
// GetTimeoutConnectionDialer gets a Dialer that dials connections with the given timeout.
|
|
func GetTimeoutConnectionDialer(timeout time.Duration) *Dialer {
|
|
return NewDialer(&Dialer{Timeout: timeout})
|
|
}
|
|
|
|
// SetDefaults for the Dialer.
|
|
func (d *Dialer) SetDefaults() *Dialer {
|
|
if d.Timeout == 0 {
|
|
d.Timeout = DefaultSessionTimeout
|
|
}
|
|
if d.ReadLimitExceededAction == ReadLimitExceededActionNotSet {
|
|
d.ReadLimitExceededAction = DefaultReadLimitExceededAction
|
|
}
|
|
if d.BytesReadLimit == 0 {
|
|
d.BytesReadLimit = DefaultBytesReadLimit
|
|
}
|
|
if d.Dialer == nil {
|
|
d.Dialer = &net.Dialer{
|
|
Timeout: d.Timeout,
|
|
KeepAlive: d.Timeout,
|
|
DualStack: true,
|
|
}
|
|
}
|
|
return d
|
|
}
|
|
|
|
// Create a new Dialer with default settings.
|
|
func NewDialer(value *Dialer) *Dialer {
|
|
if value == nil {
|
|
value = &Dialer{}
|
|
}
|
|
return value.SetDefaults()
|
|
}
|