prologic-msgbus/client/client.go

321 lines
6.6 KiB
Go

package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/jpillora/backoff"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"git.mills.io/prologic/msgbus"
)
const (
// DefaultReconnectInterval ...
DefaultReconnectInterval = 2
// DefaultMaxReconnectInterval ...
DefaultMaxReconnectInterval = 64
// DefaultPingInterval is the default time interval (in seconds) between pings
DefaultPingInterval = 60
)
var (
ErrConnectionFailed = errors.New("error: connection failed")
)
func noopHandler(msg *msgbus.Message) error { return nil }
// Client ...
type Client struct {
sync.RWMutex
url string
pingInterval time.Duration
reconnectInterval time.Duration
maxReconnectInterval time.Duration
}
// Options ...
type Options struct {
PingInterval int
ReconnectInterval int
MaxReconnectInterval int
}
// NewClient ...
func NewClient(url string, options *Options) *Client {
var (
pingInterval = DefaultPingInterval
reconnectInterval = DefaultReconnectInterval
maxReconnectInterval = DefaultMaxReconnectInterval
)
url = strings.TrimSuffix(url, "/")
client := &Client{url: url}
if options != nil {
if options.PingInterval != 0 {
pingInterval = options.PingInterval
}
if options.ReconnectInterval != 0 {
reconnectInterval = options.ReconnectInterval
}
if options.MaxReconnectInterval != 0 {
maxReconnectInterval = options.MaxReconnectInterval
}
}
client.pingInterval = time.Duration(pingInterval) * time.Second
client.reconnectInterval = time.Duration(reconnectInterval) * time.Second
client.maxReconnectInterval = time.Duration(maxReconnectInterval) * time.Second
return client
}
// Pull ...
func (c *Client) Pull(topic string) (msg *msgbus.Message, err error) {
c.RLock()
defer c.RUnlock()
url := fmt.Sprintf("%s/%s", c.url, topic)
client := &http.Client{}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
// XXX: StatusNotFound is for backwards compatibility only for older clients.
if res.StatusCode == http.StatusNoContent || res.StatusCode == http.StatusNotFound {
// Empty queue
return nil, nil
}
defer res.Body.Close()
if err := json.NewDecoder(res.Body).Decode(&msg); err != nil {
return nil, err
}
return msg, nil
}
// Publish ...
func (c *Client) Publish(topic, message string) error {
c.RLock()
defer c.RUnlock()
var payload bytes.Buffer
payload.Write([]byte(message))
url := fmt.Sprintf("%s/%s", c.url, topic)
client := &http.Client{}
req, err := http.NewRequest("PUT", url, &payload)
if err != nil {
return fmt.Errorf("error constructing request: %s", err)
}
res, err := client.Do(req)
if err != nil {
return fmt.Errorf("error publishing message: %s", err)
}
if res.StatusCode != http.StatusAccepted {
return fmt.Errorf("unexpected response: %s", res.Status)
}
return nil
}
// Subscribe ...
func (c *Client) Subscribe(topic string, index int64, handler msgbus.HandlerFunc) *Subscriber {
return NewSubscriber(c, topic, index, handler)
}
// Subscriber ...
type Subscriber struct {
sync.RWMutex
client *Client
topic string
index int64
handler msgbus.HandlerFunc
pingInterval time.Duration
reconnectInterval time.Duration
maxReconnectInterval time.Duration
}
// NewSubscriber ...
func NewSubscriber(client *Client, topic string, index int64, handler msgbus.HandlerFunc) *Subscriber {
if handler == nil {
handler = noopHandler
}
return &Subscriber{
client: client,
topic: topic,
index: index,
handler: handler,
pingInterval: client.pingInterval,
reconnectInterval: client.reconnectInterval,
maxReconnectInterval: client.maxReconnectInterval,
}
}
func (s *Subscriber) url() string {
u, err := url.Parse(s.client.url)
if err != nil {
log.Fatalf("invalid url: %s", s.client.url)
}
if strings.HasPrefix(s.client.url, "https") {
u.Scheme = "wss"
} else {
u.Scheme = "ws"
}
u.Path += fmt.Sprintf("/%s", s.topic)
q := u.Query()
q.Set("index", strconv.FormatInt(s.index, 10))
u.RawQuery = q.Encode()
return u.String()
}
func (s *Subscriber) maybeUpdateIndex(msg *msgbus.Message) {
s.Lock()
defer s.Unlock()
if s.index > 0 {
log.Debugf("updating index from %d to %d", s.index, (msg.ID + 1))
// NB: We update to index +1 so we don't keep getting the previous message(
s.index = msg.ID + 1
}
}
func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) {
b := &backoff.Backoff{
Min: s.reconnectInterval,
Max: s.maxReconnectInterval,
Factor: 2,
Jitter: true,
}
for {
url := s.url()
log.Debugf("connecting to %s", url)
conn, _, err := websocket.Dial(ctx, url, nil)
if err != nil {
log.WithError(err).Debugf("dial error")
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Debugf("reconnecting in %s", b.Duration())
time.Sleep(b.Duration())
continue
}
log.Debug("connected!")
return conn, nil
}
// never reached
}
// Run runs the subscriber client with the provided context
func (s *Subscriber) Run(ctx context.Context) error {
conn, err := s.connect(ctx)
if err != nil {
return fmt.Errorf("error connecting: %w", err)
}
msgs := make(chan *msgbus.Message)
go s.writeloop(ctx, conn)
go s.readloop(ctx, conn, msgs)
for {
select {
case <-ctx.Done():
log.Debug("context done")
if err := conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil {
log.WithError(err).Debug("error closing connection")
}
return nil
case msg, ok := <-msgs:
if !ok {
log.Debug("readloop closed")
return nil
}
s.maybeUpdateIndex(msg)
if err := s.handler(msg); err != nil {
log.Warnf("error handling message: %s", err)
}
}
}
}
func (s *Subscriber) writeloop(ctx context.Context, conn *websocket.Conn) {
t := time.NewTicker(s.pingInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
log.Debug("context done")
return
case <-t.C:
log.Debug("sending ping...")
if err := conn.Ping(ctx); err != nil {
log.WithError(err).Debug("ping error")
return
}
}
}
}
func (s *Subscriber) readloop(ctx context.Context, conn *websocket.Conn, msgs chan *msgbus.Message) {
for {
var msg *msgbus.Message
err := wsjson.Read(ctx, conn, &msg)
if err != nil {
log.WithError(err).Debug("read error, reconnecting")
conn, err = s.connect(ctx)
if err != nil {
log.WithError(err).Debug("error reconnecting")
close(msgs)
break
}
go s.writeloop(ctx, conn)
continue
}
msgs <- msg
}
}