golint; godocs; add TNSDriver to handle multiple TNSHeader formats
This commit is contained in:
parent
e06794df8b
commit
b8d979e3b1
@ -16,7 +16,7 @@ import (
|
||||
var (
|
||||
// ErrInvalidData is returned when the server returns syntactically-invalid
|
||||
// (or very unlikely / problematic) data.
|
||||
ErrInvalidData error = errors.New("server returned invalid data")
|
||||
ErrInvalidData = errors.New("server returned invalid data")
|
||||
|
||||
// ErrInvalidInput is returned when user-supplied data is not valid.
|
||||
ErrInvalidInput = errors.New("caller provided invalid input")
|
||||
@ -27,7 +27,7 @@ var (
|
||||
|
||||
// ErrBufferTooSmall is returned when the caller provides a buffer that is
|
||||
// too small for the required data.
|
||||
ErrBufferTooSmall error = errors.New("buffer too small")
|
||||
ErrBufferTooSmall = errors.New("buffer too small")
|
||||
)
|
||||
|
||||
// References:
|
||||
@ -131,16 +131,68 @@ func (reader *sliceReader) Read(output []byte) (int, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// TNSMode determines the format of the TNSHeader; used in TNSDriver.
|
||||
type TNSMode int
|
||||
|
||||
const (
|
||||
// TNSModeOld uses the pre-12c format TNSHeader, with 16-bit lengths.
|
||||
TNSModeOld TNSMode = 0
|
||||
|
||||
// TNSMode12c uses the newer TNSHeader format, with 32-bit lengths and no
|
||||
// PacketChecksum.
|
||||
TNSMode12c = 1
|
||||
)
|
||||
|
||||
// TNSDriver abstracts the bottom-level TNS packet encoding.
|
||||
type TNSDriver struct {
|
||||
// Mode determines what type of packets will be sent -- TNSModeOld or
|
||||
// TNSMode12c.
|
||||
Mode TNSMode
|
||||
}
|
||||
|
||||
// EncodePacket encodes the packet (header + body). If header is nil, create one
|
||||
// with no flags and the type set to the body's type. If header.Length == 0, set
|
||||
// it to the appropriate value (length of encoded body + 8).
|
||||
func (driver *TNSDriver) EncodePacket(packet *TNSPacket) []byte {
|
||||
body := packet.Body.Encode()
|
||||
if packet.Header == nil {
|
||||
packet.Header = &TNSHeader{
|
||||
mode: driver.Mode,
|
||||
Length: 0,
|
||||
PacketChecksum: 0,
|
||||
Type: packet.Body.GetType(),
|
||||
// Flags -- aka Reserved Byte -- is "04" in some Connect packets?
|
||||
Flags: 0,
|
||||
HeaderChecksum: 0,
|
||||
}
|
||||
}
|
||||
if packet.Header.Length == 0 {
|
||||
// It is up to the user to check the body length for overflows before calling Encode
|
||||
if driver.Mode == TNSModeOld {
|
||||
if (len(body) + 8) > 0xffff {
|
||||
panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body)))
|
||||
}
|
||||
packet.Header.Length = uint32(len(body) + 8)
|
||||
} else {
|
||||
packet.Header.Length = uint32(len(body) + 8)
|
||||
}
|
||||
}
|
||||
header := packet.Header.Encode()
|
||||
return append(header, body...)
|
||||
}
|
||||
|
||||
// TNSFlags is the type for the TNS header's flags.
|
||||
type TNSFlags uint8
|
||||
|
||||
// TNSHeader is the 8-byte header that precedes all TNS packets.
|
||||
type TNSHeader struct {
|
||||
mode TNSMode
|
||||
|
||||
// Length is the big-endian length of the entire packet, including the 8
|
||||
// bytes of the header itself.
|
||||
// For versions prior to 12(c?), the length is a uint16. For newer versions,
|
||||
// it is a uint32 (taking the place of the PacketChecksum)
|
||||
Length uint16
|
||||
Length uint32
|
||||
|
||||
// PacketChecksum is in practice set to 0.
|
||||
PacketChecksum uint16
|
||||
@ -159,37 +211,47 @@ type TNSHeader struct {
|
||||
func (header *TNSHeader) Encode() []byte {
|
||||
ret := make([]byte, 8)
|
||||
next := outputBuffer(ret)
|
||||
next.pushU16(header.Length)
|
||||
switch header.mode {
|
||||
case TNSModeOld:
|
||||
if header.Length > 0xffff {
|
||||
panic(ErrInvalidData)
|
||||
}
|
||||
next.pushU16(uint16(header.Length))
|
||||
next.pushU16(header.PacketChecksum)
|
||||
next.pushU8(byte(header.Type))
|
||||
next.pushU8(byte(header.Flags))
|
||||
next.pushU16(header.HeaderChecksum)
|
||||
case TNSMode12c:
|
||||
next.pushU32(header.Length)
|
||||
next.pushU8(byte(header.Type))
|
||||
next.pushU8(byte(header.Flags))
|
||||
next.pushU16(header.HeaderChecksum)
|
||||
default:
|
||||
panic(fmt.Errorf("Bad TNSDriver mode 0x%x", header.mode))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// DecodeTNSHeader reads the header from the first 8 bytes of buf.
|
||||
// The decoded header is returned as well as a slice pointing past the end of
|
||||
// the header in buf. On failure, returns nil/nil/error.
|
||||
func DecodeTNSHeader(buf []byte) (*TNSHeader, []byte, error) {
|
||||
if len(buf) < 8 {
|
||||
return nil, nil, ErrBufferTooSmall
|
||||
}
|
||||
ret, err := ReadTNSHeader(getSliceReader(buf))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return ret, buf[8:], nil
|
||||
}
|
||||
|
||||
// ReadTNSHeader reads/decodes a TNSHeader from the first 8 bytes of the stream.
|
||||
func ReadTNSHeader(reader io.Reader) (*TNSHeader, error) {
|
||||
func (driver *TNSDriver) ReadTNSHeader(reader io.Reader) (*TNSHeader, error) {
|
||||
ret := TNSHeader{}
|
||||
ret.mode = driver.Mode
|
||||
next := startReading(reader)
|
||||
next.read(&ret.Length)
|
||||
switch driver.Mode {
|
||||
case TNSModeOld:
|
||||
var length uint16
|
||||
next.read(&length)
|
||||
ret.Length = uint32(length)
|
||||
next.read(&ret.PacketChecksum)
|
||||
next.read(&ret.Type)
|
||||
next.read(&ret.Flags)
|
||||
next.read(&ret.HeaderChecksum)
|
||||
case TNSMode12c:
|
||||
next.read(&ret.Length)
|
||||
next.read(&ret.Type)
|
||||
next.read(&ret.Flags)
|
||||
next.read(&ret.HeaderChecksum)
|
||||
}
|
||||
if err := next.Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -813,7 +875,7 @@ func ReadTNSRefuse(reader io.Reader, header *TNSHeader) (*TNSRefuse, error) {
|
||||
if err := next.Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ret.DataLength != header.Length-8-4 {
|
||||
if uint32(ret.DataLength) != header.Length-8-4 {
|
||||
return nil, ErrInvalidData
|
||||
}
|
||||
return ret, nil
|
||||
@ -884,10 +946,10 @@ func (v ReleaseVersion) Bytes() []byte {
|
||||
// EncodeReleaseVersion gets a ReleaseVersion instance from its dotted-decimal
|
||||
// representation, e.g.:
|
||||
// EncodeReleaseVersion("64.3.2.1.0") = ReleaseVersion(0x40320100).
|
||||
func EncodeReleaseVersion(value string) ReleaseVersion {
|
||||
func EncodeReleaseVersion(value string) (ReleaseVersion, error) {
|
||||
parts := strings.Split(value, ".")
|
||||
if len(parts) != 5 {
|
||||
panic(ErrInvalidInput)
|
||||
return 0, ErrInvalidInput
|
||||
}
|
||||
numbers := make([]uint32, 5)
|
||||
maxValue := []int{
|
||||
@ -900,14 +962,22 @@ func EncodeReleaseVersion(value string) ReleaseVersion {
|
||||
for i, v := range parts {
|
||||
n, err := strconv.ParseUint(v, 10, 16)
|
||||
if err != nil {
|
||||
panic(ErrInvalidInput)
|
||||
return 0, ErrInvalidInput
|
||||
}
|
||||
if int(n) > maxValue[i] {
|
||||
panic(ErrInvalidInput)
|
||||
return 0, ErrInvalidInput
|
||||
}
|
||||
numbers[i] = uint32(n)
|
||||
}
|
||||
return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4])
|
||||
return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4]), nil
|
||||
}
|
||||
|
||||
func encodeReleaseVersion(value string) ReleaseVersion {
|
||||
ret, err := EncodeReleaseVersion(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// DataFlags is a 16-bit flags field used in the TNSData packet.
|
||||
@ -1146,7 +1216,7 @@ const (
|
||||
// NSNValueTypeUB2 identifies an unsigned 16-bit big-endian integer.
|
||||
NSNValueTypeUB2 = 3
|
||||
|
||||
// NSNValueTypeUB2 identifies an unsigned 32-bit big-endian integer.
|
||||
// NSNValueTypeUB4 identifies an unsigned 32-bit big-endian integer.
|
||||
NSNValueTypeUB4 = 4
|
||||
|
||||
// NSNValueTypeVersion identifies a 32-bit ReleaseVersion value.
|
||||
@ -1222,7 +1292,7 @@ func (value *NSNValue) MarshalJSON() ([]byte, error) {
|
||||
func NSNValueVersion(v string) *NSNValue {
|
||||
return &NSNValue{
|
||||
Type: NSNValueTypeVersion,
|
||||
Value: EncodeReleaseVersion(v).Bytes(),
|
||||
Value: encodeReleaseVersion(v).Bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1397,6 +1467,9 @@ func ReadTNSDataNSN(reader io.Reader) (*TNSDataNSN, error) {
|
||||
}
|
||||
next.read(&ret.Version)
|
||||
n, err := next.readU16()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n >= 0x0100 {
|
||||
// arbitrary but certainly sufficiently-high value -- n here is the
|
||||
// number of "services", which is typically 4.
|
||||
@ -1438,7 +1511,7 @@ type TNSPacket struct {
|
||||
// Encode the packet (header + body). If header is nil, create one with no flags
|
||||
// and the type set to the body's type. If header.Length == 0, set it to the
|
||||
// appropriate value (length of encoded body + 8).
|
||||
func (packet *TNSPacket) Encode() []byte {
|
||||
func (packet *TNSPacket) oldEncode() []byte {
|
||||
body := packet.Body.Encode()
|
||||
if packet.Header == nil {
|
||||
packet.Header = &TNSHeader{
|
||||
@ -1451,8 +1524,11 @@ func (packet *TNSPacket) Encode() []byte {
|
||||
}
|
||||
}
|
||||
if packet.Header.Length == 0 {
|
||||
if len(body)+8 > 0xffff {
|
||||
panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body)))
|
||||
}
|
||||
// It is up to the user to check the body length for overflows before calling Encode
|
||||
packet.Header.Length = uint16(len(body) + 8)
|
||||
packet.Header.Length = uint32(len(body) + 8)
|
||||
}
|
||||
header := packet.Header.Encode()
|
||||
return append(header, body...)
|
||||
@ -1460,10 +1536,10 @@ func (packet *TNSPacket) Encode() []byte {
|
||||
|
||||
// ReadTNSPacket reads a TNSPacket from the stream, or returns nil + an error
|
||||
// if one cannot be read.
|
||||
func ReadTNSPacket(reader io.Reader) (*TNSPacket, error) {
|
||||
func (driver *TNSDriver) ReadTNSPacket(reader io.Reader) (*TNSPacket, error) {
|
||||
var body TNSPacketBody
|
||||
var err error
|
||||
header, err := ReadTNSHeader(reader)
|
||||
header, err := driver.ReadTNSHeader(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1497,9 +1573,9 @@ type DescriptorEntry struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// Oracle "Descriptors" are nested series of parens, used for e.g.
|
||||
// connect descriptors and for error strings.
|
||||
// To simplify their usage in searches, they are stored in a flattened form.
|
||||
// Descriptor is a nested series of parens, used for e.g. connect descriptors
|
||||
// and for error responses. To simplify their usage in searches, they are stored
|
||||
// in a flattened form.
|
||||
// Since duplicate keys are allowed, a simple map will not work, so instead
|
||||
// it stores an list of key/value pairs, in the order they appear in the string.
|
||||
// NOTE: This is insufficient to re-construct the input (since there is no way
|
||||
|
Loading…
Reference in New Issue
Block a user