package client import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "strconv" "strings" "time" "github.com/jpillora/backoff" sync "github.com/sasha-s/go-deadlock" log "github.com/sirupsen/logrus" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" "git.mills.io/prologic/msgbus" ) const ( // DefaultReconnectInterval ... DefaultReconnectInterval = 2 // DefaultMaxReconnectInterval ... DefaultMaxReconnectInterval = 64 // DefaultPingInterval is the default time interval (in seconds) between pings DefaultPingInterval = 60 ) var ( ErrConnectionFailed = errors.New("error: connection failed") ) func noopHandler(msg *msgbus.Message) error { return nil } // Client ... type Client struct { sync.RWMutex url string pingInterval time.Duration reconnectInterval time.Duration maxReconnectInterval time.Duration } // Options ... type Options struct { PingInterval int ReconnectInterval int MaxReconnectInterval int } // NewClient ... func NewClient(url string, options *Options) *Client { var ( pingInterval = DefaultPingInterval reconnectInterval = DefaultReconnectInterval maxReconnectInterval = DefaultMaxReconnectInterval ) url = strings.TrimSuffix(url, "/") client := &Client{url: url} if options != nil { if options.PingInterval != 0 { pingInterval = options.PingInterval } if options.ReconnectInterval != 0 { reconnectInterval = options.ReconnectInterval } if options.MaxReconnectInterval != 0 { maxReconnectInterval = options.MaxReconnectInterval } } client.pingInterval = time.Duration(pingInterval) * time.Second client.reconnectInterval = time.Duration(reconnectInterval) * time.Second client.maxReconnectInterval = time.Duration(maxReconnectInterval) * time.Second return client } // Pull ... func (c *Client) Pull(topic string) (msg *msgbus.Message, err error) { c.RLock() defer c.RUnlock() url := fmt.Sprintf("%s/%s", c.url, topic) client := &http.Client{} req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } res, err := client.Do(req) if err != nil { return nil, err } // XXX: StatusNotFound is for backwards compatibility only for older clients. if res.StatusCode == http.StatusNoContent || res.StatusCode == http.StatusNotFound { // Empty queue return nil, nil } defer res.Body.Close() if err := json.NewDecoder(res.Body).Decode(&msg); err != nil { return nil, err } return msg, nil } // Publish ... func (c *Client) Publish(topic, message string) error { c.RLock() defer c.RUnlock() var payload bytes.Buffer payload.Write([]byte(message)) url := fmt.Sprintf("%s/%s", c.url, topic) client := &http.Client{} req, err := http.NewRequest("PUT", url, &payload) if err != nil { return fmt.Errorf("error constructing request: %s", err) } res, err := client.Do(req) if err != nil { return fmt.Errorf("error publishing message: %s", err) } if res.StatusCode != http.StatusAccepted { return fmt.Errorf("unexpected response: %s", res.Status) } return nil } // Subscribe ... func (c *Client) Subscribe(topic string, index int64, handler msgbus.HandlerFunc) *Subscriber { return NewSubscriber(c, topic, index, handler) } // Subscriber ... type Subscriber struct { sync.RWMutex client *Client topic string index int64 handler msgbus.HandlerFunc pingInterval time.Duration reconnectInterval time.Duration maxReconnectInterval time.Duration } // NewSubscriber ... func NewSubscriber(client *Client, topic string, index int64, handler msgbus.HandlerFunc) *Subscriber { if handler == nil { handler = noopHandler } return &Subscriber{ client: client, topic: topic, index: index, handler: handler, pingInterval: client.pingInterval, reconnectInterval: client.reconnectInterval, maxReconnectInterval: client.maxReconnectInterval, } } func (s *Subscriber) url() string { u, err := url.Parse(s.client.url) if err != nil { log.Fatalf("invalid url: %s", s.client.url) } if strings.HasPrefix(s.client.url, "https") { u.Scheme = "wss" } else { u.Scheme = "ws" } u.Path += fmt.Sprintf("/%s", s.topic) q := u.Query() q.Set("index", strconv.FormatInt(s.index, 10)) u.RawQuery = q.Encode() return u.String() } func (s *Subscriber) maybeUpdateIndex(msg *msgbus.Message) { s.Lock() defer s.Unlock() if s.index > 0 { log.Debugf("updating index from %d to %d", s.index, (msg.ID + 1)) // NB: We update to index +1 so we don't keep getting the previous message( s.index = msg.ID + 1 } } func (s *Subscriber) connect(ctx context.Context) (*websocket.Conn, error) { b := &backoff.Backoff{ Min: s.reconnectInterval, Max: s.maxReconnectInterval, Factor: 2, Jitter: true, } for { url := s.url() log.Debugf("connecting to %s", url) conn, _, err := websocket.Dial(ctx, url, nil) if err != nil { log.WithError(err).Debugf("dial error") if errors.Is(err, context.Canceled) { return nil, err } log.Debugf("reconnecting in %s", b.Duration()) time.Sleep(b.Duration()) continue } log.Debug("connected!") return conn, nil } // never reached } // Run runs the subscriber client with the provided context func (s *Subscriber) Run(ctx context.Context) error { conn, err := s.connect(ctx) if err != nil { return fmt.Errorf("error connecting: %w", err) } msgs := make(chan *msgbus.Message) go s.writeloop(ctx, conn) go s.readloop(ctx, conn, msgs) for { select { case <-ctx.Done(): log.Debug("context done") if err := conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil { log.WithError(err).Debug("error closing connection") } return nil case msg, ok := <-msgs: if !ok { log.Debug("readloop closed") return nil } s.maybeUpdateIndex(msg) if err := s.handler(msg); err != nil { log.Warnf("error handling message: %s", err) } } } } func (s *Subscriber) writeloop(ctx context.Context, conn *websocket.Conn) { t := time.NewTicker(s.pingInterval) defer t.Stop() for { select { case <-ctx.Done(): log.Debug("context done") return case <-t.C: log.Debug("sending ping...") if err := conn.Ping(ctx); err != nil { log.WithError(err).Debug("ping error") return } } } } func (s *Subscriber) readloop(ctx context.Context, conn *websocket.Conn, msgs chan *msgbus.Message) { for { var msg *msgbus.Message err := wsjson.Read(ctx, conn, &msg) if err != nil { log.WithError(err).Debug("read error, reconnecting") conn, err = s.connect(ctx) if err != nil { log.WithError(err).Debug("error reconnecting") close(msgs) break } go s.writeloop(ctx, conn) continue } msgs <- msg } }