Refactor client with correct reconnecting behaviour, contexts and a clean design
Cette révision appartient à :
Parent
1f25a21e20
révision
f79a1b6a1e
173
client/client.go
173
client/client.go
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -27,11 +28,15 @@ const (
|
|||
// DefaultMaxReconnectInterval ...
|
||||
DefaultMaxReconnectInterval = 64
|
||||
|
||||
// DefaultPingInterval is the default time interval between pings
|
||||
DefaultPingInterval = 60 * time.Second
|
||||
// DefaultPingInterval is the default time interval (in seconds) between pings
|
||||
DefaultPingInterval = 5
|
||||
)
|
||||
|
||||
func noopHandler(msg *msgbus.Message) error { return nil }
|
||||
var (
|
||||
ErrConnectionFailed = errors.New("error: connection failed")
|
||||
)
|
||||
|
||||
func noopHandler(msg msgbus.Message) error { return nil }
|
||||
|
||||
// Client ...
|
||||
type Client struct {
|
||||
|
@ -39,12 +44,14 @@ type Client struct {
|
|||
|
||||
url string
|
||||
|
||||
pingInterval time.Duration
|
||||
reconnectInterval time.Duration
|
||||
maxReconnectInterval time.Duration
|
||||
}
|
||||
|
||||
// Options ...
|
||||
type Options struct {
|
||||
PingInterval int
|
||||
ReconnectInterval int
|
||||
MaxReconnectInterval int
|
||||
}
|
||||
|
@ -52,6 +59,7 @@ type Options struct {
|
|||
// NewClient ...
|
||||
func NewClient(url string, options *Options) *Client {
|
||||
var (
|
||||
pingInterval = DefaultPingInterval
|
||||
reconnectInterval = DefaultReconnectInterval
|
||||
maxReconnectInterval = DefaultMaxReconnectInterval
|
||||
)
|
||||
|
@ -61,6 +69,10 @@ func NewClient(url string, options *Options) *Client {
|
|||
client := &Client{url: url}
|
||||
|
||||
if options != nil {
|
||||
if options.PingInterval != 0 {
|
||||
pingInterval = options.PingInterval
|
||||
}
|
||||
|
||||
if options.ReconnectInterval != 0 {
|
||||
reconnectInterval = options.ReconnectInterval
|
||||
}
|
||||
|
@ -70,6 +82,7 @@ func NewClient(url string, options *Options) *Client {
|
|||
}
|
||||
}
|
||||
|
||||
client.pingInterval = time.Duration(pingInterval) * time.Second
|
||||
client.reconnectInterval = time.Duration(reconnectInterval) * time.Second
|
||||
client.maxReconnectInterval = time.Duration(maxReconnectInterval) * time.Second
|
||||
|
||||
|
@ -147,8 +160,6 @@ func (c *Client) Subscribe(topic string, index int64, handler msgbus.HandlerFunc
|
|||
type Subscriber struct {
|
||||
sync.RWMutex
|
||||
|
||||
conn *websocket.Conn
|
||||
|
||||
client *Client
|
||||
|
||||
topic string
|
||||
|
@ -156,6 +167,7 @@ type Subscriber struct {
|
|||
|
||||
handler msgbus.HandlerFunc
|
||||
|
||||
pingInterval time.Duration
|
||||
reconnectInterval time.Duration
|
||||
maxReconnectInterval time.Duration
|
||||
}
|
||||
|
@ -172,12 +184,16 @@ func NewSubscriber(client *Client, topic string, index int64, handler msgbus.Han
|
|||
index: index,
|
||||
handler: handler,
|
||||
|
||||
pingInterval: client.pingInterval,
|
||||
reconnectInterval: client.reconnectInterval,
|
||||
maxReconnectInterval: client.maxReconnectInterval,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscriber) url() string {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
u, err := url.Parse(s.client.url)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid url: %s", s.client.url)
|
||||
|
@ -197,7 +213,7 @@ func (s *Subscriber) url() string {
|
|||
return u.String()
|
||||
}
|
||||
|
||||
func (s *Subscriber) maybeUpdateIndex(msg *msgbus.Message) {
|
||||
func (s *Subscriber) maybeUpdateIndex(msg msgbus.Message) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
|
@ -208,95 +224,102 @@ func (s *Subscriber) maybeUpdateIndex(msg *msgbus.Message) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Subscriber) closeAndReconnect() {
|
||||
s.RLock()
|
||||
if s.conn != nil {
|
||||
s.conn.Close(websocket.StatusNormalClosure, "Closing and reconnecting...")
|
||||
go s.connect()
|
||||
}
|
||||
s.RUnlock()
|
||||
}
|
||||
|
||||
func (s *Subscriber) connect() {
|
||||
func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) {
|
||||
s.RLock()
|
||||
b := &backoff.Backoff{
|
||||
Min: s.reconnectInterval,
|
||||
Max: s.maxReconnectInterval,
|
||||
Factor: 2,
|
||||
Jitter: false,
|
||||
Jitter: true,
|
||||
}
|
||||
s.RUnlock()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
conn, _, err := websocket.Dial(ctx, s.url(), nil)
|
||||
url := s.url()
|
||||
log.Debugf("connecting to %s", url)
|
||||
conn, _, err := websocket.Dial(ctx, url, nil)
|
||||
if err != nil {
|
||||
log.WithError(err).Debugf("dial error")
|
||||
if err == context.Canceled {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("reconnecting in %s", b.Duration())
|
||||
time.Sleep(b.Duration())
|
||||
continue
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
s.conn = conn
|
||||
s.Unlock()
|
||||
|
||||
go s.readLoop(ctx)
|
||||
go s.heartbeat(ctx, DefaultPingInterval)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscriber) readLoop(ctx context.Context) {
|
||||
var msg *msgbus.Message
|
||||
|
||||
for {
|
||||
err := wsjson.Read(ctx, s.conn, &msg)
|
||||
if err != nil {
|
||||
s.closeAndReconnect()
|
||||
return
|
||||
}
|
||||
s.maybeUpdateIndex(msg)
|
||||
|
||||
if err := s.handler(msg); err != nil {
|
||||
log.Warnf("error handling message: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start ...
|
||||
func (s *Subscriber) Start() {
|
||||
go s.connect()
|
||||
}
|
||||
|
||||
// Stop ...
|
||||
func (s *Subscriber) Stop() {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if err := s.conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil {
|
||||
log.Warnf("error sending close message: %s", err)
|
||||
|
||||
log.Debug("connected!")
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
s.conn = nil
|
||||
// never reached
|
||||
}
|
||||
|
||||
func (s *Subscriber) heartbeat(ctx context.Context, d time.Duration) {
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
func (s *Subscriber) Run(ctx context.Context) error {
|
||||
conn, err := s.connect(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting: %w", err)
|
||||
}
|
||||
|
||||
msgs := make(chan msgbus.Message)
|
||||
go s.writeloop(ctx, conn)
|
||||
go s.readloop(ctx, conn, msgs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
log.Debug("context done")
|
||||
if err := conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil {
|
||||
log.WithError(err).Debug("error closing connection")
|
||||
}
|
||||
return nil
|
||||
case msg, ok := <-msgs:
|
||||
if !ok {
|
||||
log.Debug("readloop closed")
|
||||
return nil
|
||||
}
|
||||
s.maybeUpdateIndex(msg)
|
||||
if err := s.handler(msg); err != nil {
|
||||
log.Warnf("error handling message: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// c.Ping returns on receiving a pong
|
||||
err := s.conn.Ping(ctx)
|
||||
if err != nil {
|
||||
s.closeAndReconnect()
|
||||
}
|
||||
t.Reset(time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscriber) writeloop(ctx context.Context, conn *websocket.Conn) {
|
||||
s.RLock()
|
||||
t := time.NewTicker(s.pingInterval)
|
||||
s.RUnlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("context done")
|
||||
return
|
||||
case <-t.C:
|
||||
log.Debug("sending ping...")
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
log.WithError(err).Debug("ping error")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscriber) readloop(ctx context.Context, conn *websocket.Conn, msgs chan msgbus.Message) {
|
||||
for {
|
||||
var msg msgbus.Message
|
||||
err := wsjson.Read(ctx, conn, &msg)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("read error, reconnecting")
|
||||
conn, err = s.connect(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("error reconnecting")
|
||||
close(msgs)
|
||||
break
|
||||
}
|
||||
go s.writeloop(ctx, conn)
|
||||
continue
|
||||
}
|
||||
msgs <- msg
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -64,7 +65,7 @@ func init() {
|
|||
}
|
||||
|
||||
func handler(command string, args []string) msgbus.HandlerFunc {
|
||||
return func(msg *msgbus.Message) error {
|
||||
return func(msg msgbus.Message) error {
|
||||
out, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Printf("error marshalling message: %s", err)
|
||||
|
@ -106,20 +107,14 @@ func subscribe(client *client.Client, topic string, index int64, command string,
|
|||
topic = defaultTopic
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := client.Subscribe(topic, index, handler(command, args))
|
||||
s.Start()
|
||||
go s.Run(ctx)
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
sig := <-sigs
|
||||
log.Printf("caught signal %s: ", sig)
|
||||
s.Stop()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
<-done
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
sig := <-sigCh
|
||||
log.Printf("caught signal %s: ", sig)
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ var upgrader = websocket.Upgrader{
|
|||
}
|
||||
|
||||
// HandlerFunc ...
|
||||
type HandlerFunc func(msg *Message) error
|
||||
type HandlerFunc func(msg Message) error
|
||||
|
||||
// Topic ...
|
||||
type Topic struct {
|
||||
|
|
Chargement…
Référencer dans un nouveau ticket