From ac91a3e484511b116ef3e2150923e6b9f1b69115 Mon Sep 17 00:00:00 2001 From: Daniel Oaks Date: Thu, 17 Aug 2017 18:23:24 +1000 Subject: [PATCH] strings: Follow latest advice on PRECIS regarding string stabilizing --- irc/strings.go | 24 +++++++++++++++++++++--- irc/strings_test.go | 25 +++++++++++++++---------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/irc/strings.go b/irc/strings.go index d8f09502..fc46a740 100644 --- a/irc/strings.go +++ b/irc/strings.go @@ -17,13 +17,31 @@ const ( ) var ( - errInvalidCharacter = errors.New("Invalid character") - errEmpty = errors.New("String is empty") + errCouldNotStabilize = errors.New("Could not stabilize string while casefolding") + errInvalidCharacter = errors.New("Invalid character") + errEmpty = errors.New("String is empty") ) // Casefold returns a casefolded string, without doing any name or channel character checks. func Casefold(str string) (string, error) { - return precis.UsernameCaseMapped.CompareKey(str) + var err error + oldStr := str + // follow the stabilizing rules laid out here: + // https://tools.ietf.org/html/draft-ietf-precis-7564bis-10.html#section-7 + for i := 0; i < 4; i++ { + str, err = precis.UsernameCaseMapped.CompareKey(str) + if err != nil { + return "", err + } + if oldStr == str { + break + } + oldStr = str + } + if oldStr != str { + return "", errCouldNotStabilize + } + return str, nil } // CasefoldChannel returns a casefolded version of a channel name. diff --git a/irc/strings_test.go b/irc/strings_test.go index 808f6ca3..7f23af83 100644 --- a/irc/strings_test.go +++ b/irc/strings_test.go @@ -1,4 +1,5 @@ // Copyright (c) 2017 Euan Kemp +// Copyright (c) 2017 Daniel Oaks // released under the MIT license package irc @@ -50,14 +51,16 @@ func TestCasefoldChannel(t *testing.T) { for i, tt := range testCases { t.Run(fmt.Sprintf("case %d: %s", i, tt.channel), func(t *testing.T) { res, err := CasefoldChannel(tt.channel) - if tt.err { - if err == nil { - t.Errorf("expected error") - } + if tt.err && err == nil { + t.Errorf("expected error when casefolding [%s], but did not receive one", tt.channel) + return + } + if !tt.err && err != nil { + t.Errorf("unexpected error while casefolding [%s]: %s", tt.channel, err.Error()) return } if tt.folded != res { - t.Errorf("expected %v to be %v", tt.folded, res) + t.Errorf("expected [%v] to be [%v]", res, tt.folded) } }) } @@ -91,14 +94,16 @@ func TestCasefoldName(t *testing.T) { for i, tt := range testCases { t.Run(fmt.Sprintf("case %d: %s", i, tt.name), func(t *testing.T) { res, err := CasefoldName(tt.name) - if tt.err { - if err == nil { - t.Errorf("expected error") - } + if tt.err && err == nil { + t.Errorf("expected error when casefolding [%s], but did not receive one", tt.name) + return + } + if !tt.err && err != nil { + t.Errorf("unexpected error while casefolding [%s]: %s", tt.name, err.Error()) return } if tt.folded != res { - t.Errorf("expected %v to be %v", tt.folded, res) + t.Errorf("expected [%v] to be [%v]", res, tt.folded) } }) }