golint; godocs; add TNSDriver to handle multiple TNSHeader formats
This commit is contained in:
parent
e06794df8b
commit
b8d979e3b1
@ -16,7 +16,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
// ErrInvalidData is returned when the server returns syntactically-invalid
|
// ErrInvalidData is returned when the server returns syntactically-invalid
|
||||||
// (or very unlikely / problematic) data.
|
// (or very unlikely / problematic) data.
|
||||||
ErrInvalidData error = errors.New("server returned invalid data")
|
ErrInvalidData = errors.New("server returned invalid data")
|
||||||
|
|
||||||
// ErrInvalidInput is returned when user-supplied data is not valid.
|
// ErrInvalidInput is returned when user-supplied data is not valid.
|
||||||
ErrInvalidInput = errors.New("caller provided invalid input")
|
ErrInvalidInput = errors.New("caller provided invalid input")
|
||||||
@ -27,7 +27,7 @@ var (
|
|||||||
|
|
||||||
// ErrBufferTooSmall is returned when the caller provides a buffer that is
|
// ErrBufferTooSmall is returned when the caller provides a buffer that is
|
||||||
// too small for the required data.
|
// too small for the required data.
|
||||||
ErrBufferTooSmall error = errors.New("buffer too small")
|
ErrBufferTooSmall = errors.New("buffer too small")
|
||||||
)
|
)
|
||||||
|
|
||||||
// References:
|
// References:
|
||||||
@ -131,16 +131,68 @@ func (reader *sliceReader) Read(output []byte) (int, error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TNSMode determines the format of the TNSHeader; used in TNSDriver.
|
||||||
|
type TNSMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TNSModeOld uses the pre-12c format TNSHeader, with 16-bit lengths.
|
||||||
|
TNSModeOld TNSMode = 0
|
||||||
|
|
||||||
|
// TNSMode12c uses the newer TNSHeader format, with 32-bit lengths and no
|
||||||
|
// PacketChecksum.
|
||||||
|
TNSMode12c = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// TNSDriver abstracts the bottom-level TNS packet encoding.
|
||||||
|
type TNSDriver struct {
|
||||||
|
// Mode determines what type of packets will be sent -- TNSModeOld or
|
||||||
|
// TNSMode12c.
|
||||||
|
Mode TNSMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodePacket encodes the packet (header + body). If header is nil, create one
|
||||||
|
// with no flags and the type set to the body's type. If header.Length == 0, set
|
||||||
|
// it to the appropriate value (length of encoded body + 8).
|
||||||
|
func (driver *TNSDriver) EncodePacket(packet *TNSPacket) []byte {
|
||||||
|
body := packet.Body.Encode()
|
||||||
|
if packet.Header == nil {
|
||||||
|
packet.Header = &TNSHeader{
|
||||||
|
mode: driver.Mode,
|
||||||
|
Length: 0,
|
||||||
|
PacketChecksum: 0,
|
||||||
|
Type: packet.Body.GetType(),
|
||||||
|
// Flags -- aka Reserved Byte -- is "04" in some Connect packets?
|
||||||
|
Flags: 0,
|
||||||
|
HeaderChecksum: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if packet.Header.Length == 0 {
|
||||||
|
// It is up to the user to check the body length for overflows before calling Encode
|
||||||
|
if driver.Mode == TNSModeOld {
|
||||||
|
if (len(body) + 8) > 0xffff {
|
||||||
|
panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body)))
|
||||||
|
}
|
||||||
|
packet.Header.Length = uint32(len(body) + 8)
|
||||||
|
} else {
|
||||||
|
packet.Header.Length = uint32(len(body) + 8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
header := packet.Header.Encode()
|
||||||
|
return append(header, body...)
|
||||||
|
}
|
||||||
|
|
||||||
// TNSFlags is the type for the TNS header's flags.
|
// TNSFlags is the type for the TNS header's flags.
|
||||||
type TNSFlags uint8
|
type TNSFlags uint8
|
||||||
|
|
||||||
// TNSHeader is the 8-byte header that precedes all TNS packets.
|
// TNSHeader is the 8-byte header that precedes all TNS packets.
|
||||||
type TNSHeader struct {
|
type TNSHeader struct {
|
||||||
|
mode TNSMode
|
||||||
|
|
||||||
// Length is the big-endian length of the entire packet, including the 8
|
// Length is the big-endian length of the entire packet, including the 8
|
||||||
// bytes of the header itself.
|
// bytes of the header itself.
|
||||||
// For versions prior to 12(c?), the length is a uint16. For newer versions,
|
// For versions prior to 12(c?), the length is a uint16. For newer versions,
|
||||||
// it is a uint32 (taking the place of the PacketChecksum)
|
// it is a uint32 (taking the place of the PacketChecksum)
|
||||||
Length uint16
|
Length uint32
|
||||||
|
|
||||||
// PacketChecksum is in practice set to 0.
|
// PacketChecksum is in practice set to 0.
|
||||||
PacketChecksum uint16
|
PacketChecksum uint16
|
||||||
@ -159,37 +211,47 @@ type TNSHeader struct {
|
|||||||
func (header *TNSHeader) Encode() []byte {
|
func (header *TNSHeader) Encode() []byte {
|
||||||
ret := make([]byte, 8)
|
ret := make([]byte, 8)
|
||||||
next := outputBuffer(ret)
|
next := outputBuffer(ret)
|
||||||
next.pushU16(header.Length)
|
switch header.mode {
|
||||||
next.pushU16(header.PacketChecksum)
|
case TNSModeOld:
|
||||||
next.pushU8(byte(header.Type))
|
if header.Length > 0xffff {
|
||||||
next.pushU8(byte(header.Flags))
|
panic(ErrInvalidData)
|
||||||
next.pushU16(header.HeaderChecksum)
|
}
|
||||||
|
next.pushU16(uint16(header.Length))
|
||||||
|
next.pushU16(header.PacketChecksum)
|
||||||
|
next.pushU8(byte(header.Type))
|
||||||
|
next.pushU8(byte(header.Flags))
|
||||||
|
next.pushU16(header.HeaderChecksum)
|
||||||
|
case TNSMode12c:
|
||||||
|
next.pushU32(header.Length)
|
||||||
|
next.pushU8(byte(header.Type))
|
||||||
|
next.pushU8(byte(header.Flags))
|
||||||
|
next.pushU16(header.HeaderChecksum)
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("Bad TNSDriver mode 0x%x", header.mode))
|
||||||
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeTNSHeader reads the header from the first 8 bytes of buf.
|
|
||||||
// The decoded header is returned as well as a slice pointing past the end of
|
|
||||||
// the header in buf. On failure, returns nil/nil/error.
|
|
||||||
func DecodeTNSHeader(buf []byte) (*TNSHeader, []byte, error) {
|
|
||||||
if len(buf) < 8 {
|
|
||||||
return nil, nil, ErrBufferTooSmall
|
|
||||||
}
|
|
||||||
ret, err := ReadTNSHeader(getSliceReader(buf))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return ret, buf[8:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadTNSHeader reads/decodes a TNSHeader from the first 8 bytes of the stream.
|
// ReadTNSHeader reads/decodes a TNSHeader from the first 8 bytes of the stream.
|
||||||
func ReadTNSHeader(reader io.Reader) (*TNSHeader, error) {
|
func (driver *TNSDriver) ReadTNSHeader(reader io.Reader) (*TNSHeader, error) {
|
||||||
ret := TNSHeader{}
|
ret := TNSHeader{}
|
||||||
|
ret.mode = driver.Mode
|
||||||
next := startReading(reader)
|
next := startReading(reader)
|
||||||
next.read(&ret.Length)
|
switch driver.Mode {
|
||||||
next.read(&ret.PacketChecksum)
|
case TNSModeOld:
|
||||||
next.read(&ret.Type)
|
var length uint16
|
||||||
next.read(&ret.Flags)
|
next.read(&length)
|
||||||
next.read(&ret.HeaderChecksum)
|
ret.Length = uint32(length)
|
||||||
|
next.read(&ret.PacketChecksum)
|
||||||
|
next.read(&ret.Type)
|
||||||
|
next.read(&ret.Flags)
|
||||||
|
next.read(&ret.HeaderChecksum)
|
||||||
|
case TNSMode12c:
|
||||||
|
next.read(&ret.Length)
|
||||||
|
next.read(&ret.Type)
|
||||||
|
next.read(&ret.Flags)
|
||||||
|
next.read(&ret.HeaderChecksum)
|
||||||
|
}
|
||||||
if err := next.Error(); err != nil {
|
if err := next.Error(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -813,7 +875,7 @@ func ReadTNSRefuse(reader io.Reader, header *TNSHeader) (*TNSRefuse, error) {
|
|||||||
if err := next.Error(); err != nil {
|
if err := next.Error(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if ret.DataLength != header.Length-8-4 {
|
if uint32(ret.DataLength) != header.Length-8-4 {
|
||||||
return nil, ErrInvalidData
|
return nil, ErrInvalidData
|
||||||
}
|
}
|
||||||
return ret, nil
|
return ret, nil
|
||||||
@ -884,10 +946,10 @@ func (v ReleaseVersion) Bytes() []byte {
|
|||||||
// EncodeReleaseVersion gets a ReleaseVersion instance from its dotted-decimal
|
// EncodeReleaseVersion gets a ReleaseVersion instance from its dotted-decimal
|
||||||
// representation, e.g.:
|
// representation, e.g.:
|
||||||
// EncodeReleaseVersion("64.3.2.1.0") = ReleaseVersion(0x40320100).
|
// EncodeReleaseVersion("64.3.2.1.0") = ReleaseVersion(0x40320100).
|
||||||
func EncodeReleaseVersion(value string) ReleaseVersion {
|
func EncodeReleaseVersion(value string) (ReleaseVersion, error) {
|
||||||
parts := strings.Split(value, ".")
|
parts := strings.Split(value, ".")
|
||||||
if len(parts) != 5 {
|
if len(parts) != 5 {
|
||||||
panic(ErrInvalidInput)
|
return 0, ErrInvalidInput
|
||||||
}
|
}
|
||||||
numbers := make([]uint32, 5)
|
numbers := make([]uint32, 5)
|
||||||
maxValue := []int{
|
maxValue := []int{
|
||||||
@ -900,14 +962,22 @@ func EncodeReleaseVersion(value string) ReleaseVersion {
|
|||||||
for i, v := range parts {
|
for i, v := range parts {
|
||||||
n, err := strconv.ParseUint(v, 10, 16)
|
n, err := strconv.ParseUint(v, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(ErrInvalidInput)
|
return 0, ErrInvalidInput
|
||||||
}
|
}
|
||||||
if int(n) > maxValue[i] {
|
if int(n) > maxValue[i] {
|
||||||
panic(ErrInvalidInput)
|
return 0, ErrInvalidInput
|
||||||
}
|
}
|
||||||
numbers[i] = uint32(n)
|
numbers[i] = uint32(n)
|
||||||
}
|
}
|
||||||
return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4])
|
return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeReleaseVersion(value string) ReleaseVersion {
|
||||||
|
ret, err := EncodeReleaseVersion(value)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataFlags is a 16-bit flags field used in the TNSData packet.
|
// DataFlags is a 16-bit flags field used in the TNSData packet.
|
||||||
@ -1146,7 +1216,7 @@ const (
|
|||||||
// NSNValueTypeUB2 identifies an unsigned 16-bit big-endian integer.
|
// NSNValueTypeUB2 identifies an unsigned 16-bit big-endian integer.
|
||||||
NSNValueTypeUB2 = 3
|
NSNValueTypeUB2 = 3
|
||||||
|
|
||||||
// NSNValueTypeUB2 identifies an unsigned 32-bit big-endian integer.
|
// NSNValueTypeUB4 identifies an unsigned 32-bit big-endian integer.
|
||||||
NSNValueTypeUB4 = 4
|
NSNValueTypeUB4 = 4
|
||||||
|
|
||||||
// NSNValueTypeVersion identifies a 32-bit ReleaseVersion value.
|
// NSNValueTypeVersion identifies a 32-bit ReleaseVersion value.
|
||||||
@ -1222,7 +1292,7 @@ func (value *NSNValue) MarshalJSON() ([]byte, error) {
|
|||||||
func NSNValueVersion(v string) *NSNValue {
|
func NSNValueVersion(v string) *NSNValue {
|
||||||
return &NSNValue{
|
return &NSNValue{
|
||||||
Type: NSNValueTypeVersion,
|
Type: NSNValueTypeVersion,
|
||||||
Value: EncodeReleaseVersion(v).Bytes(),
|
Value: encodeReleaseVersion(v).Bytes(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1397,6 +1467,9 @@ func ReadTNSDataNSN(reader io.Reader) (*TNSDataNSN, error) {
|
|||||||
}
|
}
|
||||||
next.read(&ret.Version)
|
next.read(&ret.Version)
|
||||||
n, err := next.readU16()
|
n, err := next.readU16()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if n >= 0x0100 {
|
if n >= 0x0100 {
|
||||||
// arbitrary but certainly sufficiently-high value -- n here is the
|
// arbitrary but certainly sufficiently-high value -- n here is the
|
||||||
// number of "services", which is typically 4.
|
// number of "services", which is typically 4.
|
||||||
@ -1438,7 +1511,7 @@ type TNSPacket struct {
|
|||||||
// Encode the packet (header + body). If header is nil, create one with no flags
|
// Encode the packet (header + body). If header is nil, create one with no flags
|
||||||
// and the type set to the body's type. If header.Length == 0, set it to the
|
// and the type set to the body's type. If header.Length == 0, set it to the
|
||||||
// appropriate value (length of encoded body + 8).
|
// appropriate value (length of encoded body + 8).
|
||||||
func (packet *TNSPacket) Encode() []byte {
|
func (packet *TNSPacket) oldEncode() []byte {
|
||||||
body := packet.Body.Encode()
|
body := packet.Body.Encode()
|
||||||
if packet.Header == nil {
|
if packet.Header == nil {
|
||||||
packet.Header = &TNSHeader{
|
packet.Header = &TNSHeader{
|
||||||
@ -1451,8 +1524,11 @@ func (packet *TNSPacket) Encode() []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if packet.Header.Length == 0 {
|
if packet.Header.Length == 0 {
|
||||||
|
if len(body)+8 > 0xffff {
|
||||||
|
panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body)))
|
||||||
|
}
|
||||||
// It is up to the user to check the body length for overflows before calling Encode
|
// It is up to the user to check the body length for overflows before calling Encode
|
||||||
packet.Header.Length = uint16(len(body) + 8)
|
packet.Header.Length = uint32(len(body) + 8)
|
||||||
}
|
}
|
||||||
header := packet.Header.Encode()
|
header := packet.Header.Encode()
|
||||||
return append(header, body...)
|
return append(header, body...)
|
||||||
@ -1460,10 +1536,10 @@ func (packet *TNSPacket) Encode() []byte {
|
|||||||
|
|
||||||
// ReadTNSPacket reads a TNSPacket from the stream, or returns nil + an error
|
// ReadTNSPacket reads a TNSPacket from the stream, or returns nil + an error
|
||||||
// if one cannot be read.
|
// if one cannot be read.
|
||||||
func ReadTNSPacket(reader io.Reader) (*TNSPacket, error) {
|
func (driver *TNSDriver) ReadTNSPacket(reader io.Reader) (*TNSPacket, error) {
|
||||||
var body TNSPacketBody
|
var body TNSPacketBody
|
||||||
var err error
|
var err error
|
||||||
header, err := ReadTNSHeader(reader)
|
header, err := driver.ReadTNSHeader(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1497,9 +1573,9 @@ type DescriptorEntry struct {
|
|||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Oracle "Descriptors" are nested series of parens, used for e.g.
|
// Descriptor is a nested series of parens, used for e.g. connect descriptors
|
||||||
// connect descriptors and for error strings.
|
// and for error responses. To simplify their usage in searches, they are stored
|
||||||
// To simplify their usage in searches, they are stored in a flattened form.
|
// in a flattened form.
|
||||||
// Since duplicate keys are allowed, a simple map will not work, so instead
|
// Since duplicate keys are allowed, a simple map will not work, so instead
|
||||||
// it stores an list of key/value pairs, in the order they appear in the string.
|
// it stores an list of key/value pairs, in the order they appear in the string.
|
||||||
// NOTE: This is insufficient to re-construct the input (since there is no way
|
// NOTE: This is insufficient to re-construct the input (since there is no way
|
||||||
|
Loading…
Reference in New Issue
Block a user