Refactor client with correct reconnecting behaviour, contexts and a clean design (#37)

Co-authored-by: James Mills <prologic@shortcircuit.net.au>
Reviewed-on: https://git.mills.io/prologic/msgbus/pulls/37
Reviewed-by: xuu <xuu@noreply@mills.io>
This commit is contained in:
James Mills 2022-04-05 02:59:48 +00:00
parent 1f25a21e20
commit daafbf1c60
2 changed files with 102 additions and 89 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@ -27,8 +28,12 @@ const (
// DefaultMaxReconnectInterval ...
DefaultMaxReconnectInterval = 64
// DefaultPingInterval is the default time interval between pings
DefaultPingInterval = 60 * time.Second
// DefaultPingInterval is the default time interval (in seconds) between pings
DefaultPingInterval = 60
)
var (
ErrConnectionFailed = errors.New("error: connection failed")
)
func noopHandler(msg *msgbus.Message) error { return nil }
@ -39,12 +44,14 @@ type Client struct {
url string
pingInterval time.Duration
reconnectInterval time.Duration
maxReconnectInterval time.Duration
}
// Options ...
type Options struct {
PingInterval int
ReconnectInterval int
MaxReconnectInterval int
}
@ -52,6 +59,7 @@ type Options struct {
// NewClient ...
func NewClient(url string, options *Options) *Client {
var (
pingInterval = DefaultPingInterval
reconnectInterval = DefaultReconnectInterval
maxReconnectInterval = DefaultMaxReconnectInterval
)
@ -61,6 +69,10 @@ func NewClient(url string, options *Options) *Client {
client := &Client{url: url}
if options != nil {
if options.PingInterval != 0 {
pingInterval = options.PingInterval
}
if options.ReconnectInterval != 0 {
reconnectInterval = options.ReconnectInterval
}
@ -70,6 +82,7 @@ func NewClient(url string, options *Options) *Client {
}
}
client.pingInterval = time.Duration(pingInterval) * time.Second
client.reconnectInterval = time.Duration(reconnectInterval) * time.Second
client.maxReconnectInterval = time.Duration(maxReconnectInterval) * time.Second
@ -147,8 +160,6 @@ func (c *Client) Subscribe(topic string, index int64, handler msgbus.HandlerFunc
type Subscriber struct {
sync.RWMutex
conn *websocket.Conn
client *Client
topic string
@ -156,6 +167,7 @@ type Subscriber struct {
handler msgbus.HandlerFunc
pingInterval time.Duration
reconnectInterval time.Duration
maxReconnectInterval time.Duration
}
@ -172,6 +184,7 @@ func NewSubscriber(client *Client, topic string, index int64, handler msgbus.Han
index: index,
handler: handler,
pingInterval: client.pingInterval,
reconnectInterval: client.reconnectInterval,
maxReconnectInterval: client.maxReconnectInterval,
}
@ -208,95 +221,100 @@ func (s *Subscriber) maybeUpdateIndex(msg *msgbus.Message) {
}
}
func (s *Subscriber) closeAndReconnect() {
s.RLock()
if s.conn != nil {
s.conn.Close(websocket.StatusNormalClosure, "Closing and reconnecting...")
go s.connect()
}
s.RUnlock()
}
func (s *Subscriber) connect() {
s.RLock()
func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) {
b := &backoff.Backoff{
Min: s.reconnectInterval,
Max: s.maxReconnectInterval,
Factor: 2,
Jitter: false,
Jitter: true,
}
s.RUnlock()
ctx := context.Background()
for {
conn, _, err := websocket.Dial(ctx, s.url(), nil)
url := s.url()
log.Debugf("connecting to %s", url)
conn, _, err := websocket.Dial(ctx, url, nil)
if err != nil {
log.WithError(err).Debugf("dial error")
if err == context.Canceled {
return nil, err
}
log.Debugf("reconnecting in %s", b.Duration())
time.Sleep(b.Duration())
continue
}
s.Lock()
s.conn = conn
s.Unlock()
go s.readLoop(ctx)
go s.heartbeat(ctx, DefaultPingInterval)
break
}
}
func (s *Subscriber) readLoop(ctx context.Context) {
var msg *msgbus.Message
for {
err := wsjson.Read(ctx, s.conn, &msg)
if err != nil {
s.closeAndReconnect()
return
}
s.maybeUpdateIndex(msg)
if err := s.handler(msg); err != nil {
log.Warnf("error handling message: %s", err)
}
}
}
// Start ...
func (s *Subscriber) Start() {
go s.connect()
}
// Stop ...
func (s *Subscriber) Stop() {
s.Lock()
defer s.Unlock()
if err := s.conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil {
log.Warnf("error sending close message: %s", err)
log.Debug("connected!")
return conn, nil
}
s.conn = nil
// never reached
}
func (s *Subscriber) heartbeat(ctx context.Context, d time.Duration) {
t := time.NewTimer(d)
defer t.Stop()
// Run runs the subscriber client with the provided context
func (s *Subscriber) Run(ctx context.Context) error {
conn, err := s.connect(ctx)
if err != nil {
return fmt.Errorf("error connecting: %w", err)
}
msgs := make(chan *msgbus.Message)
go s.writeloop(ctx, conn)
go s.readloop(ctx, conn, msgs)
for {
select {
case <-ctx.Done():
return
case <-t.C:
log.Debug("context done")
if err := conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil {
log.WithError(err).Debug("error closing connection")
}
return nil
case msg, ok := <-msgs:
if !ok {
log.Debug("readloop closed")
return nil
}
s.maybeUpdateIndex(msg)
if err := s.handler(msg); err != nil {
log.Warnf("error handling message: %s", err)
}
}
// c.Ping returns on receiving a pong
err := s.conn.Ping(ctx)
if err != nil {
s.closeAndReconnect()
}
t.Reset(time.Minute)
}
}
func (s *Subscriber) writeloop(ctx context.Context, conn *websocket.Conn) {
t := time.NewTicker(s.pingInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
log.Debug("context done")
return
case <-t.C:
log.Debug("sending ping...")
if err := conn.Ping(ctx); err != nil {
log.WithError(err).Debug("ping error")
return
}
}
}
}
func (s *Subscriber) readloop(ctx context.Context, conn *websocket.Conn, msgs chan *msgbus.Message) {
for {
var msg *msgbus.Message
err := wsjson.Read(ctx, conn, &msg)
if err != nil {
log.WithError(err).Debug("read error, reconnecting")
conn, err = s.connect(ctx)
if err != nil {
log.WithError(err).Debug("error reconnecting")
close(msgs)
break
}
go s.writeloop(ctx, conn)
continue
}
msgs <- msg
}
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
@ -106,20 +107,14 @@ func subscribe(client *client.Client, topic string, index int64, command string,
topic = defaultTopic
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := client.Subscribe(topic, index, handler(command, args))
s.Start()
go s.Run(ctx)
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigs
log.Printf("caught signal %s: ", sig)
s.Stop()
done <- true
}()
<-done
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigCh
log.Printf("caught signal %s: ", sig)
}