package msgbus import ( "compress/flate" "compress/gzip" "encoding/json" "errors" "fmt" "io" "io/fs" "io/ioutil" "net/http" "os" "path/filepath" "strconv" "strings" "time" "github.com/andybalholm/brotli" securejoin "github.com/cyphar/filepath-securejoin" sync "github.com/sasha-s/go-deadlock" log "github.com/sirupsen/logrus" "github.com/tidwall/wal" msgpack "github.com/vmihailenco/msgpack/v5" "github.com/gorilla/websocket" ) const ( // Time allowed to write a message to the peer. writeWait = 10 * time.Second // Time allowed to read the next pong message from the peer. pongWait = 60 * time.Second // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 ) var ( // BufferFull is logged in Subscribe() when a subscriber's // buffer is full and messages can no longer be enqueued for delivery ErrBufferFull = errors.New("error: subscriber buffer full") ) // TODO: Make this configurable? var upgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(r *http.Request) bool { return true }, } // HandlerFunc ... type HandlerFunc func(msg *Message) error // Topic ... type Topic struct { Name string `json:"name"` Sequence int64 `json:"seq"` Created time.Time `json:"created"` } func (t *Topic) String() string { return t.Name } // Message ... type Message struct { ID int64 `json:"id"` Topic *Topic `json:"topic"` Payload []byte `json:"payload"` Created time.Time `json:"created"` } func LoadMessage(data []byte) (m Message, err error) { err = msgpack.Unmarshal(data, &m) return } func (m Message) Bytes() ([]byte, error) { return msgpack.Marshal(m) } // SubscribeOption ... type SubscribeOption func(*SubscriberOptions) // SubscriberOptions ... type SubscriberOptions struct { Index int64 } // WithIndex sets the index to start subscribing from func WithIndex(index int64) SubscribeOption { return func(o *SubscriberOptions) { o.Index = index } } // SubscribersConfig ... type SubscriberConfig struct { BufferLength int } // Subscribers ... type Subscribers struct { sync.RWMutex buflen int ids map[string]bool chs map[string]chan Message } // NewSubscribers ... func NewSubscribers(config *SubscriberConfig) *Subscribers { var ( bufferLength int ) if config != nil { bufferLength = config.BufferLength } else { bufferLength = DefaultBufferLength } return &Subscribers{ buflen: bufferLength, ids: make(map[string]bool), chs: make(map[string]chan Message), } } // Len ... func (subs *Subscribers) Len() int { subs.RLock() defer subs.RUnlock() return len(subs.ids) } // AddSubscriber ... func (subs *Subscribers) AddSubscriber(id string) chan Message { subs.Lock() defer subs.Unlock() if subs.buflen <= 0 { log.Fatal("subscriber buflen <= 0") } ch := make(chan Message, subs.buflen) subs.ids[id] = true subs.chs[id] = ch return ch } // RemoveSubscriber ... func (subs *Subscribers) RemoveSubscriber(id string) { subs.Lock() defer subs.Unlock() delete(subs.ids, id) close(subs.chs[id]) delete(subs.chs, id) } // HasSubscriber ... func (subs *Subscribers) HasSubscriber(id string) bool { subs.RLock() defer subs.RUnlock() _, ok := subs.ids[id] return ok } // GetSubscriber ... func (subs *Subscribers) GetSubscriber(id string) (chan Message, bool) { subs.RLock() defer subs.RUnlock() ch, ok := subs.chs[id] if !ok { return nil, false } return ch, true } // NotifyAll ... func (subs *Subscribers) NotifyAll(message Message) int { subs.RLock() defer subs.RUnlock() i := 0 for id, ch := range subs.chs { select { case ch <- message: i++ default: // TODO: Drop this client? // TODO: Retry later? log.Warnf("cannot publish message %s#%d to %s", message.Topic.Name, message.ID, id) } } return i } // MessageBus ... type MessageBus struct { sync.RWMutex options *Options metrics *Metrics topics map[string]*Topic queues map[*Topic]*Queue logs map[*Topic]*wal.Log subscribers map[*Topic]*Subscribers } // NewMessageBus creates a new message bus with the provided options func NewMessageBus(opts ...Option) (*MessageBus, error) { options := NewDefaultOptions() for _, opt := range opts { if err := opt(options); err != nil { return nil, err } } var metrics *Metrics if options.Metrics { metrics = NewMetrics("msgbus") ctime := time.Now() // server uptime counter metrics.NewCounterFunc( "server", "uptime", "Number of nanoseconds the server has been running", func() float64 { return float64(time.Since(ctime).Nanoseconds()) }, ) // server requests counter metrics.NewCounter( "server", "requests", "Number of total requests processed", ) // client latency summary metrics.NewSummary( "client", "latency_seconds", "Client latency in seconds", ) // client errors counter metrics.NewCounter( "client", "errors", "Number of errors publishing messages to clients", ) // bus messages counter metrics.NewCounter( "bus", "messages", "Number of total messages exchanged", ) // bus dropped counter metrics.NewCounter( "bus", "dropped", "Number of messages dropped to subscribers", ) // bus delivered counter metrics.NewCounter( "bus", "delivered", "Number of messages delivered to subscribers", ) // bus fetched counter metrics.NewCounter( "bus", "fetched", "Number of messages fetched from clients", ) // bus topics gauge metrics.NewCounter( "bus", "topics", "Number of active topics registered", ) // queue len gauge vec metrics.NewGaugeVec( "queue", "len", "Queue length of each topic", []string{"topic"}, ) // queue size gauge vec // TODO: Implement this gauge by somehow getting queue sizes per topic! metrics.NewGaugeVec( "queue", "size", "Queue length of each topic", []string{"topic"}, ) // bus subscribers gauge metrics.NewGauge( "bus", "subscribers", "Number of active subscribers", ) } mb := &MessageBus{ options: options, metrics: metrics, topics: make(map[string]*Topic), queues: make(map[*Topic]*Queue), logs: make(map[*Topic]*wal.Log), subscribers: make(map[*Topic]*Subscribers), } if err := mb.readLogs(); err != nil { return nil, fmt.Errorf("error reading logs: %w", err) } return mb, nil } func (mb *MessageBus) readLog(f fs.FileInfo) error { log.Debugf("reading log %s", f.Name()) l, err := wal.Open(filepath.Join(mb.options.LogPath, f.Name()), nil) if err != nil { return fmt.Errorf("error opening log %s: %w", f, err) } t := mb.newTopic(f.Name()) t.Created = f.ModTime() first, err := l.FirstIndex() if err != nil { return fmt.Errorf("error reading first index: %w", err) } last, err := l.LastIndex() if err != nil { return fmt.Errorf("error reading last index: %w", err) } log.Debugf("first index: %d", first) log.Debugf("last index: %d", last) t.Sequence = int64(last) q := NewQueue(mb.options.MaxQueueSize) start := int64(last) - int64(mb.options.MaxQueueSize) if start < 0 { start = int64(first) } end := int64(last) log.Debugf("start: %d", start) log.Debugf("end: %d", end) for i := start; i <= end; i++ { data, err := l.Read(uint64(i)) if err != nil { return fmt.Errorf("error reading log %d: %w", i, err) } msg, err := LoadMessage(data) if err != nil { return fmt.Errorf("error deserialing log %d: %w", i, err) } q.Push(msg) } mb.queues[t] = q return nil } func (mb *MessageBus) readLogs() error { mb.Lock() defer mb.Unlock() log.Debug("reading logs...") dirs, err := os.ReadDir(mb.options.LogPath) if err != nil { return fmt.Errorf("error listing logs: %w", err) } for _, dir := range dirs { if dir.IsDir() { info, err := dir.Info() if err != nil { return fmt.Errorf("error reading log path %s: %w", dir, err) } if err := mb.readLog(info); err != nil { return err } } } return nil } // Len ... func (mb *MessageBus) Len() int { return len(mb.topics) } // Metrics ... func (mb *MessageBus) Metrics() *Metrics { return mb.metrics } func (mb *MessageBus) newTopic(topic string) *Topic { t, ok := mb.topics[topic] if !ok { t = &Topic{Name: topic, Created: time.Now()} mb.topics[topic] = t if mb.metrics != nil { mb.metrics.Counter("bus", "topics").Inc() } } return t } // NewTopic ... func (mb *MessageBus) NewTopic(topic string) *Topic { mb.Lock() defer mb.Unlock() return mb.newTopic(topic) } // NewMessage ... func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message { defer func() { topic.Sequence++ if mb.metrics != nil { mb.metrics.Counter("bus", "messages").Inc() } }() return Message{ ID: topic.Sequence + 1, Topic: topic, Payload: payload, Created: time.Now(), } } // Put ... func (mb *MessageBus) Put(message Message) error { mb.Lock() defer mb.Unlock() t := message.Topic l, ok := mb.logs[t] if !ok { fn, err := securejoin.SecureJoin(mb.options.LogPath, t.Name) if err != nil { return fmt.Errorf("error creating logfile filename for %s: %w", t.Name, err) } l, err = wal.Open(fn, nil) if err != nil { return fmt.Errorf("error opening logfile %s: %w", fn, err) } mb.logs[t] = l } id := uint64(message.ID) data, err := message.Bytes() if err != nil { return fmt.Errorf("error serializing message %d: %w", id, err) } if err := l.Write(id, data); err != nil { return fmt.Errorf("error writing message %d to logfile: %w", message.ID, err) } q, ok := mb.queues[t] if !ok { q = NewQueue(mb.options.MaxQueueSize) mb.queues[t] = q } q.Push(message) if mb.metrics != nil { mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Inc() } mb.publish(message) return nil } // Get ... func (mb *MessageBus) Get(t *Topic) (Message, bool) { mb.RLock() defer mb.RUnlock() q, ok := mb.queues[t] if !ok { return Message{}, false } m := q.Pop() if m == nil { return Message{}, false } if mb.metrics != nil { mb.metrics.Counter("bus", "fetched").Inc() mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Dec() } return m.(Message), true } // publish ... func (mb *MessageBus) publish(message Message) { subs, ok := mb.subscribers[message.Topic] if !ok { log.Debugf("no subscribers for %s", message.Topic.Name) return } log.Debug("notifying subscribers") n := subs.NotifyAll(message) if n != subs.Len() && mb.metrics != nil { log.Warnf("%d/%d subscribers notified", n, subs.Len()) mb.metrics.Counter("bus", "dropped").Add(float64(subs.Len() - n)) } } // Subscribe ... func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan Message { mb.Lock() defer mb.Unlock() t, ok := mb.topics[topic] if !ok { t = &Topic{Name: topic, Created: time.Now()} mb.topics[topic] = t } subs, ok := mb.subscribers[t] if !ok { subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.options.BufferLength}) mb.subscribers[t] = subs } if subs.HasSubscriber(id) { // Already verified the listener exists log.Debugf("already have subscriber %s", id) ch, _ := subs.GetSubscriber(id) return ch } if mb.metrics != nil { mb.metrics.Gauge("bus", "subscribers").Inc() } o := &SubscriberOptions{} for _, opt := range opts { opt(o) } ch := subs.AddSubscriber(id) q, ok := mb.queues[t] if !ok { log.Debug("nothing in queue, returning ch") return ch } // IF client requests to start from index >= 0 (-1 from the head) // AND the topic's sequence number hasn't been reset (< Index) THEN if o.Index > 0 && o.Index <= t.Sequence { var n int log.Debugf("subscriber wants to start from %d", o.Index) err := q.ForEach(func(item interface{}) error { if msg, ok := item.(Message); ok && msg.ID >= o.Index { log.Debugf("found msg %s#%d", msg.Topic.Name, msg.ID) select { case ch <- msg: n++ default: if mb.metrics != nil { mb.metrics.Counter("bus", "dropped").Inc() } return ErrBufferFull } } return nil }) if err != nil { log.WithError(err).Error("error publishing messages to new subscriber") } log.Debugf("published %d messages", n) return ch } // Otherwise, s.Index was eitehr 0 (start from head) // OR > topic.Sequence in which case (start from head) return ch } // Unsubscribe ... func (mb *MessageBus) Unsubscribe(id, topic string) { mb.Lock() defer mb.Unlock() t, ok := mb.topics[topic] if !ok { return } subs, ok := mb.subscribers[t] if !ok { return } if subs.HasSubscriber(id) { // Already verified the listener exists subs.RemoveSubscriber(id) if mb.metrics != nil { mb.metrics.Gauge("bus", "subscribers").Dec() } } } func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { if mb.metrics != nil { mb.metrics.Counter("server", "requests").Inc() } }() w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Accept-Encoding", "br, gzip, deflate") if r.Method == "GET" && (r.URL.Path == "/" || r.URL.Path == "") { // XXX: guard with a mutex? out, err := json.Marshal(mb.topics) if err != nil { msg := fmt.Sprintf("error serializing topics: %s", err) http.Error(w, msg, http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") w.Write(out) return } topic := strings.Trim(r.URL.Path, "/") t := mb.NewTopic(topic) log.Debugf("request for topic %s", t.Name) switch r.Method { case "POST", "PUT": defer r.Body.Close() if r.ContentLength > int64(mb.options.MaxPayloadSize) { msg := "payload exceeds max-payload-size" http.Error(w, msg, http.StatusRequestEntityTooLarge) return } var rd io.Reader = r.Body ce := r.Header.Get("Content-Encoding") switch ce { case "": case "br": rd = brotli.NewReader(rd) case "gzip": gz, err := gzip.NewReader(rd) if err != nil { msg := fmt.Sprintf("error reading payload: %s", err) http.Error(w, msg, http.StatusBadRequest) return } defer gz.Close() rd = gz case "deflate": fl := flate.NewReader(rd) defer fl.Close() rd = fl default: msg := fmt.Sprintf("error reading payload: not acceptable: %v", ce) http.Error(w, msg, http.StatusNotAcceptable) return } body, err := ioutil.ReadAll(rd) if err != nil { msg := fmt.Sprintf("error reading payload: %s", err) http.Error(w, msg, http.StatusBadRequest) return } if len(body) > mb.options.MaxPayloadSize { msg := "payload exceeds max-payload-size" http.Error(w, msg, http.StatusRequestEntityTooLarge) return } mb.Put(mb.NewMessage(t, body)) w.WriteHeader(http.StatusAccepted) case "GET": if r.Header.Get("Upgrade") == "websocket" { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("error creating websocket client: %s", err) return } i := SafeParseInt64(r.URL.Query().Get("index"), -1) log.Debugf("new subscriber for %s from %s", t.Name, r.RemoteAddr) NewClient(conn, t, i, mb).Start() return } message, ok := mb.Get(t) if !ok { http.Error(w, "No Messages", http.StatusNoContent) return } out, err := json.Marshal(message) if err != nil { msg := fmt.Sprintf("error serializing message: %s", err) http.Error(w, msg, http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") w.Write(out) case "DELETE": http.Error(w, "Not Implemented", http.StatusNotImplemented) // TODO: Implement deleting topics } } // Client ... type Client struct { conn *websocket.Conn topic *Topic index int64 bus *MessageBus id string ch chan Message } // NewClient ... func NewClient(conn *websocket.Conn, topic *Topic, index int64, bus *MessageBus) *Client { return &Client{ conn: conn, topic: topic, index: index, bus: bus, id: MustGenerateULID(), } } func (c *Client) readPump() { defer func() { c.conn.Close() }() c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(message string) error { t, err := strconv.ParseInt(message, 10, 64) d := time.Duration(time.Now().UnixNano() - t) if err != nil { log.Warnf("garbage pong reply from %s: %s", c.id, err) } else { log.Debugf("pong latency of %s: %s", c.id, d) } c.conn.SetReadDeadline(time.Now().Add(pongWait)) if c.bus.metrics != nil { v := c.bus.metrics.Summary("client", "latency_seconds") v.Observe(d.Seconds()) } return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { c.bus.Unsubscribe(c.id, c.topic.Name) return } log.Debugf("recieved message from %s: %s", c.id, message) } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() var err error for { select { case msg, ok := <-c.ch: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // The bus closed the channel. message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bus closed") c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait)) return } err = c.conn.WriteJSON(msg) if err != nil { // TODO: Retry? Put the message back in the queue? log.Errorf("Error sending msg to %s: %s", c.id, err) if c.bus.metrics != nil { c.bus.metrics.Counter("client", "errors").Inc() } } else { if c.bus.metrics != nil { c.bus.metrics.Counter("bus", "delivered").Inc() } } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) t := time.Now() message := []byte(fmt.Sprintf("%d", t.UnixNano())) if err := c.conn.WriteMessage(websocket.PingMessage, message); err != nil { log.Errorf("error sending ping to %s: %s", c.id, err) return } } } } // Start ... func (c *Client) Start() { c.ch = c.bus.Subscribe(c.id, c.topic.Name, WithIndex(c.index)) c.conn.SetCloseHandler(func(code int, text string) error { c.bus.Unsubscribe(c.id, c.topic.Name) message := websocket.FormatCloseMessage(code, text) c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait)) return nil }) go c.writePump() go c.readPump() }