prologic-msgbus/msgbus.go

891 lines
18 KiB
Go

package msgbus
import (
"compress/flate"
"compress/gzip"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/andybalholm/brotli"
securejoin "github.com/cyphar/filepath-securejoin"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus"
"github.com/tidwall/wal"
msgpack "github.com/vmihailenco/msgpack/v5"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
var (
// BufferFull is logged in Subscribe() when a subscriber's
// buffer is full and messages can no longer be enqueued for delivery
ErrBufferFull = errors.New("error: subscriber buffer full")
)
// TODO: Make this configurable?
var upgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// HandlerFunc ...
type HandlerFunc func(msg *Message) error
// Topic ...
type Topic struct {
Name string `json:"name"`
Sequence int64 `json:"seq"`
Created time.Time `json:"created"`
}
func (t *Topic) String() string {
return t.Name
}
// Message ...
type Message struct {
ID int64 `json:"id"`
Topic *Topic `json:"topic"`
Payload []byte `json:"payload"`
Created time.Time `json:"created"`
}
func LoadMessage(data []byte) (m Message, err error) {
err = msgpack.Unmarshal(data, &m)
return
}
func (m Message) String() string {
return fmt.Sprintf("msg#%d[%s]@%d(%q)", m.ID, m.Topic, m.Created.Unix(), string(m.Payload))
}
func (m Message) Bytes() ([]byte, error) {
return msgpack.Marshal(m)
}
// SubscribeOption ...
type SubscribeOption func(*SubscriberOptions)
// SubscriberOptions ...
type SubscriberOptions struct {
Index int64
}
// WithIndex sets the index to start subscribing from
func WithIndex(index int64) SubscribeOption {
return func(o *SubscriberOptions) { o.Index = index }
}
// SubscribersConfig ...
type SubscriberConfig struct {
BufferLength int
}
// Subscribers ...
type Subscribers struct {
sync.RWMutex
buflen int
ids map[string]bool
chs map[string]chan Message
}
// NewSubscribers ...
func NewSubscribers(config *SubscriberConfig) *Subscribers {
var (
bufferLength int
)
if config != nil {
bufferLength = config.BufferLength
} else {
bufferLength = DefaultBufferLength
}
return &Subscribers{
buflen: bufferLength,
ids: make(map[string]bool),
chs: make(map[string]chan Message),
}
}
// Len ...
func (subs *Subscribers) Len() int {
subs.RLock()
defer subs.RUnlock()
return len(subs.ids)
}
// AddSubscriber ...
func (subs *Subscribers) AddSubscriber(id string) chan Message {
subs.Lock()
defer subs.Unlock()
if subs.buflen <= 0 {
log.Fatal("subscriber buflen <= 0")
}
ch := make(chan Message, subs.buflen)
subs.ids[id] = true
subs.chs[id] = ch
return ch
}
// RemoveSubscriber ...
func (subs *Subscribers) RemoveSubscriber(id string) {
subs.Lock()
defer subs.Unlock()
delete(subs.ids, id)
close(subs.chs[id])
delete(subs.chs, id)
}
// HasSubscriber ...
func (subs *Subscribers) HasSubscriber(id string) bool {
subs.RLock()
defer subs.RUnlock()
_, ok := subs.ids[id]
return ok
}
// GetSubscriber ...
func (subs *Subscribers) GetSubscriber(id string) (chan Message, bool) {
subs.RLock()
defer subs.RUnlock()
ch, ok := subs.chs[id]
if !ok {
return nil, false
}
return ch, true
}
// NotifyAll ...
func (subs *Subscribers) NotifyAll(message Message) int {
subs.RLock()
defer subs.RUnlock()
i := 0
for id, ch := range subs.chs {
select {
case ch <- message:
i++
default:
// TODO: Drop this client?
// TODO: Retry later?
log.Warnf("cannot publish message %s#%d to %s", message.Topic.Name, message.ID, id)
}
}
return i
}
// MessageBus ...
type MessageBus struct {
sync.RWMutex
options *Options
metrics *Metrics
topics map[string]*Topic
queues map[*Topic]*Queue
logs map[*Topic]*wal.Log
subscribers map[*Topic]*Subscribers
}
// NewMessageBus creates a new message bus with the provided options
func NewMessageBus(opts ...Option) (*MessageBus, error) {
options := NewDefaultOptions()
for _, opt := range opts {
if err := opt(options); err != nil {
return nil, err
}
}
var metrics *Metrics
if options.Metrics {
metrics = NewMetrics("msgbus")
ctime := time.Now()
// server uptime counter
metrics.NewCounterFunc(
"server", "uptime",
"Number of nanoseconds the server has been running",
func() float64 {
return float64(time.Since(ctime).Nanoseconds())
},
)
// server requests counter
metrics.NewCounter(
"server", "requests",
"Number of total requests processed",
)
// client latency summary
metrics.NewSummary(
"client", "latency_seconds",
"Client latency in seconds",
)
// client errors counter
metrics.NewCounter(
"client", "errors",
"Number of errors publishing messages to clients",
)
// bus messages counter
metrics.NewCounter(
"bus", "messages",
"Number of total messages exchanged",
)
// bus dropped counter
metrics.NewCounter(
"bus", "dropped",
"Number of messages dropped to subscribers",
)
// bus delivered counter
metrics.NewCounter(
"bus", "delivered",
"Number of messages delivered to subscribers",
)
// bus fetched counter
metrics.NewCounter(
"bus", "fetched",
"Number of messages fetched from clients",
)
// bus topics gauge
metrics.NewCounter(
"bus", "topics",
"Number of active topics registered",
)
// queue len gauge vec
metrics.NewGaugeVec(
"queue", "len",
"Queue length of each topic",
[]string{"topic"},
)
// queue size gauge vec
// TODO: Implement this gauge by somehow getting queue sizes per topic!
metrics.NewGaugeVec(
"queue", "size",
"Queue length of each topic",
[]string{"topic"},
)
// bus subscribers gauge
metrics.NewGauge(
"bus", "subscribers",
"Number of active subscribers",
)
}
mb := &MessageBus{
options: options,
metrics: metrics,
topics: make(map[string]*Topic),
queues: make(map[*Topic]*Queue),
logs: make(map[*Topic]*wal.Log),
subscribers: make(map[*Topic]*Subscribers),
}
if err := mb.readLogs(); err != nil {
return nil, fmt.Errorf("error reading logs: %w", err)
}
return mb, nil
}
func (mb *MessageBus) readLog(f fs.FileInfo) error {
log.Debugf("reading log %s", f.Name())
l, err := wal.Open(filepath.Join(mb.options.LogPath, f.Name()), nil)
if err != nil {
return fmt.Errorf("error opening log %s: %w", f, err)
}
t := mb.newTopic(f.Name())
t.Created = f.ModTime()
first, err := l.FirstIndex()
if err != nil {
return fmt.Errorf("error reading first index: %w", err)
}
last, err := l.LastIndex()
if err != nil {
return fmt.Errorf("error reading last index: %w", err)
}
log.Debugf("first index: %d", first)
log.Debugf("last index: %d", last)
t.Sequence = int64(last)
q := NewQueue(mb.options.MaxQueueSize)
start := int64(last) - int64(mb.options.MaxQueueSize)
if start < 0 {
start = int64(first)
}
end := int64(last)
log.Debugf("start: %d", start)
log.Debugf("end: %d", end)
for i := start; i <= end; i++ {
data, err := l.Read(uint64(i))
if err != nil {
return fmt.Errorf("error reading log %d: %w", i, err)
}
msg, err := LoadMessage(data)
if err != nil {
return fmt.Errorf("error deserialing log %d: %w", i, err)
}
q.Push(msg)
}
mb.queues[t] = q
return nil
}
func (mb *MessageBus) readLogs() error {
mb.Lock()
defer mb.Unlock()
log.Debug("reading logs...")
dirs, err := os.ReadDir(mb.options.LogPath)
if err != nil {
return fmt.Errorf("error listing logs: %w", err)
}
for _, dir := range dirs {
if dir.IsDir() {
info, err := dir.Info()
if err != nil {
return fmt.Errorf("error reading log path %s: %w", dir, err)
}
if err := mb.readLog(info); err != nil {
return err
}
}
}
return nil
}
// Len ...
func (mb *MessageBus) Len() int {
return len(mb.topics)
}
// Metrics ...
func (mb *MessageBus) Metrics() *Metrics {
return mb.metrics
}
func (mb *MessageBus) newTopic(topic string) *Topic {
t, ok := mb.topics[topic]
if !ok {
t = &Topic{Name: topic, Created: time.Now()}
mb.topics[topic] = t
if mb.metrics != nil {
mb.metrics.Counter("bus", "topics").Inc()
}
}
return t
}
// NewTopic ...
func (mb *MessageBus) NewTopic(topic string) *Topic {
mb.Lock()
defer mb.Unlock()
return mb.newTopic(topic)
}
// NewMessage ...
func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message {
defer func() {
topic.Sequence++
if mb.metrics != nil {
mb.metrics.Counter("bus", "messages").Inc()
}
}()
return Message{
ID: topic.Sequence + 1,
Topic: topic,
Payload: payload,
Created: time.Now(),
}
}
// Put ...
func (mb *MessageBus) Put(message Message) error {
mb.Lock()
defer mb.Unlock()
t := message.Topic
l, ok := mb.logs[t]
if !ok {
fn, err := securejoin.SecureJoin(mb.options.LogPath, t.Name)
if err != nil {
return fmt.Errorf("error creating logfile filename for %s: %w", t.Name, err)
}
l, err = wal.Open(fn, nil)
if err != nil {
return fmt.Errorf("error opening logfile %s: %w", fn, err)
}
mb.logs[t] = l
}
id := uint64(message.ID)
data, err := message.Bytes()
if err != nil {
return fmt.Errorf("error serializing message %d: %w", id, err)
}
if err := l.Write(id, data); err != nil {
return fmt.Errorf("error writing message %d to logfile: %w", message.ID, err)
}
q, ok := mb.queues[t]
if !ok {
q = NewQueue(mb.options.MaxQueueSize)
mb.queues[t] = q
}
q.Push(message)
if mb.metrics != nil {
mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Inc()
}
mb.publish(message)
return nil
}
// Get ...
func (mb *MessageBus) Get(t *Topic) (Message, bool) {
mb.RLock()
defer mb.RUnlock()
q, ok := mb.queues[t]
if !ok {
return Message{}, false
}
m := q.Pop()
if m == nil {
return Message{}, false
}
if mb.metrics != nil {
mb.metrics.Counter("bus", "fetched").Inc()
mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Dec()
}
return m.(Message), true
}
// publish ...
func (mb *MessageBus) publish(message Message) {
subs, ok := mb.subscribers[message.Topic]
if !ok {
log.Debugf("no subscribers for %s", message.Topic.Name)
return
}
log.Debug("notifying subscribers")
n := subs.NotifyAll(message)
if n != subs.Len() && mb.metrics != nil {
log.Warnf("%d/%d subscribers notified", n, subs.Len())
mb.metrics.Counter("bus", "dropped").Add(float64(subs.Len() - n))
}
}
// Subscribe ...
func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan Message {
mb.Lock()
defer mb.Unlock()
t, ok := mb.topics[topic]
if !ok {
t = &Topic{Name: topic, Created: time.Now()}
mb.topics[topic] = t
}
subs, ok := mb.subscribers[t]
if !ok {
subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.options.BufferLength})
mb.subscribers[t] = subs
}
if subs.HasSubscriber(id) {
// Already verified the listener exists
log.Debugf("already have subscriber %s", id)
ch, _ := subs.GetSubscriber(id)
return ch
}
if mb.metrics != nil {
mb.metrics.Gauge("bus", "subscribers").Inc()
}
o := &SubscriberOptions{}
for _, opt := range opts {
opt(o)
}
ch := subs.AddSubscriber(id)
q, ok := mb.queues[t]
if !ok {
log.Debug("nothing in queue, returning ch")
return ch
}
// IF client requests to start from index >= 0 (-1 from the head)
// AND the topic's sequence number hasn't been reset (< Index) THEN
if o.Index > 0 && o.Index <= t.Sequence {
var n int
log.Debugf("subscriber wants to start from %d", o.Index)
err := q.ForEach(func(item interface{}) error {
if msg, ok := item.(Message); ok {
log.Debugf("found msg %s", msg)
if msg.ID >= o.Index {
select {
case ch <- msg:
n++
default:
if mb.metrics != nil {
mb.metrics.Counter("bus", "dropped").Inc()
}
return ErrBufferFull
}
} else {
log.Debugf("msg %s before requested index %d", msg, o.Index)
}
}
return nil
})
if err != nil {
log.WithError(err).Error("error publishing messages to new subscriber")
}
log.Debugf("published %d messages", n)
return ch
}
// Otherwise, s.Index was eitehr 0 (start from head)
// OR > topic.Sequence in which case (start from head)
return ch
}
// Unsubscribe ...
func (mb *MessageBus) Unsubscribe(id, topic string) {
mb.Lock()
defer mb.Unlock()
t, ok := mb.topics[topic]
if !ok {
return
}
subs, ok := mb.subscribers[t]
if !ok {
return
}
if subs.HasSubscriber(id) {
// Already verified the listener exists
subs.RemoveSubscriber(id)
if mb.metrics != nil {
mb.metrics.Gauge("bus", "subscribers").Dec()
}
}
}
func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if mb.metrics != nil {
mb.metrics.Counter("server", "requests").Inc()
}
}()
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Accept-Encoding", "br, gzip, deflate")
if r.Method == "GET" && (r.URL.Path == "/" || r.URL.Path == "") {
// XXX: guard with a mutex?
out, err := json.Marshal(mb.topics)
if err != nil {
msg := fmt.Sprintf("error serializing topics: %s", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
w.Write(out)
return
}
topic := strings.Trim(r.URL.Path, "/")
t := mb.NewTopic(topic)
log.Debugf("request for topic %s", t.Name)
switch r.Method {
case "POST", "PUT":
defer r.Body.Close()
if r.ContentLength > int64(mb.options.MaxPayloadSize) {
msg := "payload exceeds max-payload-size"
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
}
var rd io.Reader = r.Body
ce := r.Header.Get("Content-Encoding")
switch ce {
case "":
case "br":
rd = brotli.NewReader(rd)
case "gzip":
gz, err := gzip.NewReader(rd)
if err != nil {
msg := fmt.Sprintf("error reading payload: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return
}
defer gz.Close()
rd = gz
case "deflate":
fl := flate.NewReader(rd)
defer fl.Close()
rd = fl
default:
msg := fmt.Sprintf("error reading payload: not acceptable: %v", ce)
http.Error(w, msg, http.StatusNotAcceptable)
return
}
body, err := ioutil.ReadAll(rd)
if err != nil {
msg := fmt.Sprintf("error reading payload: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return
}
if len(body) > mb.options.MaxPayloadSize {
msg := "payload exceeds max-payload-size"
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
}
mb.Put(mb.NewMessage(t, body))
w.WriteHeader(http.StatusAccepted)
case "GET":
if r.Header.Get("Upgrade") == "websocket" {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("error creating websocket client: %s", err)
return
}
i := SafeParseInt64(r.URL.Query().Get("index"), -1)
log.Debugf("new subscriber for %s from %s", t.Name, r.RemoteAddr)
NewClient(conn, t, i, mb).Start()
return
}
message, ok := mb.Get(t)
if !ok {
http.Error(w, "No Messages", http.StatusNoContent)
return
}
out, err := json.Marshal(message)
if err != nil {
msg := fmt.Sprintf("error serializing message: %s", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
w.Write(out)
case "DELETE":
http.Error(w, "Not Implemented", http.StatusNotImplemented)
// TODO: Implement deleting topics
}
}
// Client ...
type Client struct {
conn *websocket.Conn
topic *Topic
index int64
bus *MessageBus
id string
ch chan Message
}
// NewClient ...
func NewClient(conn *websocket.Conn, topic *Topic, index int64, bus *MessageBus) *Client {
return &Client{
conn: conn,
topic: topic,
index: index,
bus: bus,
id: MustGenerateULID(),
}
}
func (c *Client) readPump() {
defer func() {
c.conn.Close()
}()
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(message string) error {
t, err := strconv.ParseInt(message, 10, 64)
d := time.Duration(time.Now().UnixNano() - t)
if err != nil {
log.Warnf("garbage pong reply from %s: %s", c.id, err)
} else {
log.Debugf("pong latency of %s: %s", c.id, d)
}
c.conn.SetReadDeadline(time.Now().Add(pongWait))
if c.bus.metrics != nil {
v := c.bus.metrics.Summary("client", "latency_seconds")
v.Observe(d.Seconds())
}
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
c.bus.Unsubscribe(c.id, c.topic.Name)
return
}
log.Debugf("recieved message from %s: %s", c.id, message)
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
var err error
for {
select {
case msg, ok := <-c.ch:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// The bus closed the channel.
message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bus closed")
c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait))
return
}
err = c.conn.WriteJSON(msg)
if err != nil {
// TODO: Retry? Put the message back in the queue?
log.Errorf("Error sending msg to %s: %s", c.id, err)
if c.bus.metrics != nil {
c.bus.metrics.Counter("client", "errors").Inc()
}
} else {
if c.bus.metrics != nil {
c.bus.metrics.Counter("bus", "delivered").Inc()
}
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
t := time.Now()
message := []byte(fmt.Sprintf("%d", t.UnixNano()))
if err := c.conn.WriteMessage(websocket.PingMessage, message); err != nil {
log.Errorf("error sending ping to %s: %s", c.id, err)
return
}
}
}
}
// Start ...
func (c *Client) Start() {
c.ch = c.bus.Subscribe(c.id, c.topic.Name, WithIndex(c.index))
c.conn.SetCloseHandler(func(code int, text string) error {
c.bus.Unsubscribe(c.id, c.topic.Name)
message := websocket.FormatCloseMessage(code, text)
c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait))
return nil
})
go c.writePump()
go c.readPump()
}