package msgbus import ( "compress/flate" "compress/gzip" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "strconv" "strings" "time" "github.com/andybalholm/brotli" sync "github.com/sasha-s/go-deadlock" log "github.com/sirupsen/logrus" "github.com/gorilla/websocket" ) const ( // DefaultMaxQueueSize is the default maximum size of queues DefaultMaxQueueSize = 1024 // ~8MB per queue (1000 * 4KB) // DefaultMaxPayloadSize is the default maximum payload size DefaultMaxPayloadSize = 8192 // 8KB // DefaultBufferLength is the default buffer length for subscriber chans DefaultBufferLength = 256 // Time allowed to write a message to the peer. writeWait = 10 * time.Second // Time allowed to read the next pong message from the peer. pongWait = 60 * time.Second // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 ) // TODO: Make this configurable? var upgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(r *http.Request) bool { return true }, } // HandlerFunc ... type HandlerFunc func(msg *Message) error // Topic ... type Topic struct { Name string `json:"name"` Sequence int `json:"seq"` Created time.Time `json:"created"` } func (t *Topic) String() string { return t.Name } // Message ... type Message struct { ID int `json:"id"` Topic *Topic `json:"topic"` Payload []byte `json:"payload"` Created time.Time `json:"created"` } // SubscribeOption ... type SubscribeOption func(*SubscriberOptions) // SubscriberOptions ... type SubscriberOptions struct { Index int } // WithIndex sets the index to start subscribing from func WithIndex(index int) SubscribeOption { return func(o *SubscriberOptions) { o.Index = index } } // SubscribersConfig ... type SubscriberConfig struct { BufferLength int } // Subscribers ... type Subscribers struct { sync.RWMutex buflen int ids map[string]bool chs map[string]chan Message } // NewSubscribers ... func NewSubscribers(config *SubscriberConfig) *Subscribers { var ( bufferLength int ) if config != nil { bufferLength = config.BufferLength } else { bufferLength = DefaultBufferLength } return &Subscribers{ buflen: bufferLength, ids: make(map[string]bool), chs: make(map[string]chan Message), } } // Len ... func (subs *Subscribers) Len() int { subs.RLock() defer subs.RUnlock() return len(subs.ids) } // AddSubscriber ... func (subs *Subscribers) AddSubscriber(id string) chan Message { subs.Lock() defer subs.Unlock() if subs.buflen <= 0 { log.Fatal("subscriber buflen <= 0") } ch := make(chan Message, subs.buflen) subs.ids[id] = true subs.chs[id] = ch return ch } // RemoveSubscriber ... func (subs *Subscribers) RemoveSubscriber(id string) { subs.Lock() defer subs.Unlock() delete(subs.ids, id) close(subs.chs[id]) delete(subs.chs, id) } // HasSubscriber ... func (subs *Subscribers) HasSubscriber(id string) bool { subs.RLock() defer subs.RUnlock() _, ok := subs.ids[id] return ok } // GetSubscriber ... func (subs *Subscribers) GetSubscriber(id string) (chan Message, bool) { subs.RLock() defer subs.RUnlock() ch, ok := subs.chs[id] if !ok { return nil, false } return ch, true } // NotifyAll ... func (subs *Subscribers) NotifyAll(message Message) int { subs.RLock() defer subs.RUnlock() i := 0 for id, ch := range subs.chs { select { case ch <- message: i++ default: // TODO: Drop this client? // TODO: Retry later? log.Warnf("cannot publish message to %s: %+v", id, message) } } return i } // Options ... type Options struct { BufferLength int MaxQueueSize int MaxPayloadSize int WithMetrics bool } // MessageBus ... type MessageBus struct { sync.RWMutex metrics *Metrics bufferLength int maxQueueSize int maxPayloadSize int topics map[string]*Topic queues map[*Topic]*Queue subscribers map[*Topic]*Subscribers } // New ... func New(options *Options) *MessageBus { var ( bufferLength int maxQueueSize int maxPayloadSize int withMetrics bool ) if options != nil { bufferLength = options.BufferLength maxQueueSize = options.MaxQueueSize maxPayloadSize = options.MaxPayloadSize withMetrics = options.WithMetrics } else { bufferLength = DefaultBufferLength maxQueueSize = DefaultMaxQueueSize maxPayloadSize = DefaultMaxPayloadSize withMetrics = false } var metrics *Metrics if withMetrics { metrics = NewMetrics("msgbus") ctime := time.Now() // server uptime counter metrics.NewCounterFunc( "server", "uptime", "Number of nanoseconds the server has been running", func() float64 { return float64(time.Since(ctime).Nanoseconds()) }, ) // server requests counter metrics.NewCounter( "server", "requests", "Number of total requests processed", ) // client latency summary metrics.NewSummary( "client", "latency_seconds", "Client latency in seconds", ) // client errors counter metrics.NewCounter( "client", "errors", "Number of errors publishing messages to clients", ) // bus messages counter metrics.NewCounter( "bus", "messages", "Number of total messages exchanged", ) // bus dropped counter metrics.NewCounter( "bus", "dropped", "Number of messages dropped to subscribers", ) // bus delivered counter metrics.NewCounter( "bus", "delivered", "Number of messages delivered to subscribers", ) // bus fetched counter metrics.NewCounter( "bus", "fetched", "Number of messages fetched from clients", ) // bus topics gauge metrics.NewCounter( "bus", "topics", "Number of active topics registered", ) // queue len gauge vec metrics.NewGaugeVec( "queue", "len", "Queue length of each topic", []string{"topic"}, ) // queue size gauge vec // TODO: Implement this gauge by somehow getting queue sizes per topic! metrics.NewGaugeVec( "queue", "size", "Queue length of each topic", []string{"topic"}, ) // bus subscribers gauge metrics.NewGauge( "bus", "subscribers", "Number of active subscribers", ) } return &MessageBus{ metrics: metrics, bufferLength: bufferLength, maxQueueSize: maxQueueSize, maxPayloadSize: maxPayloadSize, topics: make(map[string]*Topic), queues: make(map[*Topic]*Queue), subscribers: make(map[*Topic]*Subscribers), } } // Len ... func (mb *MessageBus) Len() int { return len(mb.topics) } // Metrics ... func (mb *MessageBus) Metrics() *Metrics { return mb.metrics } // NewTopic ... func (mb *MessageBus) NewTopic(topic string) *Topic { mb.Lock() defer mb.Unlock() t, ok := mb.topics[topic] if !ok { t = &Topic{Name: topic, Created: time.Now()} mb.topics[topic] = t if mb.metrics != nil { mb.metrics.Counter("bus", "topics").Inc() } } return t } // NewMessage ... func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message { defer func() { topic.Sequence++ if mb.metrics != nil { mb.metrics.Counter("bus", "messages").Inc() } }() return Message{ ID: topic.Sequence, Topic: topic, Payload: payload, Created: time.Now(), } } // Put ... func (mb *MessageBus) Put(message Message) { mb.Lock() defer mb.Unlock() t := message.Topic q, ok := mb.queues[t] if !ok { q = NewQueue(mb.maxQueueSize) mb.queues[message.Topic] = q } q.Push(message) if mb.metrics != nil { mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Inc() } mb.publish(message) } // Get ... func (mb *MessageBus) Get(t *Topic) (Message, bool) { mb.RLock() defer mb.RUnlock() q, ok := mb.queues[t] if !ok { return Message{}, false } m := q.Pop() if m == nil { return Message{}, false } if mb.metrics != nil { mb.metrics.Counter("bus", "fetched").Inc() mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Dec() } return m.(Message), true } // publish ... func (mb *MessageBus) publish(message Message) { subs, ok := mb.subscribers[message.Topic] if !ok { log.Debugf("no subscribers for %s", message.Topic.Name) return } log.Debug("notifying subscribers") n := subs.NotifyAll(message) if n != subs.Len() && mb.metrics != nil { log.Warnf("%d/%d subscribers notified", n, subs.Len()) mb.metrics.Counter("bus", "dropped").Add(float64(subs.Len() - n)) } } // Subscribe ... func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan Message { mb.Lock() defer mb.Unlock() t, ok := mb.topics[topic] if !ok { t = &Topic{Name: topic, Created: time.Now()} mb.topics[topic] = t } subs, ok := mb.subscribers[t] if !ok { subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.bufferLength}) mb.subscribers[t] = subs } if subs.HasSubscriber(id) { // Already verified the listener exists log.Debugf("already have subscriber %s", id) ch, _ := subs.GetSubscriber(id) return ch } if mb.metrics != nil { mb.metrics.Gauge("bus", "subscribers").Inc() } o := &SubscriberOptions{} for _, opt := range opts { opt(o) } ch := subs.AddSubscriber(id) q, ok := mb.queues[t] if !ok { log.Debug("nothing in queue, returning ch") return ch } if o.Index >= 0 && o.Index <= q.Len() { var n int log.Debugf("subscriber wants to start from %d", o.Index) q.ForEach(func(item interface{}) error { msg := item.(Message) log.Debugf("found #%v", msg) if msg.ID >= o.Index { ch <- msg n++ } return nil }) log.Debugf("published %d messages", n) } return ch } // Unsubscribe ... func (mb *MessageBus) Unsubscribe(id, topic string) { mb.Lock() defer mb.Unlock() t, ok := mb.topics[topic] if !ok { return } subs, ok := mb.subscribers[t] if !ok { return } if subs.HasSubscriber(id) { // Already verified the listener exists subs.RemoveSubscriber(id) if mb.metrics != nil { mb.metrics.Gauge("bus", "subscribers").Dec() } } } func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { if mb.metrics != nil { mb.metrics.Counter("server", "requests").Inc() } }() w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Accept-Encoding", "br, gzip, deflate") if r.Method == "GET" && (r.URL.Path == "/" || r.URL.Path == "") { // XXX: guard with a mutex? out, err := json.Marshal(mb.topics) if err != nil { msg := fmt.Sprintf("error serializing topics: %s", err) http.Error(w, msg, http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") w.Write(out) return } topic := strings.Trim(r.URL.Path, "/") t := mb.NewTopic(topic) log.Debugf("request for topic %#v", t.Name) switch r.Method { case "POST", "PUT": defer r.Body.Close() if r.ContentLength > int64(mb.maxPayloadSize) { msg := "payload exceeds max-payload-size" http.Error(w, msg, http.StatusRequestEntityTooLarge) return } var rd io.Reader = r.Body ce := r.Header.Get("Content-Encoding") switch ce { case "": case "br": rd = brotli.NewReader(rd) case "gzip": gz, err := gzip.NewReader(rd) if err != nil { msg := fmt.Sprintf("error reading payload: %s", err) http.Error(w, msg, http.StatusBadRequest) return } defer gz.Close() rd = gz case "deflate": fl := flate.NewReader(rd) defer fl.Close() rd = fl default: msg := fmt.Sprintf("error reading payload: not acceptable: %v", ce) http.Error(w, msg, http.StatusNotAcceptable) return } body, err := ioutil.ReadAll(rd) if err != nil { msg := fmt.Sprintf("error reading payload: %s", err) http.Error(w, msg, http.StatusBadRequest) return } if len(body) > mb.maxPayloadSize { msg := "payload exceeds max-payload-size" http.Error(w, msg, http.StatusRequestEntityTooLarge) return } mb.Put(mb.NewMessage(t, body)) w.WriteHeader(http.StatusAccepted) case "GET": if r.Header.Get("Upgrade") == "websocket" { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("error creating websocket client: %s", err) return } i := SafeParseInt(r.URL.Query().Get("index"), -1) log.Debugf("new subscriber for %s from %s", t.Name, r.RemoteAddr) NewClient(conn, t, i, mb).Start() return } message, ok := mb.Get(t) if !ok { http.Error(w, "No Messages", http.StatusNoContent) return } out, err := json.Marshal(message) if err != nil { msg := fmt.Sprintf("error serializing message: %s", err) http.Error(w, msg, http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") w.Write(out) case "DELETE": http.Error(w, "Not Implemented", http.StatusNotImplemented) // TODO: Implement deleting topics } } // Client ... type Client struct { conn *websocket.Conn topic *Topic index int bus *MessageBus id string ch chan Message } // NewClient ... func NewClient(conn *websocket.Conn, topic *Topic, index int, bus *MessageBus) *Client { return &Client{conn: conn, topic: topic, index: index, bus: bus} } func (c *Client) readPump() { defer func() { c.conn.Close() }() c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(message string) error { t, err := strconv.ParseInt(message, 10, 64) d := time.Duration(time.Now().UnixNano() - t) if err != nil { log.Warnf("garbage pong reply from %s: %s", c.id, err) } else { log.Debugf("pong latency of %s: %s", c.id, d) } c.conn.SetReadDeadline(time.Now().Add(pongWait)) if c.bus.metrics != nil { v := c.bus.metrics.Summary("client", "latency_seconds") v.Observe(d.Seconds()) } return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { c.bus.Unsubscribe(c.id, c.topic.Name) return } log.Debugf("recieved message from %s: %s", c.id, message) } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() var err error for { select { case msg, ok := <-c.ch: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // The bus closed the channel. message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bus closed") c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait)) return } err = c.conn.WriteJSON(msg) if err != nil { // TODO: Retry? Put the message back in the queue? log.Errorf("Error sending msg to %s: %s", c.id, err) if c.bus.metrics != nil { c.bus.metrics.Counter("client", "errors").Inc() } } else { if c.bus.metrics != nil { c.bus.metrics.Counter("bus", "delivered").Inc() } } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) t := time.Now() message := []byte(fmt.Sprintf("%d", t.UnixNano())) if err := c.conn.WriteMessage(websocket.PingMessage, message); err != nil { log.Errorf("error sending ping to %s: %s", c.id, err) return } } } } // Start ... func (c *Client) Start() { c.id = c.conn.RemoteAddr().String() c.ch = c.bus.Subscribe(c.id, c.topic.Name, WithIndex(c.index)) c.conn.SetCloseHandler(func(code int, text string) error { c.bus.Unsubscribe(c.id, c.topic.Name) message := websocket.FormatCloseMessage(code, text) c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait)) return nil }) go c.writePump() go c.readPump() }