diff --git a/modules/oracle/types.go b/modules/oracle/types.go index 840f6f0..fab5a3e 100644 --- a/modules/oracle/types.go +++ b/modules/oracle/types.go @@ -16,7 +16,7 @@ import ( var ( // ErrInvalidData is returned when the server returns syntactically-invalid // (or very unlikely / problematic) data. - ErrInvalidData error = errors.New("server returned invalid data") + ErrInvalidData = errors.New("server returned invalid data") // ErrInvalidInput is returned when user-supplied data is not valid. ErrInvalidInput = errors.New("caller provided invalid input") @@ -27,7 +27,7 @@ var ( // ErrBufferTooSmall is returned when the caller provides a buffer that is // too small for the required data. - ErrBufferTooSmall error = errors.New("buffer too small") + ErrBufferTooSmall = errors.New("buffer too small") ) // References: @@ -131,16 +131,68 @@ func (reader *sliceReader) Read(output []byte) (int, error) { return n, nil } +// TNSMode determines the format of the TNSHeader; used in TNSDriver. +type TNSMode int + +const ( + // TNSModeOld uses the pre-12c format TNSHeader, with 16-bit lengths. + TNSModeOld TNSMode = 0 + + // TNSMode12c uses the newer TNSHeader format, with 32-bit lengths and no + // PacketChecksum. + TNSMode12c = 1 +) + +// TNSDriver abstracts the bottom-level TNS packet encoding. +type TNSDriver struct { + // Mode determines what type of packets will be sent -- TNSModeOld or + // TNSMode12c. + Mode TNSMode +} + +// EncodePacket encodes the packet (header + body). If header is nil, create one +// with no flags and the type set to the body's type. If header.Length == 0, set +// it to the appropriate value (length of encoded body + 8). +func (driver *TNSDriver) EncodePacket(packet *TNSPacket) []byte { + body := packet.Body.Encode() + if packet.Header == nil { + packet.Header = &TNSHeader{ + mode: driver.Mode, + Length: 0, + PacketChecksum: 0, + Type: packet.Body.GetType(), + // Flags -- aka Reserved Byte -- is "04" in some Connect packets? + Flags: 0, + HeaderChecksum: 0, + } + } + if packet.Header.Length == 0 { + // It is up to the user to check the body length for overflows before calling Encode + if driver.Mode == TNSModeOld { + if (len(body) + 8) > 0xffff { + panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body))) + } + packet.Header.Length = uint32(len(body) + 8) + } else { + packet.Header.Length = uint32(len(body) + 8) + } + } + header := packet.Header.Encode() + return append(header, body...) +} + // TNSFlags is the type for the TNS header's flags. type TNSFlags uint8 // TNSHeader is the 8-byte header that precedes all TNS packets. type TNSHeader struct { + mode TNSMode + // Length is the big-endian length of the entire packet, including the 8 // bytes of the header itself. // For versions prior to 12(c?), the length is a uint16. For newer versions, // it is a uint32 (taking the place of the PacketChecksum) - Length uint16 + Length uint32 // PacketChecksum is in practice set to 0. PacketChecksum uint16 @@ -159,37 +211,47 @@ type TNSHeader struct { func (header *TNSHeader) Encode() []byte { ret := make([]byte, 8) next := outputBuffer(ret) - next.pushU16(header.Length) - next.pushU16(header.PacketChecksum) - next.pushU8(byte(header.Type)) - next.pushU8(byte(header.Flags)) - next.pushU16(header.HeaderChecksum) + switch header.mode { + case TNSModeOld: + if header.Length > 0xffff { + panic(ErrInvalidData) + } + next.pushU16(uint16(header.Length)) + next.pushU16(header.PacketChecksum) + next.pushU8(byte(header.Type)) + next.pushU8(byte(header.Flags)) + next.pushU16(header.HeaderChecksum) + case TNSMode12c: + next.pushU32(header.Length) + next.pushU8(byte(header.Type)) + next.pushU8(byte(header.Flags)) + next.pushU16(header.HeaderChecksum) + default: + panic(fmt.Errorf("Bad TNSDriver mode 0x%x", header.mode)) + } return ret } -// DecodeTNSHeader reads the header from the first 8 bytes of buf. -// The decoded header is returned as well as a slice pointing past the end of -// the header in buf. On failure, returns nil/nil/error. -func DecodeTNSHeader(buf []byte) (*TNSHeader, []byte, error) { - if len(buf) < 8 { - return nil, nil, ErrBufferTooSmall - } - ret, err := ReadTNSHeader(getSliceReader(buf)) - if err != nil { - return nil, nil, err - } - return ret, buf[8:], nil -} - // ReadTNSHeader reads/decodes a TNSHeader from the first 8 bytes of the stream. -func ReadTNSHeader(reader io.Reader) (*TNSHeader, error) { +func (driver *TNSDriver) ReadTNSHeader(reader io.Reader) (*TNSHeader, error) { ret := TNSHeader{} + ret.mode = driver.Mode next := startReading(reader) - next.read(&ret.Length) - next.read(&ret.PacketChecksum) - next.read(&ret.Type) - next.read(&ret.Flags) - next.read(&ret.HeaderChecksum) + switch driver.Mode { + case TNSModeOld: + var length uint16 + next.read(&length) + ret.Length = uint32(length) + next.read(&ret.PacketChecksum) + next.read(&ret.Type) + next.read(&ret.Flags) + next.read(&ret.HeaderChecksum) + case TNSMode12c: + next.read(&ret.Length) + next.read(&ret.Type) + next.read(&ret.Flags) + next.read(&ret.HeaderChecksum) + } if err := next.Error(); err != nil { return nil, err } @@ -813,7 +875,7 @@ func ReadTNSRefuse(reader io.Reader, header *TNSHeader) (*TNSRefuse, error) { if err := next.Error(); err != nil { return nil, err } - if ret.DataLength != header.Length-8-4 { + if uint32(ret.DataLength) != header.Length-8-4 { return nil, ErrInvalidData } return ret, nil @@ -884,10 +946,10 @@ func (v ReleaseVersion) Bytes() []byte { // EncodeReleaseVersion gets a ReleaseVersion instance from its dotted-decimal // representation, e.g.: // EncodeReleaseVersion("64.3.2.1.0") = ReleaseVersion(0x40320100). -func EncodeReleaseVersion(value string) ReleaseVersion { +func EncodeReleaseVersion(value string) (ReleaseVersion, error) { parts := strings.Split(value, ".") if len(parts) != 5 { - panic(ErrInvalidInput) + return 0, ErrInvalidInput } numbers := make([]uint32, 5) maxValue := []int{ @@ -900,14 +962,22 @@ func EncodeReleaseVersion(value string) ReleaseVersion { for i, v := range parts { n, err := strconv.ParseUint(v, 10, 16) if err != nil { - panic(ErrInvalidInput) + return 0, ErrInvalidInput } if int(n) > maxValue[i] { - panic(ErrInvalidInput) + return 0, ErrInvalidInput } numbers[i] = uint32(n) } - return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4]) + return ReleaseVersion((numbers[0] << 24) | (numbers[1] << 20) | (numbers[2] << 16) | (numbers[3] << 8) | numbers[4]), nil +} + +func encodeReleaseVersion(value string) ReleaseVersion { + ret, err := EncodeReleaseVersion(value) + if err != nil { + panic(err) + } + return ret } // DataFlags is a 16-bit flags field used in the TNSData packet. @@ -1146,7 +1216,7 @@ const ( // NSNValueTypeUB2 identifies an unsigned 16-bit big-endian integer. NSNValueTypeUB2 = 3 - // NSNValueTypeUB2 identifies an unsigned 32-bit big-endian integer. + // NSNValueTypeUB4 identifies an unsigned 32-bit big-endian integer. NSNValueTypeUB4 = 4 // NSNValueTypeVersion identifies a 32-bit ReleaseVersion value. @@ -1222,7 +1292,7 @@ func (value *NSNValue) MarshalJSON() ([]byte, error) { func NSNValueVersion(v string) *NSNValue { return &NSNValue{ Type: NSNValueTypeVersion, - Value: EncodeReleaseVersion(v).Bytes(), + Value: encodeReleaseVersion(v).Bytes(), } } @@ -1397,6 +1467,9 @@ func ReadTNSDataNSN(reader io.Reader) (*TNSDataNSN, error) { } next.read(&ret.Version) n, err := next.readU16() + if err != nil { + return nil, err + } if n >= 0x0100 { // arbitrary but certainly sufficiently-high value -- n here is the // number of "services", which is typically 4. @@ -1438,7 +1511,7 @@ type TNSPacket struct { // Encode the packet (header + body). If header is nil, create one with no flags // and the type set to the body's type. If header.Length == 0, set it to the // appropriate value (length of encoded body + 8). -func (packet *TNSPacket) Encode() []byte { +func (packet *TNSPacket) oldEncode() []byte { body := packet.Body.Encode() if packet.Header == nil { packet.Header = &TNSHeader{ @@ -1451,8 +1524,11 @@ func (packet *TNSPacket) Encode() []byte { } } if packet.Header.Length == 0 { + if len(body)+8 > 0xffff { + panic(fmt.Errorf("Body too large to fit into 16-bit length (%d bytes)", len(body))) + } // It is up to the user to check the body length for overflows before calling Encode - packet.Header.Length = uint16(len(body) + 8) + packet.Header.Length = uint32(len(body) + 8) } header := packet.Header.Encode() return append(header, body...) @@ -1460,10 +1536,10 @@ func (packet *TNSPacket) Encode() []byte { // ReadTNSPacket reads a TNSPacket from the stream, or returns nil + an error // if one cannot be read. -func ReadTNSPacket(reader io.Reader) (*TNSPacket, error) { +func (driver *TNSDriver) ReadTNSPacket(reader io.Reader) (*TNSPacket, error) { var body TNSPacketBody var err error - header, err := ReadTNSHeader(reader) + header, err := driver.ReadTNSHeader(reader) if err != nil { return nil, err } @@ -1497,9 +1573,9 @@ type DescriptorEntry struct { Value string `json:"value"` } -// Oracle "Descriptors" are nested series of parens, used for e.g. -// connect descriptors and for error strings. -// To simplify their usage in searches, they are stored in a flattened form. +// Descriptor is a nested series of parens, used for e.g. connect descriptors +// and for error responses. To simplify their usage in searches, they are stored +// in a flattened form. // Since duplicate keys are allowed, a simple map will not work, so instead // it stores an list of key/value pairs, in the order they appear in the string. // NOTE: This is insufficient to re-construct the input (since there is no way