zgrab2/utility.go

229 lines
6.6 KiB
Go

package zgrab2
import (
"errors"
"net"
"regexp"
"strconv"
"strings"
"time"
"github.com/zmap/zflags"
"github.com/sirupsen/logrus"
"runtime/debug"
)
var parser *flags.Parser
func init() {
parser = flags.NewParser(&config, flags.Default)
}
// NewIniParser creates and returns a ini parser initialized
// with the default parser
func NewIniParser() *flags.IniParser {
return flags.NewIniParser(parser)
}
// AddGroup exposes the parser's AddGroup function, allowing extension
// of the global arguments.
func AddGroup(shortDescription string, longDescription string, data interface{}) {
parser.AddGroup(shortDescription, longDescription, data)
}
// AddCommand adds a module to the parser and returns a pointer to
// a flags.command object or an error
func AddCommand(command string, shortDescription string, longDescription string, port int, m ScanModule) (*flags.Command, error) {
cmd, err := parser.AddCommand(command, shortDescription, longDescription, m)
if err != nil {
return nil, err
}
cmd.FindOptionByLongName("port").Default = []string{strconv.FormatUint(uint64(port), 10)}
cmd.FindOptionByLongName("name").Default = []string{command}
modules[command] = m
return cmd, nil
}
// ParseCommandLine parses the commands given on the command line
// and validates the framework configuration (global options)
// immediately after parsing
func ParseCommandLine(flags []string) ([]string, string, ScanFlags, error) {
posArgs, moduleType, f, err := parser.ParseCommandLine(flags)
if err == nil {
validateFrameworkConfiguration()
}
sf, _ := f.(ScanFlags)
return posArgs, moduleType, sf, err
}
// ReadAvaiable reads what it can without blocking for more than
// defaultReadTimeout per read, or defaultTotalTimeout for the whole session.
// Reads at most defaultMaxReadSize bytes.
func ReadAvailable(conn net.Conn) ([]byte, error) {
const defaultReadTimeout = 10 * time.Millisecond
const defaultMaxReadSize = 1024 * 512
// if the buffer size exactly matches the number of bytes returned, we hit
// a corner case where we attempt to read even though there is nothing
// available. Otherwise we should be able to return without blocking at all.
// So -- it's better to be large than small, but the worst case is getting
// the exact right number of bytes.
const defaultBufferSize = 8209
return ReadAvailableWithOptions(conn, defaultBufferSize, defaultReadTimeout, 0, defaultMaxReadSize)
}
// Make this implement the net.Error interface so that err.(net.Error).Timeout() works.
type errTotalTimeout string
const (
ErrTotalTimeout = errTotalTimeout("timeout")
)
func (err errTotalTimeout) Error() string {
return string(err)
}
func (err errTotalTimeout) Timeout() bool {
return true
}
func (err errTotalTimeout) Temporary() bool {
return false
}
// ReadAvailableWithOptions reads whatever can be read (up to maxReadSize) from
// conn without blocking for longer than readTimeout per read, or totalTimeout
// for the entire session. A totalTimeout of 0 means attempt to use the
// connection's timeout (or, failing that, 1 second).
// On failure, returns anything it was able to read along with the error.
func ReadAvailableWithOptions(conn net.Conn, bufferSize int, readTimeout time.Duration, totalTimeout time.Duration, maxReadSize int) ([]byte, error) {
min := func(a, b int) int {
if a < b {
return a
}
return b
}
var totalDeadline time.Time
if totalTimeout == 0 {
// Would be nice if this could be taken from the SetReadDeadline(), but that's not possible in general
const defaultTotalTimeout = 1 * time.Second
totalTimeout = defaultTotalTimeout
timeoutConn, isTimeoutConn := conn.(*TimeoutConnection)
if isTimeoutConn {
totalTimeout = timeoutConn.Timeout
}
}
if totalTimeout > 0 {
totalDeadline = time.Now().Add(totalTimeout)
}
buf := make([]byte, bufferSize)
ret := make([]byte, 0)
// The first read will use any pre-assigned deadlines.
n, err := conn.Read(buf[0:min(bufferSize, maxReadSize)])
ret = append(ret, buf[0:n]...)
if err != nil || n >= maxReadSize {
return ret, err
}
maxReadSize -= n
// If there were more than bufSize -1 bytes available, read whatever is
// available without blocking longer than timeout, and do not treat timeouts
// as an error.
// Keep reading until we time out or get an error.
for totalDeadline.IsZero() || totalDeadline.After(time.Now()) {
deadline := time.Now().Add(readTimeout)
conn.SetReadDeadline(deadline)
n, err := conn.Read(buf[0:min(maxReadSize, bufferSize)])
maxReadSize -= n
ret = append(ret, buf[0:n]...)
if err != nil {
if IsTimeoutError(err) {
err = nil
}
return ret, err
}
if err != nil {
return ret, err
}
if n >= maxReadSize {
return ret, err
}
}
return ret, ErrTotalTimeout
}
var InsufficientBufferError = errors.New("not enough buffer space")
// ReadUntilRegex calls connection.Read() until it returns an error, or the cumulatively-read data matches the given regexp
func ReadUntilRegex(connection net.Conn, res []byte, expr *regexp.Regexp) (int, error) {
buf := res[0:]
length := 0
for finished := false; !finished; {
n, err := connection.Read(buf)
length += n
if err != nil {
return length, err
}
if expr.Match(res[0:length]) {
finished = true
}
if length == len(res) {
return length, InsufficientBufferError
}
buf = res[length:]
}
return length, nil
}
// TLDMatches checks for a strict TLD match
func TLDMatches(host1 string, host2 string) bool {
splitStr1 := strings.Split(stripPortNumber(host1), ".")
splitStr2 := strings.Split(stripPortNumber(host2), ".")
tld1 := splitStr1[len(splitStr1)-1]
tld2 := splitStr2[len(splitStr2)-1]
return tld1 == tld2
}
func stripPortNumber(host string) string {
return strings.Split(host, ":")[0]
}
type timeoutError interface {
Timeout() bool
}
// IsTimeoutError checks if the given error corresponds to a timeout (of any type).
func IsTimeoutError(err error) bool {
if err == nil {
return false
}
if cast, ok := err.(timeoutError); ok {
return cast.Timeout()
}
if cast, ok := err.(*ScanError); ok {
return cast.Status == SCAN_IO_TIMEOUT || cast.Status == SCAN_CONNECTION_TIMEOUT
}
return false
}
// LogPanic is intended to be called from within defer -- if there was no panic, it returns without
// doing anything. Otherwise, it logs the stacktrace, the panic error, and the provided message
// before re-raising the original panic.
// Example:
// defer zgrab2.LogPanic("Error decoding body '%x'", body)
func LogPanic(format string, args...interface{}) {
err := recover()
if err == nil {
return
}
logrus.Errorf("Uncaught panic at %s: %v", string(debug.Stack()), err)
logrus.Errorf(format, args...)
panic(err)
}