zgrab2/conn_timeout_test.go

439 lines
12 KiB
Go

package zgrab2
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"testing"
"time"
"github.com/sirupsen/logrus"
)
// Config for a single timeout test
type connTimeoutTestConfig struct {
// Name for the test for logging purposes
name string
// Optional explicit endpoint to connect to (if absent, use 127.0.0.1)
endpoint string
// TCP port number to communicate on
port int
// Function used to dial new connections
dialer func() (*TimeoutConnection, error)
// Client timeout values
timeout time.Duration
connectTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
// Time for server to wait after listening before accepting a connection
acceptDelay time.Duration
// Time for server to wait after accepting before writing payload
writeDelay time.Duration
// Time for server to wait before reading payload
readDelay time.Duration
// Payload for server to send client after connecting
serverToClientPayload []byte
// Payload for client to send server after reading the previous payload
clientToServerPayload []byte
// Step when the client is expected to fail
failStep testStep
// If non-empty, the error string returned by the client should contain this
failError string
}
// Standardized time units, separated by factors of 10.
const (
short = 100 * time.Millisecond
medium = 1000 * time.Millisecond
long = 10000 * time.Millisecond
)
// enum type for the various locations where the test can fail
type testStep string
const (
testStepConnect = testStep("connect")
testStepRead = testStep("read")
testStepWrite = testStep("write")
testStepDone = testStep("done")
)
// Encapsulates a source for an error (client/server/???), the step where it occurred, and an
// optional cause.
type timeoutTestError struct {
source string
step testStep
cause error
}
func (err *timeoutTestError) Error() string {
return fmt.Sprintf("%s error at %s: %v", err.source, err.step, err.cause)
}
func serverError(step testStep, err error) *timeoutTestError {
return &timeoutTestError{
source: "server",
step: step,
cause: err,
}
}
func clientError(step testStep, err error) *timeoutTestError {
return &timeoutTestError{
source: "client",
step: step,
cause: err,
}
}
// Helper to ensure all data is written to a socket
func _write(writer io.Writer, data []byte) error {
n, err := writer.Write(data)
if err == nil && n != len(data) {
err = io.ErrShortWrite
}
return err
}
// Run the configured server. As soon as it returns, it is listening.
// Returns a channel that receives a timeoutTestError on error, or is closed on successful completion.
func (cfg *connTimeoutTestConfig) runServer(t *testing.T) chan *timeoutTestError {
errorChan := make(chan *timeoutTestError)
if cfg.endpoint != "" {
// Only listen on localhost
return errorChan
}
listener, err := net.Listen("tcp", cfg.getEndpoint())
if err != nil {
logrus.Fatalf("Error listening: %v", err)
}
go func() {
defer func() {
if err = listener.Close(); err != nil {
t.Errorf("%s: error closing connection: %v", cfg.name, err)
}
}()
defer close(errorChan)
time.Sleep(cfg.acceptDelay)
sock, err := listener.Accept()
if err != nil {
errorChan <- serverError(testStepConnect, err)
return
}
defer func() {
if err = sock.Close(); err != nil {
t.Errorf("%s: error closing connection: %v", cfg.name, err)
}
}()
time.Sleep(cfg.writeDelay)
if err := _write(sock, cfg.serverToClientPayload); err != nil {
errorChan <- serverError(testStepWrite, err)
return
}
time.Sleep(cfg.readDelay)
buf := make([]byte, len(cfg.clientToServerPayload))
n, err := io.ReadFull(sock, buf)
if err != nil && err != io.EOF {
errorChan <- serverError(testStepRead, err)
return
}
if err == io.EOF && n < len(buf) {
errorChan <- serverError(testStepRead, err)
return
}
if !bytes.Equal(buf, cfg.clientToServerPayload) {
t.Errorf("%s: clientToServerPayload mismatch", cfg.name)
}
return
}()
return errorChan
}
// Get the configured endpoint
func (cfg *connTimeoutTestConfig) getEndpoint() string {
if cfg.endpoint != "" {
return cfg.endpoint
}
return fmt.Sprintf("127.0.0.1:%d", cfg.port)
}
// Dial a connection to the configured endpoint using a Dialer
func (cfg *connTimeoutTestConfig) dialerDial() (*TimeoutConnection, error) {
dialer := NewDialer(&Dialer{
Timeout: cfg.timeout,
ConnectTimeout: cfg.connectTimeout,
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
})
ret, err := dialer.Dial("tcp", cfg.getEndpoint())
if err != nil {
return nil, err
}
return ret.(*TimeoutConnection), err
}
// Dial a connection to the configured endpoint using a DialTimeoutConnectionEx
func (cfg *connTimeoutTestConfig) directDial() (*TimeoutConnection, error) {
ret, err := DialTimeoutConnectionEx("tcp", cfg.getEndpoint(), cfg.connectTimeout, cfg.timeout, cfg.readTimeout, cfg.writeTimeout, 0)
if err != nil {
return nil, err
}
return ret.(*TimeoutConnection), err
}
// Dial a connection to the configured endpoint using Dialer.DialContext
func (cfg *connTimeoutTestConfig) contextDial() (*TimeoutConnection, error) {
dialer := NewDialer(&Dialer{
Timeout: cfg.timeout,
ConnectTimeout: cfg.connectTimeout,
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
})
ret, err := dialer.DialContext(context.Background(), "tcp", cfg.getEndpoint())
if err != nil {
return nil, err
}
return ret.(*TimeoutConnection), err
}
// Run the client: connect to the server, read the payload, write the payload, disconnect.
func (cfg *connTimeoutTestConfig) runClient(t *testing.T) (testStep, error) {
conn, err := cfg.dialer()
if err != nil {
return testStepConnect, err
}
defer func() {
if err = conn.Close(); err != nil {
t.Errorf("%s: error closing connection: %v", cfg.name, err)
}
}()
buf := make([]byte, len(cfg.serverToClientPayload))
_, err = io.ReadFull(conn, buf)
if err != nil {
return testStepRead, err
}
if !bytes.Equal(cfg.serverToClientPayload, buf) {
t.Errorf("%s: serverToClientPayload payload mismatch", cfg.name)
}
if err := _write(conn, cfg.clientToServerPayload); err != nil {
return testStepWrite, err
}
return testStepDone, nil
}
// Run the configured test -- start a server and a client to connect to it.
func (cfg *connTimeoutTestConfig) run(t *testing.T) {
done := make(chan *timeoutTestError)
serverError := cfg.runServer(t)
go func() {
defer func() {
if err := recover(); err != nil {
close(done)
panic(err)
}
}()
step, err := cfg.runClient(t)
done <- clientError(step, err)
}()
go func() {
time.Sleep(long + medium + short)
done <- &timeoutTestError{source: "timeout"}
}()
var ret *timeoutTestError
select {
case err := <-serverError:
if err != nil && !errors.Is(err.cause, io.EOF) {
t.Fatalf("%s: Server error: %v", cfg.name, err)
}
ret = <-done
case ret = <-done:
if ret == nil {
t.Fatalf("Channel unexpectedly closed")
}
}
if ret.source != "client" {
t.Fatalf("%s: Unexpected error from %s: %v", cfg.name, ret.source, ret.cause)
}
if ret.step != cfg.failStep {
t.Errorf("%s: Failed at step %s, but expected to fail at step %s (error=%v)", cfg.name, ret.step, cfg.failStep, ret.cause)
return
}
if cfg.failError != "" {
errString := "none"
if ret.cause != nil {
errString = ret.cause.Error()
}
if !strings.Contains(errString, cfg.failError) {
t.Errorf("%s: Expected an error (%s) at step %s, got %s", cfg.name, cfg.failError, cfg.failStep, errString)
return
}
} else if ret.cause != nil {
t.Errorf("%s: expected no error at step %s, but got %v", cfg.name, cfg.failStep, ret.cause)
}
}
var connTestConfigs = []connTimeoutTestConfig{
// Long timeouts, short delays -- should succeed
{
name: "happy",
port: 0x5613,
timeout: long,
connectTimeout: medium,
readTimeout: medium,
writeTimeout: medium,
acceptDelay: short,
writeDelay: short,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepDone,
},
// long session timeout, short connectTimeout. Use a non-local, nonexistent endpoint (localhost
// would return "connection refused" immediately)
{
name: "connect_timeout",
endpoint: "10.0.254.254:41591",
timeout: long,
connectTimeout: short,
readTimeout: medium,
writeTimeout: medium,
acceptDelay: short,
writeDelay: short,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepConnect,
failError: "i/o timeout",
},
// short session timeout, medium connect timeout, with connect to nonexistent endpoint.
{
name: "session_connect_timeout",
endpoint: "10.0.254.254:41591",
timeout: short,
connectTimeout: medium,
readTimeout: medium,
writeTimeout: medium,
acceptDelay: short,
writeDelay: short,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepConnect,
failError: "i/o timeout",
},
// Get an IO timeout on the read.
// sessionTimeout > acceptDelay + writeDelay > writeDelay > readTimeout
{
name: "read_timeout",
port: 0x5614,
timeout: long,
connectTimeout: short,
readTimeout: short,
writeTimeout: short,
acceptDelay: short,
writeDelay: medium,
readDelay: short,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepRead,
failError: "i/o timeout",
},
// Get a context timeout on a read.
// readTimeout > writeDelay > timeout > acceptDelay
{
name: "session_read_timeout",
port: 0x5615,
timeout: short,
connectTimeout: long,
readTimeout: long,
writeTimeout: long,
acceptDelay: 0,
writeDelay: medium * 2,
readDelay: 0,
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepWrite,
failError: "context deadline exceeded",
},
// Use a session timeout that is longer than any individual action's timeout.
// acceptDelay+writeDelay+readDelay > timeout > acceptDelay >= writeDelay >= readDelay
{
name: "session_timeout",
port: 0x5616,
timeout: medium,
connectTimeout: long,
readTimeout: long,
writeTimeout: long,
acceptDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
writeDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
readDelay: time.Nanosecond * time.Duration(medium.Nanoseconds()/2+short.Nanoseconds()),
serverToClientPayload: []byte("abc"),
clientToServerPayload: []byte("defghi"),
failStep: testStepWrite,
failError: "context deadline exceeded",
},
// TODO: How to test write timeout?
}
// TestTimeoutConnectionTimeouts tests that the TimeoutConnection behaves as expected with respect
// to timeouts.
func TestTimeoutConnectionTimeouts(t *testing.T) {
temp := make([]connTimeoutTestConfig, 0, len(connTestConfigs)*3)
// Make three copies of connTestConfigs, one with each dial method.
for _, cfg := range connTestConfigs {
direct := cfg
dialer := cfg
ctxDialer := cfg
dialer.name = dialer.name + "_dialer"
dialer.port = dialer.port + 100
dialer.dialer = dialer.dialerDial
direct.name = direct.name + "_direct"
direct.port = direct.port + 200
direct.dialer = direct.directDial
ctxDialer.name = ctxDialer.name + "_context"
ctxDialer.port = ctxDialer.port + 300
ctxDialer.dialer = ctxDialer.contextDial
temp = append(temp, direct, dialer, ctxDialer)
}
for _, cfg := range temp {
t.Logf("Running %s", cfg.name)
cfg.run(t)
}
}