397 lines
12 KiB
Go
397 lines
12 KiB
Go
package zgrab2
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// Start a local echo server on port.
|
|
func runEchoServer(t *testing.T, port int) {
|
|
endpoint := fmt.Sprintf("127.0.0.1:%d", port)
|
|
listener, err := net.Listen("tcp", endpoint)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
go func() {
|
|
defer listener.Close()
|
|
sock, err := listener.Accept()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer sock.Close()
|
|
|
|
buf := make([]byte, 1024)
|
|
for {
|
|
n, err := sock.Read(buf)
|
|
if err != nil {
|
|
if err != io.EOF && !strings.Contains(err.Error(), "connection reset") {
|
|
t.Fatal(err)
|
|
}
|
|
return
|
|
}
|
|
sock.SetWriteDeadline(time.Now().Add(time.Millisecond * 250))
|
|
n, err = sock.Write(buf[0:n])
|
|
if err != nil {
|
|
if err != io.EOF && !strings.Contains(err.Error(), "connection reset") && !strings.Contains(err.Error(), "broken pipe") {
|
|
t.Logf("Unexpected error writing to client: %v", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Interface for getting a TimeoutConnection; we want to test both the dialer and the direct Dial functions.
|
|
type timeoutConnector interface {
|
|
connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error)
|
|
getConfig() readLimitTestConfig
|
|
}
|
|
|
|
// Config for a test case
|
|
type readLimitTestConfig struct {
|
|
// The maximum bytes the connection should read
|
|
limit int
|
|
// The number of bytes that should be sent (so iff sendSize > limit, the action should be triggered)
|
|
sendSize int
|
|
// The action to run when too much data is sent
|
|
action ReadLimitExceededAction
|
|
}
|
|
|
|
// Call sendReceive(), and check that the input/output match, and that any expected errors / truncation occurs.
|
|
func checkedSendReceive(t *testing.T, conn *TimeoutConnection, size int) (result error) {
|
|
// helper to report + return an error
|
|
tErrorf := func(format string, args ...interface{}) error {
|
|
result = fmt.Errorf(format, args)
|
|
t.Error(result)
|
|
return result
|
|
}
|
|
|
|
// We will check that this increases by the correct size
|
|
before := conn.BytesRead
|
|
|
|
// This is true if we expect an overflow to occur (and so the ReadLimitExceededAction should fire)
|
|
overflowed := (before + size) > conn.BytesReadLimit
|
|
|
|
// Don't want to keep re-typing this
|
|
action := conn.ReadLimitExceededAction
|
|
|
|
defer func() {
|
|
if result != nil {
|
|
// log any previous error -- more may still follow
|
|
t.Error(result)
|
|
}
|
|
err := recover()
|
|
if err != nil {
|
|
if action != ReadLimitExceededActionPanic {
|
|
// no reason to panic unless that is the action
|
|
panic(err)
|
|
}
|
|
if !overflowed {
|
|
tErrorf("panicked early: only sent %d bytes so far, but limit=%d", before+size, conn.BytesReadLimit)
|
|
return
|
|
}
|
|
if err == ErrReadLimitExceeded {
|
|
// We read too much data and this is the right error: silently succeed
|
|
return
|
|
}
|
|
tErrorf("wrong panic error: got %v, expected ErrReadlimitExceeded", err)
|
|
return
|
|
}
|
|
|
|
if action != ReadLimitExceededActionPanic {
|
|
// other action -- fine that we didn't panic
|
|
return
|
|
}
|
|
if !overflowed {
|
|
// not enough bytes read to overflow -- fine that we didn't panic
|
|
return
|
|
}
|
|
// ReadLimitExceededActionPanic, read too many bytes: should have panicked but didn't
|
|
tErrorf("should have panicked: action=ReadLimitExceededActionPanic, but sent without issue")
|
|
}()
|
|
|
|
ret, err := sendReceive(t, conn, size)
|
|
|
|
if err != nil {
|
|
if !overflowed {
|
|
// If there is no overflow, there should be no error
|
|
return tErrorf("read: unexpected error: %v", err)
|
|
}
|
|
if err != io.EOF && err != ErrReadLimitExceeded {
|
|
// EOF and ErrReadLimitExceeded are the only errors that should be returned
|
|
return tErrorf("read: wrong error: %v", err)
|
|
}
|
|
if err == io.EOF && action != ReadLimitExceededActionTruncate {
|
|
// EOF should only occur with truncation
|
|
return tErrorf("read: unexpected EOF")
|
|
}
|
|
if err == ErrReadLimitExceeded && action != ReadLimitExceededActionError {
|
|
// ErrReadLimitExceeded should only occur with ReadLimitExceededActionError
|
|
return tErrorf("read: unexpected ErrReadLimitExceeded")
|
|
}
|
|
// Otherwise, fall through -- we still need to check that the data matches
|
|
} else {
|
|
if overflowed && action == ReadLimitExceededActionError {
|
|
return tErrorf("read: should have gotten an error, but did not")
|
|
}
|
|
}
|
|
expectedSize := size
|
|
if overflowed {
|
|
expectedSize = conn.BytesReadLimit - before
|
|
}
|
|
|
|
if conn.BytesRead != before+expectedSize {
|
|
return tErrorf("check: BytesRead value inconsistent; expected %d, got %d", before+expectedSize, conn.BytesRead)
|
|
}
|
|
if len(ret) != expectedSize {
|
|
return tErrorf("check: expected %d bytes, got %d", expectedSize, len(ret))
|
|
}
|
|
if expectedSize > 0 && !checkTestBuffer(ret) {
|
|
return tErrorf("Got back invalid data (%x)", ret)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Send size testBuffer bytes to conn, then perform a read, and return the result/error.
|
|
func sendReceive(t *testing.T, conn *TimeoutConnection, size int) ([]byte, error) {
|
|
toSend := getTestBuffer(size)
|
|
n, err := conn.Write(toSend)
|
|
if err != nil {
|
|
t.Fatalf("Send failed: %v", err)
|
|
return nil, err
|
|
}
|
|
if n != len(toSend) {
|
|
t.Fatalf("Short write: expected to send %d bytes, returned %d", len(toSend), n)
|
|
return nil, io.ErrShortWrite
|
|
}
|
|
readBuf := make([]byte, size)
|
|
n, err = conn.Read(readBuf)
|
|
return readBuf[0:n], err
|
|
}
|
|
|
|
// Get a size-byte slice of sequential bytes (mod 256), starting from 0
|
|
func getTestBuffer(size int) []byte {
|
|
ret := make([]byte, size)
|
|
for i := 0; i < size; i++ {
|
|
ret[i] = byte(i & 0xff)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// Check that buf is of the type returned by getTestBuffer.
|
|
func checkTestBuffer(buf []byte) bool {
|
|
if buf == nil || len(buf) == 0 {
|
|
return false
|
|
}
|
|
for i, v := range buf {
|
|
if v != byte(i&0xff) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Send / receive cfg.sendSize bytes in a single shot and check that it behaves appropriately.
|
|
func (cfg readLimitTestConfig) runSingleSend(t *testing.T, conn *TimeoutConnection, idx int) error {
|
|
if err := checkedSendReceive(t, conn, cfg.sendSize); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Send / receive cfg.sendSize bytes, split over five sends, and check that it behaves appropriately.
|
|
func (cfg readLimitTestConfig) runMultiSend(t *testing.T, conn *TimeoutConnection, idx int) error {
|
|
for i := 0; i < 5; i++ {
|
|
if err := checkedSendReceive(t, conn, cfg.sendSize/5); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// A timeoutConnector that uses a dialer to dial the connections
|
|
type dialerConnector struct {
|
|
readLimitTestConfig
|
|
|
|
// This is lazily inited
|
|
dialer *Dialer
|
|
}
|
|
|
|
// Function that returns a connector
|
|
type timeoutConnectorFactory func(readLimitTestConfig) timeoutConnector
|
|
|
|
// Dial the connection using the dialer (creating the dialer if necessary)
|
|
func (d *dialerConnector) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) {
|
|
if d.dialer == nil {
|
|
d.dialer = NewDialer(&Dialer{
|
|
BytesReadLimit: d.limit,
|
|
ReadLimitExceededAction: d.action,
|
|
})
|
|
}
|
|
var ret *TimeoutConnection
|
|
conn, err := d.dialer.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
|
if conn != nil {
|
|
ret = conn.(*TimeoutConnection)
|
|
}
|
|
return ret, err
|
|
}
|
|
|
|
func (d *dialerConnector) getConfig() readLimitTestConfig {
|
|
return d.readLimitTestConfig
|
|
}
|
|
|
|
func dialerTimeoutConnectorFactory(cfg readLimitTestConfig) timeoutConnector {
|
|
return &dialerConnector{
|
|
readLimitTestConfig: cfg,
|
|
}
|
|
}
|
|
|
|
// Dial using a direct call to DialTimeoutConnectionEx
|
|
type directDial struct {
|
|
readLimitTestConfig
|
|
}
|
|
|
|
func (d *directDial) connect(ctx context.Context, t *testing.T, port int, idx int) (*TimeoutConnection, error) {
|
|
conn, err := DialTimeoutConnectionEx("tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second, time.Second, time.Second, time.Second, d.limit)
|
|
var ret *TimeoutConnection
|
|
if conn != nil {
|
|
ret = conn.(*TimeoutConnection)
|
|
ret.BytesReadLimit = d.limit
|
|
ret.ReadLimitExceededAction = d.action
|
|
}
|
|
return ret, err
|
|
}
|
|
|
|
func (d *directDial) getConfig() readLimitTestConfig {
|
|
return d.readLimitTestConfig
|
|
}
|
|
|
|
func directDialFactory(cfg readLimitTestConfig) timeoutConnector {
|
|
return &directDial{cfg}
|
|
}
|
|
|
|
var readLimitTestConfigs = map[string]readLimitTestConfig{
|
|
// Check that a 2000-byte read gets truncated at 1000 bytes
|
|
"truncate": {
|
|
limit: 1000,
|
|
sendSize: 2000,
|
|
action: ReadLimitExceededActionTruncate,
|
|
},
|
|
|
|
// Check that a 1005-byte read gets truncated at 1000 bytes
|
|
"truncate_close": {
|
|
limit: 1000,
|
|
sendSize: 1005,
|
|
action: ReadLimitExceededActionTruncate,
|
|
},
|
|
|
|
// Check that a 2000-byte read errors after reading the first 1000 bytes
|
|
"error": {
|
|
limit: 1000,
|
|
sendSize: 2000,
|
|
action: ReadLimitExceededActionError,
|
|
},
|
|
|
|
// Check that a 2000-byte read panics after reading the first 1000 bytes
|
|
"panic": {
|
|
limit: 1000,
|
|
sendSize: 2000,
|
|
action: ReadLimitExceededActionPanic,
|
|
},
|
|
|
|
// Check that the default settings pass (backwards compatibility)
|
|
"default": {},
|
|
|
|
// Check that a 100-byte read succeeds / is not truncated
|
|
"happy": {
|
|
limit: 1000,
|
|
sendSize: 100,
|
|
action: ReadLimitExceededActionPanic,
|
|
},
|
|
|
|
// Check that a 1000-byte read succeeds / is not truncated
|
|
"closeCall": {
|
|
limit: 1000,
|
|
sendSize: 1000,
|
|
action: ReadLimitExceededActionPanic,
|
|
},
|
|
}
|
|
|
|
// Each of these gets run with each readLimitTestConfig
|
|
var connTestConnectors = map[string]timeoutConnectorFactory{
|
|
"directDial": directDialFactory,
|
|
"dialerConnector": dialerTimeoutConnectorFactory,
|
|
}
|
|
|
|
// Run a single full trial with the given connector: connect, send/receive the configured bytes, and
|
|
// check that the response was properly truncated (or not), and that the bytes read total is
|
|
// correctly tabulated.
|
|
func runBytesReadLimitTrial(t *testing.T, connector timeoutConnector, idx int, method func(readLimitTestConfig, *testing.T, *TimeoutConnection, int) error) (result error) {
|
|
cfg := connector.getConfig()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
port := 0x1234 + idx
|
|
runEchoServer(t, port)
|
|
conn, err := connector.connect(ctx, t, port, idx)
|
|
if err != nil {
|
|
t.Fatalf("Error dialing: %v", err)
|
|
}
|
|
expectedSize := cfg.sendSize
|
|
if expectedSize > cfg.limit {
|
|
expectedSize = cfg.limit
|
|
}
|
|
defer func() {
|
|
if conn.BytesRead != expectedSize {
|
|
result = fmt.Errorf("BytesRead(%d) != expected(%d)", conn.BytesRead, expectedSize)
|
|
t.Error(result)
|
|
}
|
|
}()
|
|
defer conn.Close()
|
|
return method(cfg, t, conn, idx)
|
|
}
|
|
|
|
// Run a full set of trials on the connector -- ten with a single send, and ten with multiple sends.
|
|
func testBytesReadLimitOn(t *testing.T, connector timeoutConnector) error {
|
|
for i := 0; i < 10; i++ {
|
|
if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runSingleSend); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
if err := runBytesReadLimitTrial(t, connector, i, readLimitTestConfig.runMultiSend); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Check that the BytesReadLimit is enforced (or not) as expected:
|
|
// 1. Create an echo server
|
|
// 2. Dial a fresh TimeoutConnection to the echo server with the given BytesReadLimit / ReadLimitExceededAction
|
|
// 3. Send the configured number of bytes in a single packet
|
|
// 4. Check that it (succeeds / truncates / errors / panics) according to the config
|
|
// 5. Repeat 10 times
|
|
// 6. Repeat the above 10 more times, except in #3, split the send across five packets
|
|
func TestBytesReadLimit(t *testing.T) {
|
|
connectors := make(map[string]timeoutConnector)
|
|
// Create a fresh connector for each configuration
|
|
for cfgName, cfg := range readLimitTestConfigs {
|
|
for connectorName, factory := range connTestConnectors {
|
|
connectors[connectorName+"_"+cfgName] = factory(cfg)
|
|
}
|
|
}
|
|
|
|
// Run each connector
|
|
for name, connector := range connectors {
|
|
t.Logf("Running %s", name)
|
|
if err := testBytesReadLimitOn(t, connector); err != nil {
|
|
t.Logf("Failed running %s: %v", name, err)
|
|
}
|
|
}
|
|
}
|