mirror of
https://git.mills.io/prologic/msgbus.git
synced 2024-06-28 09:41:43 +00:00
daafbf1c60
Co-authored-by: James Mills <prologic@shortcircuit.net.au> Reviewed-on: https://git.mills.io/prologic/msgbus/pulls/37 Reviewed-by: xuu <xuu@noreply@mills.io>
321 lines
6.6 KiB
Go
321 lines
6.6 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jpillora/backoff"
|
|
sync "github.com/sasha-s/go-deadlock"
|
|
log "github.com/sirupsen/logrus"
|
|
"nhooyr.io/websocket"
|
|
"nhooyr.io/websocket/wsjson"
|
|
|
|
"git.mills.io/prologic/msgbus"
|
|
)
|
|
|
|
const (
|
|
// DefaultReconnectInterval ...
|
|
DefaultReconnectInterval = 2
|
|
|
|
// DefaultMaxReconnectInterval ...
|
|
DefaultMaxReconnectInterval = 64
|
|
|
|
// DefaultPingInterval is the default time interval (in seconds) between pings
|
|
DefaultPingInterval = 60
|
|
)
|
|
|
|
var (
|
|
ErrConnectionFailed = errors.New("error: connection failed")
|
|
)
|
|
|
|
func noopHandler(msg *msgbus.Message) error { return nil }
|
|
|
|
// Client ...
|
|
type Client struct {
|
|
sync.RWMutex
|
|
|
|
url string
|
|
|
|
pingInterval time.Duration
|
|
reconnectInterval time.Duration
|
|
maxReconnectInterval time.Duration
|
|
}
|
|
|
|
// Options ...
|
|
type Options struct {
|
|
PingInterval int
|
|
ReconnectInterval int
|
|
MaxReconnectInterval int
|
|
}
|
|
|
|
// NewClient ...
|
|
func NewClient(url string, options *Options) *Client {
|
|
var (
|
|
pingInterval = DefaultPingInterval
|
|
reconnectInterval = DefaultReconnectInterval
|
|
maxReconnectInterval = DefaultMaxReconnectInterval
|
|
)
|
|
|
|
url = strings.TrimSuffix(url, "/")
|
|
|
|
client := &Client{url: url}
|
|
|
|
if options != nil {
|
|
if options.PingInterval != 0 {
|
|
pingInterval = options.PingInterval
|
|
}
|
|
|
|
if options.ReconnectInterval != 0 {
|
|
reconnectInterval = options.ReconnectInterval
|
|
}
|
|
|
|
if options.MaxReconnectInterval != 0 {
|
|
maxReconnectInterval = options.MaxReconnectInterval
|
|
}
|
|
}
|
|
|
|
client.pingInterval = time.Duration(pingInterval) * time.Second
|
|
client.reconnectInterval = time.Duration(reconnectInterval) * time.Second
|
|
client.maxReconnectInterval = time.Duration(maxReconnectInterval) * time.Second
|
|
|
|
return client
|
|
}
|
|
|
|
// Pull ...
|
|
func (c *Client) Pull(topic string) (msg *msgbus.Message, err error) {
|
|
c.RLock()
|
|
defer c.RUnlock()
|
|
|
|
url := fmt.Sprintf("%s/%s", c.url, topic)
|
|
client := &http.Client{}
|
|
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// XXX: StatusNotFound is for backwards compatibility only for older clients.
|
|
if res.StatusCode == http.StatusNoContent || res.StatusCode == http.StatusNotFound {
|
|
// Empty queue
|
|
return nil, nil
|
|
}
|
|
|
|
defer res.Body.Close()
|
|
if err := json.NewDecoder(res.Body).Decode(&msg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return msg, nil
|
|
}
|
|
|
|
// Publish ...
|
|
func (c *Client) Publish(topic, message string) error {
|
|
c.RLock()
|
|
defer c.RUnlock()
|
|
|
|
var payload bytes.Buffer
|
|
|
|
payload.Write([]byte(message))
|
|
|
|
url := fmt.Sprintf("%s/%s", c.url, topic)
|
|
|
|
client := &http.Client{}
|
|
|
|
req, err := http.NewRequest("PUT", url, &payload)
|
|
if err != nil {
|
|
return fmt.Errorf("error constructing request: %s", err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("error publishing message: %s", err)
|
|
}
|
|
|
|
if res.StatusCode != http.StatusAccepted {
|
|
return fmt.Errorf("unexpected response: %s", res.Status)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Subscribe ...
|
|
func (c *Client) Subscribe(topic string, index int64, handler msgbus.HandlerFunc) *Subscriber {
|
|
return NewSubscriber(c, topic, index, handler)
|
|
}
|
|
|
|
// Subscriber ...
|
|
type Subscriber struct {
|
|
sync.RWMutex
|
|
|
|
client *Client
|
|
|
|
topic string
|
|
index int64
|
|
|
|
handler msgbus.HandlerFunc
|
|
|
|
pingInterval time.Duration
|
|
reconnectInterval time.Duration
|
|
maxReconnectInterval time.Duration
|
|
}
|
|
|
|
// NewSubscriber ...
|
|
func NewSubscriber(client *Client, topic string, index int64, handler msgbus.HandlerFunc) *Subscriber {
|
|
if handler == nil {
|
|
handler = noopHandler
|
|
}
|
|
|
|
return &Subscriber{
|
|
client: client,
|
|
topic: topic,
|
|
index: index,
|
|
handler: handler,
|
|
|
|
pingInterval: client.pingInterval,
|
|
reconnectInterval: client.reconnectInterval,
|
|
maxReconnectInterval: client.maxReconnectInterval,
|
|
}
|
|
}
|
|
|
|
func (s *Subscriber) url() string {
|
|
u, err := url.Parse(s.client.url)
|
|
if err != nil {
|
|
log.Fatalf("invalid url: %s", s.client.url)
|
|
}
|
|
|
|
if strings.HasPrefix(s.client.url, "https") {
|
|
u.Scheme = "wss"
|
|
} else {
|
|
u.Scheme = "ws"
|
|
}
|
|
|
|
u.Path += fmt.Sprintf("/%s", s.topic)
|
|
q := u.Query()
|
|
q.Set("index", strconv.FormatInt(s.index, 10))
|
|
u.RawQuery = q.Encode()
|
|
|
|
return u.String()
|
|
}
|
|
|
|
func (s *Subscriber) maybeUpdateIndex(msg *msgbus.Message) {
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
if s.index > 0 {
|
|
log.Debugf("updating index from %d to %d", s.index, (msg.ID + 1))
|
|
// NB: We update to index +1 so we don't keep getting the previous message(
|
|
s.index = msg.ID + 1
|
|
}
|
|
}
|
|
|
|
func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) {
|
|
b := &backoff.Backoff{
|
|
Min: s.reconnectInterval,
|
|
Max: s.maxReconnectInterval,
|
|
Factor: 2,
|
|
Jitter: true,
|
|
}
|
|
|
|
for {
|
|
url := s.url()
|
|
log.Debugf("connecting to %s", url)
|
|
conn, _, err := websocket.Dial(ctx, url, nil)
|
|
if err != nil {
|
|
log.WithError(err).Debugf("dial error")
|
|
if err == context.Canceled {
|
|
return nil, err
|
|
}
|
|
log.Debugf("reconnecting in %s", b.Duration())
|
|
time.Sleep(b.Duration())
|
|
continue
|
|
}
|
|
log.Debug("connected!")
|
|
return conn, nil
|
|
}
|
|
|
|
// never reached
|
|
}
|
|
|
|
// Run runs the subscriber client with the provided context
|
|
func (s *Subscriber) Run(ctx context.Context) error {
|
|
conn, err := s.connect(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("error connecting: %w", err)
|
|
}
|
|
|
|
msgs := make(chan *msgbus.Message)
|
|
go s.writeloop(ctx, conn)
|
|
go s.readloop(ctx, conn, msgs)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Debug("context done")
|
|
if err := conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil {
|
|
log.WithError(err).Debug("error closing connection")
|
|
}
|
|
return nil
|
|
case msg, ok := <-msgs:
|
|
if !ok {
|
|
log.Debug("readloop closed")
|
|
return nil
|
|
}
|
|
s.maybeUpdateIndex(msg)
|
|
if err := s.handler(msg); err != nil {
|
|
log.Warnf("error handling message: %s", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Subscriber) writeloop(ctx context.Context, conn *websocket.Conn) {
|
|
t := time.NewTicker(s.pingInterval)
|
|
defer t.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Debug("context done")
|
|
return
|
|
case <-t.C:
|
|
log.Debug("sending ping...")
|
|
if err := conn.Ping(ctx); err != nil {
|
|
log.WithError(err).Debug("ping error")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Subscriber) readloop(ctx context.Context, conn *websocket.Conn, msgs chan *msgbus.Message) {
|
|
for {
|
|
var msg *msgbus.Message
|
|
err := wsjson.Read(ctx, conn, &msg)
|
|
if err != nil {
|
|
log.WithError(err).Debug("read error, reconnecting")
|
|
conn, err = s.connect(ctx)
|
|
if err != nil {
|
|
log.WithError(err).Debug("error reconnecting")
|
|
close(msgs)
|
|
break
|
|
}
|
|
go s.writeloop(ctx, conn)
|
|
continue
|
|
}
|
|
msgs <- msg
|
|
}
|
|
}
|