prologic-msgbus/msgbus.go

754 lines
15 KiB
Go

package msgbus
import (
"compress/flate"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"
"github.com/andybalholm/brotli"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus"
"github.com/gorilla/websocket"
)
const (
// DefaultMaxQueueSize is the default maximum size of queues
DefaultMaxQueueSize = 1024 // ~8MB per queue (1000 * 4KB)
// DefaultMaxPayloadSize is the default maximum payload size
DefaultMaxPayloadSize = 8192 // 8KB
// DefaultBufferLength is the default buffer length for subscriber chans
DefaultBufferLength = 256
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
// TODO: Make this configurable?
var upgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// HandlerFunc ...
type HandlerFunc func(msg *Message) error
// Topic ...
type Topic struct {
Name string `json:"name"`
Sequence int `json:"seq"`
Created time.Time `json:"created"`
}
func (t *Topic) String() string {
return t.Name
}
// Message ...
type Message struct {
ID int `json:"id"`
Topic *Topic `json:"topic"`
Payload []byte `json:"payload"`
Created time.Time `json:"created"`
}
// SubscribeOption ...
type SubscribeOption func(*SubscriberOptions)
// SubscriberOptions ...
type SubscriberOptions struct {
Index int
}
// WithIndex sets the index to start subscribing from
func WithIndex(index int) SubscribeOption {
return func(o *SubscriberOptions) { o.Index = index }
}
// SubscribersConfig ...
type SubscriberConfig struct {
BufferLength int
}
// Subscribers ...
type Subscribers struct {
sync.RWMutex
buflen int
ids map[string]bool
chs map[string]chan Message
}
// NewSubscribers ...
func NewSubscribers(config *SubscriberConfig) *Subscribers {
var (
bufferLength int
)
if config != nil {
bufferLength = config.BufferLength
} else {
bufferLength = DefaultBufferLength
}
return &Subscribers{
buflen: bufferLength,
ids: make(map[string]bool),
chs: make(map[string]chan Message),
}
}
// Len ...
func (subs *Subscribers) Len() int {
subs.RLock()
defer subs.RUnlock()
return len(subs.ids)
}
// AddSubscriber ...
func (subs *Subscribers) AddSubscriber(id string) chan Message {
subs.Lock()
defer subs.Unlock()
if subs.buflen <= 0 {
log.Fatal("subscriber buflen <= 0")
}
ch := make(chan Message, subs.buflen)
subs.ids[id] = true
subs.chs[id] = ch
return ch
}
// RemoveSubscriber ...
func (subs *Subscribers) RemoveSubscriber(id string) {
subs.Lock()
defer subs.Unlock()
delete(subs.ids, id)
close(subs.chs[id])
delete(subs.chs, id)
}
// HasSubscriber ...
func (subs *Subscribers) HasSubscriber(id string) bool {
subs.RLock()
defer subs.RUnlock()
_, ok := subs.ids[id]
return ok
}
// GetSubscriber ...
func (subs *Subscribers) GetSubscriber(id string) (chan Message, bool) {
subs.RLock()
defer subs.RUnlock()
ch, ok := subs.chs[id]
if !ok {
return nil, false
}
return ch, true
}
// NotifyAll ...
func (subs *Subscribers) NotifyAll(message Message) int {
subs.RLock()
defer subs.RUnlock()
i := 0
for id, ch := range subs.chs {
select {
case ch <- message:
i++
default:
// TODO: Drop this client?
// TODO: Retry later?
log.Warnf("cannot publish message to %s: %+v", id, message)
}
}
return i
}
// Options ...
type Options struct {
BufferLength int
MaxQueueSize int
MaxPayloadSize int
WithMetrics bool
}
// MessageBus ...
type MessageBus struct {
sync.RWMutex
metrics *Metrics
bufferLength int
maxQueueSize int
maxPayloadSize int
topics map[string]*Topic
queues map[*Topic]*Queue
subscribers map[*Topic]*Subscribers
}
// New ...
func New(options *Options) *MessageBus {
var (
bufferLength int
maxQueueSize int
maxPayloadSize int
withMetrics bool
)
if options != nil {
bufferLength = options.BufferLength
maxQueueSize = options.MaxQueueSize
maxPayloadSize = options.MaxPayloadSize
withMetrics = options.WithMetrics
} else {
bufferLength = DefaultBufferLength
maxQueueSize = DefaultMaxQueueSize
maxPayloadSize = DefaultMaxPayloadSize
withMetrics = false
}
var metrics *Metrics
if withMetrics {
metrics = NewMetrics("msgbus")
ctime := time.Now()
// server uptime counter
metrics.NewCounterFunc(
"server", "uptime",
"Number of nanoseconds the server has been running",
func() float64 {
return float64(time.Since(ctime).Nanoseconds())
},
)
// server requests counter
metrics.NewCounter(
"server", "requests",
"Number of total requests processed",
)
// client latency summary
metrics.NewSummary(
"client", "latency_seconds",
"Client latency in seconds",
)
// client errors counter
metrics.NewCounter(
"client", "errors",
"Number of errors publishing messages to clients",
)
// bus messages counter
metrics.NewCounter(
"bus", "messages",
"Number of total messages exchanged",
)
// bus dropped counter
metrics.NewCounter(
"bus", "dropped",
"Number of messages dropped to subscribers",
)
// bus delivered counter
metrics.NewCounter(
"bus", "delivered",
"Number of messages delivered to subscribers",
)
// bus fetched counter
metrics.NewCounter(
"bus", "fetched",
"Number of messages fetched from clients",
)
// bus topics gauge
metrics.NewCounter(
"bus", "topics",
"Number of active topics registered",
)
// queue len gauge vec
metrics.NewGaugeVec(
"queue", "len",
"Queue length of each topic",
[]string{"topic"},
)
// queue size gauge vec
// TODO: Implement this gauge by somehow getting queue sizes per topic!
metrics.NewGaugeVec(
"queue", "size",
"Queue length of each topic",
[]string{"topic"},
)
// bus subscribers gauge
metrics.NewGauge(
"bus", "subscribers",
"Number of active subscribers",
)
}
return &MessageBus{
metrics: metrics,
bufferLength: bufferLength,
maxQueueSize: maxQueueSize,
maxPayloadSize: maxPayloadSize,
topics: make(map[string]*Topic),
queues: make(map[*Topic]*Queue),
subscribers: make(map[*Topic]*Subscribers),
}
}
// Len ...
func (mb *MessageBus) Len() int {
return len(mb.topics)
}
// Metrics ...
func (mb *MessageBus) Metrics() *Metrics {
return mb.metrics
}
// NewTopic ...
func (mb *MessageBus) NewTopic(topic string) *Topic {
mb.Lock()
defer mb.Unlock()
t, ok := mb.topics[topic]
if !ok {
t = &Topic{Name: topic, Created: time.Now()}
mb.topics[topic] = t
if mb.metrics != nil {
mb.metrics.Counter("bus", "topics").Inc()
}
}
return t
}
// NewMessage ...
func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message {
defer func() {
topic.Sequence++
if mb.metrics != nil {
mb.metrics.Counter("bus", "messages").Inc()
}
}()
return Message{
ID: topic.Sequence,
Topic: topic,
Payload: payload,
Created: time.Now(),
}
}
// Put ...
func (mb *MessageBus) Put(message Message) {
mb.Lock()
defer mb.Unlock()
t := message.Topic
q, ok := mb.queues[t]
if !ok {
q = NewQueue(mb.maxQueueSize)
mb.queues[message.Topic] = q
}
q.Push(message)
if mb.metrics != nil {
mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Inc()
}
mb.publish(message)
}
// Get ...
func (mb *MessageBus) Get(t *Topic) (Message, bool) {
mb.RLock()
defer mb.RUnlock()
q, ok := mb.queues[t]
if !ok {
return Message{}, false
}
m := q.Pop()
if m == nil {
return Message{}, false
}
if mb.metrics != nil {
mb.metrics.Counter("bus", "fetched").Inc()
mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Dec()
}
return m.(Message), true
}
// publish ...
func (mb *MessageBus) publish(message Message) {
subs, ok := mb.subscribers[message.Topic]
if !ok {
log.Debugf("no subscribers for %s", message.Topic.Name)
return
}
log.Debug("notifying subscribers")
n := subs.NotifyAll(message)
if n != subs.Len() && mb.metrics != nil {
log.Warnf("%d/%d subscribers notified", n, subs.Len())
mb.metrics.Counter("bus", "dropped").Add(float64(subs.Len() - n))
}
}
// Subscribe ...
func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan Message {
mb.Lock()
defer mb.Unlock()
t, ok := mb.topics[topic]
if !ok {
t = &Topic{Name: topic, Created: time.Now()}
mb.topics[topic] = t
}
subs, ok := mb.subscribers[t]
if !ok {
subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.bufferLength})
mb.subscribers[t] = subs
}
if subs.HasSubscriber(id) {
// Already verified the listener exists
log.Debugf("already have subscriber %s", id)
ch, _ := subs.GetSubscriber(id)
return ch
}
if mb.metrics != nil {
mb.metrics.Gauge("bus", "subscribers").Inc()
}
o := &SubscriberOptions{}
for _, opt := range opts {
opt(o)
}
ch := subs.AddSubscriber(id)
q, ok := mb.queues[t]
if !ok {
log.Debug("nothing in queue, returning ch")
return ch
}
if o.Index >= 0 && o.Index <= q.Len() {
var n int
log.Debugf("subscriber wants to start from %d", o.Index)
q.ForEach(func(item interface{}) error {
msg := item.(Message)
log.Debugf("found #%v", msg)
if msg.ID >= o.Index {
ch <- msg
n++
}
return nil
})
log.Debugf("published %d messages", n)
}
return ch
}
// Unsubscribe ...
func (mb *MessageBus) Unsubscribe(id, topic string) {
mb.Lock()
defer mb.Unlock()
t, ok := mb.topics[topic]
if !ok {
return
}
subs, ok := mb.subscribers[t]
if !ok {
return
}
if subs.HasSubscriber(id) {
// Already verified the listener exists
subs.RemoveSubscriber(id)
if mb.metrics != nil {
mb.metrics.Gauge("bus", "subscribers").Dec()
}
}
}
func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if mb.metrics != nil {
mb.metrics.Counter("server", "requests").Inc()
}
}()
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Accept-Encoding", "br, gzip, deflate")
if r.Method == "GET" && (r.URL.Path == "/" || r.URL.Path == "") {
// XXX: guard with a mutex?
out, err := json.Marshal(mb.topics)
if err != nil {
msg := fmt.Sprintf("error serializing topics: %s", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
w.Write(out)
return
}
topic := strings.Trim(r.URL.Path, "/")
t := mb.NewTopic(topic)
log.Debugf("request for topic %#v", t.Name)
switch r.Method {
case "POST", "PUT":
defer r.Body.Close()
if r.ContentLength > int64(mb.maxPayloadSize) {
msg := "payload exceeds max-payload-size"
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
}
var rd io.Reader = r.Body
ce := r.Header.Get("Content-Encoding")
switch ce {
case "":
case "br":
rd = brotli.NewReader(rd)
case "gzip":
gz, err := gzip.NewReader(rd)
if err != nil {
msg := fmt.Sprintf("error reading payload: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return
}
defer gz.Close()
rd = gz
case "deflate":
fl := flate.NewReader(rd)
defer fl.Close()
rd = fl
default:
msg := fmt.Sprintf("error reading payload: not acceptable: %v", ce)
http.Error(w, msg, http.StatusNotAcceptable)
return
}
body, err := ioutil.ReadAll(rd)
if err != nil {
msg := fmt.Sprintf("error reading payload: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return
}
if len(body) > mb.maxPayloadSize {
msg := "payload exceeds max-payload-size"
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
}
mb.Put(mb.NewMessage(t, body))
w.WriteHeader(http.StatusAccepted)
case "GET":
if r.Header.Get("Upgrade") == "websocket" {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("error creating websocket client: %s", err)
return
}
i := SafeParseInt(r.URL.Query().Get("index"), -1)
log.Debugf("new subscriber for %s from %s", t.Name, r.RemoteAddr)
NewClient(conn, t, i, mb).Start()
return
}
message, ok := mb.Get(t)
if !ok {
http.Error(w, "No Messages", http.StatusNoContent)
return
}
out, err := json.Marshal(message)
if err != nil {
msg := fmt.Sprintf("error serializing message: %s", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
w.Write(out)
case "DELETE":
http.Error(w, "Not Implemented", http.StatusNotImplemented)
// TODO: Implement deleting topics
}
}
// Client ...
type Client struct {
conn *websocket.Conn
topic *Topic
index int
bus *MessageBus
id string
ch chan Message
}
// NewClient ...
func NewClient(conn *websocket.Conn, topic *Topic, index int, bus *MessageBus) *Client {
return &Client{conn: conn, topic: topic, index: index, bus: bus}
}
func (c *Client) readPump() {
defer func() {
c.conn.Close()
}()
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(message string) error {
t, err := strconv.ParseInt(message, 10, 64)
d := time.Duration(time.Now().UnixNano() - t)
if err != nil {
log.Warnf("garbage pong reply from %s: %s", c.id, err)
} else {
log.Debugf("pong latency of %s: %s", c.id, d)
}
c.conn.SetReadDeadline(time.Now().Add(pongWait))
if c.bus.metrics != nil {
v := c.bus.metrics.Summary("client", "latency_seconds")
v.Observe(d.Seconds())
}
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
c.bus.Unsubscribe(c.id, c.topic.Name)
return
}
log.Debugf("recieved message from %s: %s", c.id, message)
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
var err error
for {
select {
case msg, ok := <-c.ch:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// The bus closed the channel.
message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bus closed")
c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait))
return
}
err = c.conn.WriteJSON(msg)
if err != nil {
// TODO: Retry? Put the message back in the queue?
log.Errorf("Error sending msg to %s: %s", c.id, err)
if c.bus.metrics != nil {
c.bus.metrics.Counter("client", "errors").Inc()
}
} else {
if c.bus.metrics != nil {
c.bus.metrics.Counter("bus", "delivered").Inc()
}
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
t := time.Now()
message := []byte(fmt.Sprintf("%d", t.UnixNano()))
if err := c.conn.WriteMessage(websocket.PingMessage, message); err != nil {
log.Errorf("error sending ping to %s: %s", c.id, err)
return
}
}
}
}
// Start ...
func (c *Client) Start() {
c.id = c.conn.RemoteAddr().String()
c.ch = c.bus.Subscribe(c.id, c.topic.Name, WithIndex(c.index))
c.conn.SetCloseHandler(func(code int, text string) error {
c.bus.Unsubscribe(c.id, c.topic.Name)
message := websocket.FormatCloseMessage(code, text)
c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait))
return nil
})
go c.writePump()
go c.readPump()
}