check for nested context.Canceled
This commit is contained in:
rodič
a421534ee1
revize
ebd665ee25
|
@ -235,7 +235,7 @@ func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) {
|
|||
conn, _, err := websocket.Dial(ctx, url, nil)
|
||||
if err != nil {
|
||||
log.WithError(err).Debugf("dial error")
|
||||
if err == context.Canceled {
|
||||
if checkForCancelled(err) {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("reconnecting in %s", b.Duration())
|
||||
|
@ -318,3 +318,11 @@ func (s *Subscriber) readloop(ctx context.Context, conn *websocket.Conn, msgs ch
|
|||
msgs <- msg
|
||||
}
|
||||
}
|
||||
|
||||
func checkForCancelled(err error) bool {
|
||||
next := err
|
||||
for next != nil && next != context.Canceled {
|
||||
next = errors.Unwrap(next)
|
||||
}
|
||||
return next == context.Canceled
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -39,3 +42,67 @@ func TestClientPublish(t *testing.T) {
|
|||
assert.Equal(topic, msg.Topic)
|
||||
assert.Equal([]byte("hello world"), msg.Payload)
|
||||
}
|
||||
|
||||
func TestCancelContextOnSubscribe(t *testing.T) {
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
logDir, err := ioutil.TempDir("", "msgbus-logs-*")
|
||||
require.NoError(err)
|
||||
defer os.RemoveAll(logDir)
|
||||
|
||||
mb, err := msgbus.NewMessageBus(msgbus.WithLogPath(logDir))
|
||||
require.NoError(err)
|
||||
defer os.RemoveAll(logDir)
|
||||
|
||||
server := httptest.NewServer(mb)
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, nil)
|
||||
|
||||
sub := client.Subscribe("hello", 1, func(msg *msgbus.Message) error { return nil })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
failAfter := time.NewTimer(5 * time.Second)
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- sub.Run(ctx)
|
||||
}()
|
||||
|
||||
loop := true
|
||||
for loop {
|
||||
select {
|
||||
case <-failAfter.C:
|
||||
loop = false
|
||||
t.Fatal("Run() with cancelled context did not exit on time")
|
||||
case runError := <-errChan:
|
||||
loop = false
|
||||
assert.ErrorIs(runError, context.Canceled)
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
}
|
||||
|
||||
func TestIsContextCancelWrapped(t *testing.T) {
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.False(checkForCancelled(nil))
|
||||
assert.False(checkForCancelled(ErrConnectionFailed))
|
||||
assert.True(checkForCancelled(context.Canceled))
|
||||
|
||||
levelOne := fmt.Errorf("level one :%w", context.Canceled)
|
||||
assert.True(checkForCancelled(levelOne))
|
||||
levelTwo := fmt.Errorf("level two :%w", levelOne)
|
||||
assert.True(checkForCancelled(levelTwo))
|
||||
levelThree := fmt.Errorf("level two :%w", levelTwo)
|
||||
assert.True(checkForCancelled(levelThree))
|
||||
|
||||
}
|
||||
|
|
Načítá se…
Odkázat v novém úkolu