use TNSDriver in tests

This commit is contained in:
Justin Bastress 2018-02-27 16:26:32 -05:00
parent b8d979e3b1
commit 7a61e3b2a9

@ -143,7 +143,7 @@ var validTNSData = map[string]TestCase{
DataFlags: 0, DataFlags: 0,
Data: (&TNSDataNSN{ Data: (&TNSDataNSN{
ID: 0xdeadbeef, ID: 0xdeadbeef,
Version: EncodeReleaseVersion("10.2.0.3.0"), Version: encodeReleaseVersion("10.2.0.3.0"),
Options: NSNOptions(0), Options: NSNOptions(0),
Services: []NSNService{ Services: []NSNService{
NSNService{ NSNService{
@ -451,23 +451,25 @@ func min(a, b int) int {
return b return b
} }
func getTNSDriver() *TNSDriver {
return &TNSDriver{Mode: TNSModeOld}
}
// TODO: TNSRedirect // TODO: TNSRedirect
// TODO: Invalid cases // TODO: Invalid cases
func TestTNSHeaderEncode(t *testing.T) { func TestTNSHeaderEncode(t *testing.T) {
driver := getTNSDriver()
for hex, header := range validHeaders { for hex, header := range validHeaders {
bin := fromHex(hex) bin := fromHex(hex)
encoded := header.Encode() encoded := header.Encode()
if !bytes.Equal(bin, encoded) { if !bytes.Equal(bin, encoded) {
t.Errorf("TNSHeader.Encode mismatch:[\n%s\n]", interleave(bin, encoded)) t.Errorf("TNSHeader.Encode mismatch:[\n%s\n]", interleave(bin, encoded))
} }
decoded, rest, err := DecodeTNSHeader(bin) decoded, err := driver.ReadTNSHeader(getSliceReader(bin))
if err != nil { if err != nil {
t.Fatalf("Decode error: %v", err) t.Fatalf("Decode error: %v", err)
} }
if len(rest) > 0 {
t.Fatalf("Leftover data (%d bytes)", len(rest))
}
jsonHeader := serialize(header) jsonHeader := serialize(header)
jsonDecoded := serialize(decoded) jsonDecoded := serialize(decoded)
if !bytes.Equal(jsonHeader, jsonDecoded) { if !bytes.Equal(jsonHeader, jsonDecoded) {
@ -477,14 +479,15 @@ func TestTNSHeaderEncode(t *testing.T) {
} }
func TestTNSConnect(t *testing.T) { func TestTNSConnect(t *testing.T) {
driver := getTNSDriver()
for tag, info := range validTNSConnect { for tag, info := range validTNSConnect {
bin := fromHex(info.Encoding) bin := fromHex(info.Encoding)
encoded := info.Value.Encode() encoded := driver.EncodePacket(info.Value)
if !bytes.Equal(bin, encoded) { if !bytes.Equal(bin, encoded) {
t.Errorf("%s: TNSConnect.Encode mismatch:[\n%s\n]", tag, interleave(bin, encoded)) t.Errorf("%s: TNSConnect.Encode mismatch:[\n%s\n]", tag, interleave(bin, encoded))
} }
reader := getSliceReader(bin) reader := getSliceReader(bin)
response, err := ReadTNSPacket(reader) response, err := driver.ReadTNSPacket(reader)
if err != nil { if err != nil {
t.Fatalf("%s: Error reading TNSConnect packet: %v", tag, err) t.Fatalf("%s: Error reading TNSConnect packet: %v", tag, err)
} }
@ -505,14 +508,15 @@ func TestTNSConnect(t *testing.T) {
} }
func TestTNSAccept(t *testing.T) { func TestTNSAccept(t *testing.T) {
driver := getTNSDriver()
for tag, info := range validTNSAccept { for tag, info := range validTNSAccept {
bin := fromHex(info.Encoding) bin := fromHex(info.Encoding)
encoded := info.Value.Encode() encoded := driver.EncodePacket(info.Value)
if !bytes.Equal(bin, encoded) { if !bytes.Equal(bin, encoded) {
t.Errorf("%s: TNSAccept.Encode mismatch:[\n%s\n]", tag, interleave(bin, encoded)) t.Errorf("%s: TNSAccept.Encode mismatch:[\n%s\n]", tag, interleave(bin, encoded))
} }
reader := getSliceReader(bin) reader := getSliceReader(bin)
response, err := ReadTNSPacket(reader) response, err := driver.ReadTNSPacket(reader)
if err != nil { if err != nil {
t.Fatalf("%s: Error reading TNSAccept packet: %v", tag, err) t.Fatalf("%s: Error reading TNSAccept packet: %v", tag, err)
} }
@ -533,14 +537,15 @@ func TestTNSAccept(t *testing.T) {
} }
func TestTNSData(t *testing.T) { func TestTNSData(t *testing.T) {
driver := getTNSDriver()
for tag, info := range validTNSData { for tag, info := range validTNSData {
bin := fromHex(info.Encoding) bin := fromHex(info.Encoding)
encoded := info.Value.Encode() encoded := driver.EncodePacket(info.Value)
if !bytes.Equal(bin, encoded) { if !bytes.Equal(bin, encoded) {
t.Errorf("%s: TNSData.Encode mismatch:[\n%s\n]", tag, interleave(bin, encoded)) t.Errorf("%s: TNSData.Encode mismatch:[\n%s\n]", tag, interleave(bin, encoded))
} }
reader := getSliceReader(bin) reader := getSliceReader(bin)
response, err := ReadTNSPacket(reader) response, err := driver.ReadTNSPacket(reader)
if err != nil { if err != nil {
t.Fatalf("%s: Error reading TNSData packet: %v", tag, err) t.Fatalf("%s: Error reading TNSData packet: %v", tag, err)
} }
@ -682,9 +687,9 @@ func TestDescriptorGetValue(t *testing.T) {
func removeSpace(s string) string { func removeSpace(s string) string {
ret := strings.Replace(s, "\r", "", -1) ret := strings.Replace(s, "\r", "", -1)
ret = strings.Replace(s, "\n", "", -1) ret = strings.Replace(ret, "\n", "", -1)
ret = strings.Replace(s, "\t", "", -1) ret = strings.Replace(ret, "\t", "", -1)
ret = strings.Replace(s, " ", "", -1) ret = strings.Replace(ret, " ", "", -1)
return ret return ret
} }