diff --git a/client/client.go b/client/client.go index a45fd0e..83a1410 100644 --- a/client/client.go +++ b/client/client.go @@ -235,7 +235,7 @@ func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) { conn, _, err := websocket.Dial(ctx, url, nil) if err != nil { log.WithError(err).Debugf("dial error") - if err == context.Canceled { + if errors.Is(err, context.Canceled) { return nil, err } log.Debugf("reconnecting in %s", b.Duration()) diff --git a/client/client_test.go b/client/client_test.go index d60373d..174d2a4 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,10 +1,12 @@ package client import ( + "context" "io/ioutil" "net/http/httptest" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,3 +41,50 @@ func TestClientPublish(t *testing.T) { assert.Equal(topic, msg.Topic) assert.Equal([]byte("hello world"), msg.Payload) } + +func TestCancelContextOnSubscribe(t *testing.T) { + + assert := assert.New(t) + require := require.New(t) + + logDir, err := ioutil.TempDir("", "msgbus-logs-*") + require.NoError(err) + defer os.RemoveAll(logDir) + + mb, err := msgbus.NewMessageBus(msgbus.WithLogPath(logDir)) + require.NoError(err) + defer os.RemoveAll(logDir) + + server := httptest.NewServer(mb) + defer server.Close() + + client := NewClient(server.URL, nil) + + sub := client.Subscribe("hello", 1, func(msg *msgbus.Message) error { return nil }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + failAfter := time.NewTimer(5 * time.Second) + + errChan := make(chan error) + go func() { + errChan <- sub.Run(ctx) + }() + + loop := true + for loop { + select { + case <-failAfter.C: + loop = false + t.Fatal("Run() with cancelled context did not exit on time") + case runError := <-errChan: + loop = false + assert.ErrorIs(runError, context.Canceled) + } + } + + cancel() + <-ctx.Done() + +}