package client import ( "bytes" "context" "encoding/json" "fmt" "net/http" "net/url" "os" "strings" "sync" "time" "github.com/jpillora/backoff" log "github.com/sirupsen/logrus" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" "git.mills.io/prologic/msgbus" ) const ( // DefaultReconnectInterval ... DefaultReconnectInterval = 2 // DefaultMaxReconnectInterval ... DefaultMaxReconnectInterval = 64 // DefaultPingInterval is the default time interval between pings DefaultPingInterval = 60 * time.Second ) // Client ... type Client struct { sync.RWMutex url string reconnectInterval time.Duration maxReconnectInterval time.Duration } // Options ... type Options struct { ReconnectInterval int MaxReconnectInterval int } // NewClient ... func NewClient(url string, options *Options) *Client { var ( reconnectInterval = DefaultReconnectInterval maxReconnectInterval = DefaultMaxReconnectInterval ) url = strings.TrimSuffix(url, "/") client := &Client{url: url} if options != nil { if options.ReconnectInterval != 0 { reconnectInterval = options.ReconnectInterval } if options.MaxReconnectInterval != 0 { maxReconnectInterval = options.MaxReconnectInterval } } client.reconnectInterval = time.Duration(reconnectInterval) * time.Second client.maxReconnectInterval = time.Duration(maxReconnectInterval) * time.Second return client } // Handle ... func (c *Client) Handle(msg *msgbus.Message) error { out, err := json.Marshal(msg) if err != nil { return err } os.Stdout.Write(out) os.Stdout.Write([]byte{'\r', '\n'}) return nil } // Pull ... func (c *Client) Pull(topic string) (msg *msgbus.Message, err error) { c.RLock() defer c.RUnlock() url := fmt.Sprintf("%s/%s", c.url, topic) client := &http.Client{} req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } res, err := client.Do(req) if err != nil { return nil, err } if res.StatusCode == http.StatusNotFound { // Empty queue return nil, nil } defer res.Body.Close() if err := json.NewDecoder(res.Body).Decode(&msg); err != nil { return nil, err } if err := c.Handle(msg); err != nil { return nil, err } return msg, nil } // Publish ... func (c *Client) Publish(topic, message string) error { c.RLock() defer c.RUnlock() var payload bytes.Buffer payload.Write([]byte(message)) url := fmt.Sprintf("%s/%s", c.url, topic) client := &http.Client{} req, err := http.NewRequest("PUT", url, &payload) if err != nil { return fmt.Errorf("error constructing request: %s", err) } res, err := client.Do(req) if err != nil { return fmt.Errorf("error publishing message: %s", err) } if res.StatusCode != http.StatusAccepted { return fmt.Errorf("unexpected response: %s", res.Status) } return nil } // Subscribe ... func (c *Client) Subscribe(topic string, handler msgbus.HandlerFunc) *Subscriber { return NewSubscriber(c, topic, handler) } // Subscriber ... type Subscriber struct { sync.RWMutex conn *websocket.Conn client *Client topic string handler msgbus.HandlerFunc url string reconnectInterval time.Duration maxReconnectInterval time.Duration closeWriteChan chan bool } // NewSubscriber ... func NewSubscriber(client *Client, topic string, handler msgbus.HandlerFunc) *Subscriber { if handler == nil { handler = client.Handle } u, err := url.Parse(client.url) if err != nil { log.Fatalf("invalid url: %s", client.url) } if strings.HasPrefix(client.url, "https") { u.Scheme = "wss" } else { u.Scheme = "ws" } u.Path += fmt.Sprintf("/%s", topic) url := u.String() return &Subscriber{ client: client, topic: topic, handler: handler, url: url, reconnectInterval: client.reconnectInterval, maxReconnectInterval: client.maxReconnectInterval, closeWriteChan: make(chan bool, 1), } } func (s *Subscriber) closeAndReconnect() { s.closeWriteChan <- true s.RLock() s.conn.Close(websocket.StatusNormalClosure, "Closing and reconnecting...") s.RUnlock() go s.connect() } func (s *Subscriber) connect() { s.RLock() b := &backoff.Backoff{ Min: s.reconnectInterval, Max: s.maxReconnectInterval, Factor: 2, Jitter: false, } s.RUnlock() for { conn, _, err := websocket.Dial(context.TODO(), s.url, nil) if err != nil { time.Sleep(b.Duration()) continue } s.Lock() s.conn = conn s.closeWriteChan = make(chan bool, 1) s.Unlock() go s.readLoop() go s.writeLoop() break } } func (s *Subscriber) readLoop() { var msg *msgbus.Message for { if err := wsjson.Read(context.TODO(), s.conn, &msg); err != nil { s.closeAndReconnect() return } if err := s.handler(msg); err != nil { log.Warnf("error handling message: %s", err) } } } func (s *Subscriber) writeLoop() { ticker := time.NewTicker(DefaultPingInterval) defer func() { ticker.Stop() s.RLock() defer s.RUnlock() if s.conn != nil { s.conn.Close(websocket.StatusNormalClosure, "Closed writeLoop()") } }() for { select { case <-ticker.C: if err := s.conn.Ping(context.TODO()); err != nil { s.closeAndReconnect() return } case <-s.closeWriteChan: return } } } // Start ... func (s *Subscriber) Start() { go s.connect() } // Stop ... func (s *Subscriber) Stop() { s.Lock() defer s.Unlock() if err := s.conn.Close(websocket.StatusNormalClosure, "Subscriber stopped"); err != nil { log.Warnf("error sending close message: %s", err) } s.conn = nil }