golint; godocs; add TNSDriver to handle multiple TNSHeader formats

This commit is contained in:
Justin Bastress 2018-02-27 16:26:16 -05:00
parent e06794df8b
commit b8d979e3b1

@ -16,7 +16,7 @@ import (
var (
// ErrInvalidData is returned when the server returns syntactically-invalid
// (or very unlikely / problematic) data.
ErrInvalidData error = errors.New("server returned invalid data")
ErrInvalidData = errors.New("server returned invalid data")
// ErrInvalidInput is returned when user-supplied data is not valid.
ErrInvalidInput = errors.New("caller provided invalid input")
@ -27,7 +27,7 @@ var (
// ErrBufferTooSmall is returned when the caller provides a buffer that is
// too small for the required data.
ErrBufferTooSmall error = errors.New("buffer too small")
ErrBufferTooSmall = errors.New("buffer too small")
)
// References:
@ -131,16 +131,68 @@ func (reader *sliceReader) Read(output []byte) (int, error) {
return n, nil
}
// TNSMode determines the format of the TNSHeader; used in TNSDriver.
type TNSMode int
const (
// TNSModeOld uses the pre-12c format TNSHeader, with 16-bit lengths.
TNSModeOld TNSMode = 0
// TNSMode12c uses the newer TNSHeader format, with 32-bit lengths and no
// PacketChecksum.
TNSMode12c = 1
)
// TNSDriver abstracts the bottom-level TNS packet encoding.
type TNSDriver struct {
// Mode determines what type of packets will be sent -- TNSModeOld or
// TNSMode12c.
Mode TNSMode
}
// EncodePacket encodes the packet (header + body). If header is nil, create one
// with no flags and the type set to the body's type. If header.Length == 0, set
// it to the appropriate value (length of encoded body + 8).
func (driver *TNSDriver) EncodePacket(packet *TNSPacket) []byte {
body := packet.Body.Encode()
if packet.Header == nil {
packet.Header = &TNSHeader{
mode: driver.Mode,
Length: 0,
PacketChecksum: 0,
Type: packet.Body.GetType(),
// Flags -- aka Reserved Byte -- is "04" in some Connect packets?
Flags: 0,
HeaderChecksum: 0,
}
}
if packet.Header.Length == 0 {
// It is up to the user to check the body length for overflows before calling Encode
if driver.Mode == TNSModeOld {
if (len(body) + 8) > 0xffff {
panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body)))
}
packet.Header.Length = uint32(len(body) + 8)
} else {
packet.Header.Length = uint32(len(body) + 8)
}
}
header := packet.Header.Encode()
return append(header, body...)
}
// TNSFlags is the type for the TNS header's flags.
type TNSFlags uint8
// TNSHeader is the 8-byte header that precedes all TNS packets.
type TNSHeader struct {
mode TNSMode
// Length is the big-endian length of the entire packet, including the 8
// bytes of the header itself.
// For versions prior to 12(c?), the length is a uint16. For newer versions,
// it is a uint32 (taking the place of the PacketChecksum)
Length uint16
Length uint32
// PacketChecksum is in practice set to 0.
PacketChecksum uint16
@ -159,37 +211,47 @@ type TNSHeader struct {
func (header *TNSHeader) Encode() []byte {
ret := make([]byte, 8)
next := outputBuffer(ret)
next.pushU16(header.Length)
next.pushU16(header.PacketChecksum)
next.pushU8(byte(header.Type))
next.pushU8(byte(header.Flags))
next.pushU16(header.HeaderChecksum)
switch header.mode {
case TNSModeOld:
if header.Length > 0xffff {
panic(ErrInvalidData)
}
next.pushU16(uint16(header.Length))
next.pushU16(header.PacketChecksum)
next.pushU8(byte(header.Type))
next.pushU8(byte(header.Flags))
next.pushU16(header.HeaderChecksum)
case TNSMode12c:
next.pushU32(header.Length)
next.pushU8(byte(header.Type))
next.pushU8(byte(header.Flags))
next.pushU16(header.HeaderChecksum)
default:
panic(fmt.Errorf("Bad TNSDriver mode 0x%x", header.mode))
}
return ret
}
// DecodeTNSHeader reads the header from the first 8 bytes of buf.
// The decoded header is returned as well as a slice pointing past the end of
// the header in buf. On failure, returns nil/nil/error.
func DecodeTNSHeader(buf []byte) (*TNSHeader, []byte, error) {
if len(buf) < 8 {
return nil, nil, ErrBufferTooSmall
}
ret, err := ReadTNSHeader(getSliceReader(buf))
if err != nil {
return nil, nil, err
}
return ret, buf[8:], nil
}
// ReadTNSHeader reads/decodes a TNSHeader from the first 8 bytes of the stream.
func ReadTNSHeader(reader io.Reader) (*TNSHeader, error) {
func (driver *TNSDriver) ReadTNSHeader(reader io.Reader) (*TNSHeader, error) {
ret := TNSHeader{}
ret.mode = driver.Mode
next := startReading(reader)
next.read(&ret.Length)
next.read(&ret.PacketChecksum)
next.read(&ret.Type)
next.read(&ret.Flags)
next.read(&ret.HeaderChecksum)
switch driver.Mode {
case TNSModeOld:
var length uint16
next.read(&length)
ret.Length = uint32(length)
next.read(&ret.PacketChecksum)
next.read(&ret.Type)
next.read(&ret.Flags)
next.read(&ret.HeaderChecksum)
case TNSMode12c:
next.read(&ret.Length)
next.read(&ret.Type)
next.read(&ret.Flags)
next.read(&ret.HeaderChecksum)
}
if err := next.Error(); err != nil {
return nil, err
}
@ -813,7 +875,7 @@ func ReadTNSRefuse(reader io.Reader, header *TNSHeader) (*TNSRefuse, error) {
if err := next.Error(); err != nil {
return nil, err
}
if ret.DataLength != header.Length-8-4 {
if uint32(ret.DataLength) != header.Length-8-4 {
return nil, ErrInvalidData
}
return ret, nil
@ -884,10 +946,10 @@ func (v ReleaseVersion) Bytes() []byte {
// EncodeReleaseVersion gets a ReleaseVersion instance from its dotted-decimal
// representation, e.g.:
// EncodeReleaseVersion("64.3.2.1.0") = ReleaseVersion(0x40320100).
func EncodeReleaseVersion(value string) ReleaseVersion {
func EncodeReleaseVersion(value string) (ReleaseVersion, error) {
parts := strings.Split(value, ".")
if len(parts) != 5 {
panic(ErrInvalidInput)
return 0, ErrInvalidInput
}
numbers := make([]uint32, 5)
maxValue := []int{
@ -900,14 +962,22 @@ func EncodeReleaseVersion(value string) ReleaseVersion {
for i, v := range parts {
n, err := strconv.ParseUint(v, 10, 16)
if err != nil {
panic(ErrInvalidInput)
return 0, ErrInvalidInput
}
if int(n) > maxValue[i] {
panic(ErrInvalidInput)
return 0, ErrInvalidInput
}
numbers[i] = uint32(n)
}
return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4])
return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4]), nil
}
func encodeReleaseVersion(value string) ReleaseVersion {
ret, err := EncodeReleaseVersion(value)
if err != nil {
panic(err)
}
return ret
}
// DataFlags is a 16-bit flags field used in the TNSData packet.
@ -1146,7 +1216,7 @@ const (
// NSNValueTypeUB2 identifies an unsigned 16-bit big-endian integer.
NSNValueTypeUB2 = 3
// NSNValueTypeUB2 identifies an unsigned 32-bit big-endian integer.
// NSNValueTypeUB4 identifies an unsigned 32-bit big-endian integer.
NSNValueTypeUB4 = 4
// NSNValueTypeVersion identifies a 32-bit ReleaseVersion value.
@ -1222,7 +1292,7 @@ func (value *NSNValue) MarshalJSON() ([]byte, error) {
func NSNValueVersion(v string) *NSNValue {
return &NSNValue{
Type: NSNValueTypeVersion,
Value: EncodeReleaseVersion(v).Bytes(),
Value: encodeReleaseVersion(v).Bytes(),
}
}
@ -1397,6 +1467,9 @@ func ReadTNSDataNSN(reader io.Reader) (*TNSDataNSN, error) {
}
next.read(&ret.Version)
n, err := next.readU16()
if err != nil {
return nil, err
}
if n >= 0x0100 {
// arbitrary but certainly sufficiently-high value -- n here is the
// number of "services", which is typically 4.
@ -1438,7 +1511,7 @@ type TNSPacket struct {
// Encode the packet (header + body). If header is nil, create one with no flags
// and the type set to the body's type. If header.Length == 0, set it to the
// appropriate value (length of encoded body + 8).
func (packet *TNSPacket) Encode() []byte {
func (packet *TNSPacket) oldEncode() []byte {
body := packet.Body.Encode()
if packet.Header == nil {
packet.Header = &TNSHeader{
@ -1451,8 +1524,11 @@ func (packet *TNSPacket) Encode() []byte {
}
}
if packet.Header.Length == 0 {
if len(body)+8 > 0xffff {
panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body)))
}
// It is up to the user to check the body length for overflows before calling Encode
packet.Header.Length = uint16(len(body) + 8)
packet.Header.Length = uint32(len(body) + 8)
}
header := packet.Header.Encode()
return append(header, body...)
@ -1460,10 +1536,10 @@ func (packet *TNSPacket) Encode() []byte {
// ReadTNSPacket reads a TNSPacket from the stream, or returns nil + an error
// if one cannot be read.
func ReadTNSPacket(reader io.Reader) (*TNSPacket, error) {
func (driver *TNSDriver) ReadTNSPacket(reader io.Reader) (*TNSPacket, error) {
var body TNSPacketBody
var err error
header, err := ReadTNSHeader(reader)
header, err := driver.ReadTNSHeader(reader)
if err != nil {
return nil, err
}
@ -1497,9 +1573,9 @@ type DescriptorEntry struct {
Value string `json:"value"`
}
// Oracle "Descriptors" are nested series of parens, used for e.g.
// connect descriptors and for error strings.
// To simplify their usage in searches, they are stored in a flattened form.
// Descriptor is a nested series of parens, used for e.g. connect descriptors
// and for error responses. To simplify their usage in searches, they are stored
// in a flattened form.
// Since duplicate keys are allowed, a simple map will not work, so instead
// it stores an list of key/value pairs, in the order they appear in the string.
// NOTE: This is insufficient to re-construct the input (since there is no way