Add support for a write-ahead-log (WAL) to persist messages (#33)

Closes #31

Adds support for a write-ahead-log (WAL) for messages per queue/topic. This is now the new default behaviour and adds a new CLI flag `-l/--log-path` and Env var `LOG_PATH` to configure where the logs are stored.

On startup, the message bus will refill the queues with the contents of messages from persisted log files with the most recent `-Q/--max-queue-size` number of items.

That is, on startup/crash the queues/topics will always contain the same messages as if the message bus had never restarted or crashed in the first place.

This has a benefit of actually making the per-topic sequence number _actually_ monotic increasing integers and something that can be relied upon when indexing into a queue/topic for subscribers with the `-i/--index` / `Index` option.

Co-authored-by: James Mills <prologic@shortcircuit.net.au>
Reviewed-on: https://git.mills.io/prologic/msgbus/pulls/33
This commit is contained in:
James Mills 2022-04-03 15:59:38 +00:00
부모 95505d5e2b
커밋 6bfb669347
14개의 변경된 파일693개의 추가작업 그리고 178개의 파일을 삭제

3
.gitignore vendored
파일 보기

@ -6,7 +6,8 @@
**/.DS_Store
/dist
/coverage.txt
/logs
/coverage.*
/msgbus
/msgbusd
/cmd/msgbus/msgbus

파일 보기

@ -139,7 +139,7 @@ func (c *Client) Publish(topic, message string) error {
}
// Subscribe ...
func (c *Client) Subscribe(topic string, index int, handler msgbus.HandlerFunc) *Subscriber {
func (c *Client) Subscribe(topic string, index int64, handler msgbus.HandlerFunc) *Subscriber {
return NewSubscriber(c, topic, index, handler)
}
@ -152,7 +152,7 @@ type Subscriber struct {
client *Client
topic string
index int
index int64
handler msgbus.HandlerFunc
@ -162,7 +162,7 @@ type Subscriber struct {
}
// NewSubscriber ...
func NewSubscriber(client *Client, topic string, index int, handler msgbus.HandlerFunc) *Subscriber {
func NewSubscriber(client *Client, topic string, index int64, handler msgbus.HandlerFunc) *Subscriber {
if handler == nil {
handler = noopHandler
}
@ -180,7 +180,7 @@ func NewSubscriber(client *Client, topic string, index int, handler msgbus.Handl
u.Path += fmt.Sprintf("/%s", topic)
q := u.Query()
q.Set("index", strconv.Itoa(index))
q.Set("index", strconv.FormatInt(index, 10))
u.RawQuery = q.Encode()
url := u.String()

파일 보기

@ -1,33 +1,41 @@
package client
import (
"io/ioutil"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.mills.io/prologic/msgbus"
)
func TestClientPublish(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := msgbus.New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := msgbus.NewMessageBus(msgbus.WithLogPath(testdir))
require.NoError(err)
defer os.RemoveAll(testdir)
server := httptest.NewServer(mb)
defer server.Close()
client := NewClient(server.URL, nil)
err := client.Publish("hello", "hello world")
err = client.Publish("hello", "hello world")
assert.NoError(err)
topic := mb.NewTopic("hello")
expected := msgbus.Message{Topic: topic, Payload: []byte("hello world")}
actual, ok := mb.Get(topic)
assert.True(ok)
assert.Equal(actual.ID, expected.ID)
assert.Equal(actual.Topic, expected.Topic)
assert.Equal(actual.Payload, expected.Payload)
msg, ok := mb.Get(topic)
require.True(ok)
assert.Equal(int64(1), msg.ID)
assert.Equal(topic, msg.Topic)
assert.Equal([]byte("hello world"), msg.Payload)
}

파일 보기

@ -46,5 +46,7 @@ func pull(client *client.Client, topic string) {
fmt.Fprintf(os.Stderr, "error reading message: %s\n", err)
os.Exit(2)
}
fmt.Printf("%s\n", msg.Payload)
if msg != nil {
fmt.Printf("%s\n", msg.Payload)
}
}

파일 보기

@ -36,7 +36,7 @@ reset to zero on message bus restarts.`,
client := client.NewClient(uri, nil)
topic := args[0]
index := viper.GetInt("index")
index := viper.GetInt64("index")
var (
command string
@ -101,7 +101,7 @@ func handler(command string, args []string) msgbus.HandlerFunc {
}
}
func subscribe(client *client.Client, topic string, index int, command string, args []string) {
func subscribe(client *client.Client, topic string, index int64, command string, args []string) {
if topic == "" {
topic = defaultTopic
}

파일 보기

@ -27,7 +27,9 @@ Valid optinos:
var (
version bool
debug bool
bind string
logPath string
bufferLength int
maxQueueSize int
@ -43,10 +45,17 @@ func init() {
}
flag.BoolVarP(&debug, "debug", "d", false, "enable debug logging")
flag.StringVarP(&bind, "bind", "b", "0.0.0.0:8000", "[int]:<port> to bind to")
flag.BoolVarP(&version, "version", "v", false, "display version information")
// Basic options
flag.StringVarP(
&bind, "bind", "b", msgbus.DefaultBind,
"[int]:<port> to bind to",
)
flag.StringVarP(
&logPath, "log-path", "l", msgbus.DefaultLogPath,
"path to write log files to (wal)",
)
flag.IntVarP(
&bufferLength, "buffer-length", "B", msgbus.DefaultBufferLength,
"set the buffer length for subscribers before messages are dropped",
@ -87,6 +96,8 @@ func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Expose-Headers", "*")
next.ServeHTTP(w, r)
})
}
@ -109,13 +120,16 @@ func main() {
go professor.Launch(":6060")
}
opts := msgbus.Options{
BufferLength: bufferLength,
MaxQueueSize: maxQueueSize,
MaxPayloadSize: maxPayloadSize,
WithMetrics: true,
mb, err := msgbus.NewMessageBus(
msgbus.WithLogPath(logPath),
msgbus.WithBufferLength(bufferLength),
msgbus.WithMaxQueueSize(maxQueueSize),
msgbus.WithMaxPayloadSize(maxPayloadSize),
msgbus.WithMetrics(true),
)
if err != nil {
log.WithError(err).Fatal("error configuring message bus")
}
mb := msgbus.New(&opts)
http.Handle("/", corsMiddleware(mb))
http.Handle("/metrics", mb.Metrics().Handler())

파일 보기

@ -7,7 +7,11 @@ import (
)
func main() {
m := msgbus.New(nil)
m, err := msgbus.NewMessageBus(nil)
if err != nil {
log.Fatal(err)
}
t := m.NewTopic("foo")
m.Put(m.NewMessage(t, []byte("Hello World!")))

8
go.mod
파일 보기

@ -4,6 +4,7 @@ go 1.18
require (
github.com/andybalholm/brotli v1.0.4
github.com/cyphar/filepath-securejoin v0.2.3
github.com/gorilla/websocket v1.5.0
github.com/jpillora/backoff v1.0.0
github.com/mitchellh/go-homedir v1.1.0
@ -15,6 +16,8 @@ require (
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.10.1
github.com/stretchr/testify v1.7.0
github.com/tidwall/wal v1.1.7
github.com/vmihailenco/msgpack/v5 v5.3.5
nhooyr.io/websocket v1.8.7
)
@ -41,6 +44,11 @@ require (
github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/tidwall/gjson v1.10.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/tinylru v1.1.0 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8 // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/protobuf v1.27.1 // indirect

17
go.sum
파일 보기

@ -61,6 +61,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cyphar/filepath-securejoin v0.2.3 h1:YX6ebbZCZP7VkM3scTTokDgBL2TY741X51MTk3ycuNI=
github.com/cyphar/filepath-securejoin v0.2.3/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -278,14 +280,29 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tidwall/gjson v1.10.2 h1:APbLGOM0rrEkd8WBw9C24nllro4ajFuJu0Sc9hRz8Bo=
github.com/tidwall/gjson v1.10.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/tinylru v1.1.0 h1:XY6IUfzVTU9rpwdhKUF6nQdChgCdGjkMfLzbWyiau6I=
github.com/tidwall/tinylru v1.1.0/go.mod h1:3+bX+TJ2baOLMWTnlyNWHh4QMnFyARg2TLTQ6OFbzw8=
github.com/tidwall/wal v1.1.7 h1:emc1TRjIVsdKKSnpwGBAcsAGg0767SvUk8+ygx7Bb+4=
github.com/tidwall/wal v1.1.7/go.mod h1:r6lR1j27W9EPalgHiB7zLJDYu3mzW5BQP5KrzBpYY/E=
github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

0
logs/.gitkeep Normal file
파일 보기

240
msgbus.go
파일 보기

@ -6,29 +6,26 @@ import (
"encoding/json"
"fmt"
"io"
"io/fs"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/andybalholm/brotli"
securejoin "github.com/cyphar/filepath-securejoin"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus"
"github.com/tidwall/wal"
msgpack "github.com/vmihailenco/msgpack/v5"
"github.com/gorilla/websocket"
)
const (
// DefaultMaxQueueSize is the default maximum size of queues
DefaultMaxQueueSize = 1024 // ~8MB per queue (1000 * 4KB)
// DefaultMaxPayloadSize is the default maximum payload size
DefaultMaxPayloadSize = 8192 // 8KB
// DefaultBufferLength is the default buffer length for subscriber chans
DefaultBufferLength = 256
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
@ -54,7 +51,7 @@ type HandlerFunc func(msg *Message) error
// Topic ...
type Topic struct {
Name string `json:"name"`
Sequence int `json:"seq"`
Sequence int64 `json:"seq"`
Created time.Time `json:"created"`
}
@ -64,22 +61,31 @@ func (t *Topic) String() string {
// Message ...
type Message struct {
ID int `json:"id"`
ID int64 `json:"id"`
Topic *Topic `json:"topic"`
Payload []byte `json:"payload"`
Created time.Time `json:"created"`
}
func LoadMessage(data []byte) (m Message, err error) {
err = msgpack.Unmarshal(data, &m)
return
}
func (m Message) Bytes() ([]byte, error) {
return msgpack.Marshal(m)
}
// SubscribeOption ...
type SubscribeOption func(*SubscriberOptions)
// SubscriberOptions ...
type SubscriberOptions struct {
Index int
Index int64
}
// WithIndex sets the index to start subscribing from
func WithIndex(index int) SubscribeOption {
func WithIndex(index int64) SubscribeOption {
return func(o *SubscriberOptions) { o.Index = index }
}
@ -195,53 +201,33 @@ func (subs *Subscribers) NotifyAll(message Message) int {
return i
}
// Options ...
type Options struct {
BufferLength int
MaxQueueSize int
MaxPayloadSize int
WithMetrics bool
}
// MessageBus ...
type MessageBus struct {
sync.RWMutex
options *Options
metrics *Metrics
bufferLength int
maxQueueSize int
maxPayloadSize int
topics map[string]*Topic
queues map[*Topic]*Queue
logs map[*Topic]*wal.Log
topics map[string]*Topic
queues map[*Topic]*Queue
subscribers map[*Topic]*Subscribers
}
// New ...
func New(options *Options) *MessageBus {
var (
bufferLength int
maxQueueSize int
maxPayloadSize int
withMetrics bool
)
// NewMessageBus creates a new message bus with the provided options
func NewMessageBus(opts ...Option) (*MessageBus, error) {
options := DefaultOptions
if options != nil {
bufferLength = options.BufferLength
maxQueueSize = options.MaxQueueSize
maxPayloadSize = options.MaxPayloadSize
withMetrics = options.WithMetrics
} else {
bufferLength = DefaultBufferLength
maxQueueSize = DefaultMaxQueueSize
maxPayloadSize = DefaultMaxPayloadSize
withMetrics = false
for _, opt := range opts {
if err := opt(options); err != nil {
return nil, err
}
}
var metrics *Metrics
if withMetrics {
if options.Metrics {
metrics = NewMetrics("msgbus")
ctime := time.Now()
@ -325,17 +311,100 @@ func New(options *Options) *MessageBus {
)
}
return &MessageBus{
mb := &MessageBus{
options: options,
metrics: metrics,
bufferLength: bufferLength,
maxQueueSize: maxQueueSize,
maxPayloadSize: maxPayloadSize,
topics: make(map[string]*Topic),
queues: make(map[*Topic]*Queue),
logs: make(map[*Topic]*wal.Log),
topics: make(map[string]*Topic),
queues: make(map[*Topic]*Queue),
subscribers: make(map[*Topic]*Subscribers),
}
if err := mb.readLogs(); err != nil {
return nil, fmt.Errorf("error reading logs: %w", err)
}
return mb, nil
}
func (mb *MessageBus) readLog(f fs.FileInfo) error {
log.Debugf("reading log %s", f.Name())
l, err := wal.Open(filepath.Join(mb.options.LogPath, f.Name()), nil)
if err != nil {
return fmt.Errorf("error opening log %s: %w", f, err)
}
t := mb.newTopic(f.Name())
t.Created = f.ModTime()
first, err := l.FirstIndex()
if err != nil {
return fmt.Errorf("error reading first index: %w", err)
}
last, err := l.LastIndex()
if err != nil {
return fmt.Errorf("error reading last index: %w", err)
}
log.Debugf("first index: %d", first)
log.Debugf("last index: %d", last)
t.Sequence = int64(last)
q := NewQueue(mb.options.MaxQueueSize)
start := int64(last) - int64(mb.options.MaxQueueSize)
if start < 0 {
start = int64(first)
}
end := int64(last)
log.Debugf("start: %d", start)
log.Debugf("end: %d", end)
for i := start; i <= end; i++ {
data, err := l.Read(uint64(i))
if err != nil {
return fmt.Errorf("error reading log %d: %w", i, err)
}
msg, err := LoadMessage(data)
if err != nil {
return fmt.Errorf("error deserialing log %d: %w", i, err)
}
q.Push(msg)
}
mb.queues[t] = q
return nil
}
func (mb *MessageBus) readLogs() error {
mb.Lock()
defer mb.Unlock()
log.Debug("reading logs...")
dirs, err := os.ReadDir(mb.options.LogPath)
if err != nil {
return fmt.Errorf("error listing logs: %w", err)
}
for _, dir := range dirs {
if dir.IsDir() {
info, err := dir.Info()
if err != nil {
return fmt.Errorf("error reading log path %s: %w", dir, err)
}
if err := mb.readLog(info); err != nil {
return err
}
}
}
return nil
}
// Len ...
@ -348,11 +417,7 @@ func (mb *MessageBus) Metrics() *Metrics {
return mb.metrics
}
// NewTopic ...
func (mb *MessageBus) NewTopic(topic string) *Topic {
mb.Lock()
defer mb.Unlock()
func (mb *MessageBus) newTopic(topic string) *Topic {
t, ok := mb.topics[topic]
if !ok {
t = &Topic{Name: topic, Created: time.Now()}
@ -364,6 +429,14 @@ func (mb *MessageBus) NewTopic(topic string) *Topic {
return t
}
// NewTopic ...
func (mb *MessageBus) NewTopic(topic string) *Topic {
mb.Lock()
defer mb.Unlock()
return mb.newTopic(topic)
}
// NewMessage ...
func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message {
defer func() {
@ -374,7 +447,7 @@ func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message {
}()
return Message{
ID: topic.Sequence,
ID: topic.Sequence + 1,
Topic: topic,
Payload: payload,
Created: time.Now(),
@ -382,15 +455,41 @@ func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message {
}
// Put ...
func (mb *MessageBus) Put(message Message) {
func (mb *MessageBus) Put(message Message) error {
mb.Lock()
defer mb.Unlock()
t := message.Topic
l, ok := mb.logs[t]
if !ok {
fn, err := securejoin.SecureJoin(mb.options.LogPath, t.Name)
if err != nil {
return fmt.Errorf("error creating logfile filename for %s: %w", t.Name, err)
}
l, err = wal.Open(fn, nil)
if err != nil {
return fmt.Errorf("error opening logfile %s: %w", fn, err)
}
mb.logs[t] = l
}
id := uint64(message.ID)
data, err := message.Bytes()
if err != nil {
return fmt.Errorf("error serializing message %d: %w", id, err)
}
if err := l.Write(id, data); err != nil {
return fmt.Errorf("error writing message %d to logfile: %w", message.ID, err)
}
q, ok := mb.queues[t]
if !ok {
q = NewQueue(mb.maxQueueSize)
mb.queues[message.Topic] = q
q = NewQueue(mb.options.MaxQueueSize)
mb.queues[t] = q
}
q.Push(message)
@ -399,6 +498,8 @@ func (mb *MessageBus) Put(message Message) {
}
mb.publish(message)
return nil
}
// Get ...
@ -453,7 +554,7 @@ func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan
subs, ok := mb.subscribers[t]
if !ok {
subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.bufferLength})
subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.options.BufferLength})
mb.subscribers[t] = subs
}
@ -482,12 +583,12 @@ func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan
// IF client requests to start from index >= 0 (-1 from the head)
// AND the topic's sequence number hasn't been reset (< Index) THEN
if o.Index >= 0 && o.Index <= t.Sequence {
if o.Index > 0 && o.Index <= t.Sequence {
var n int
log.Debugf("subscriber wants to start from %d", o.Index)
q.ForEach(func(item interface{}) error {
if msg, ok := item.(Message); ok && msg.ID >= o.Index {
log.Debugf("found #%v", msg)
log.Debugf("found msg #%d", msg.ID)
ch <- msg
n++
}
@ -496,6 +597,7 @@ func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan
log.Debugf("published %d messages", n)
} else {
// ELSE start from the beginning (invalid Index or Topic was reset)
// NB: This should not happen with write-ahead-logs (WAL)
var n int
log.Debugf("subscriber wanted to start from invalid %d (topic is at %d)", o.Index, t.Sequence)
q.ForEach(func(item interface{}) error {
@ -564,13 +666,13 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
topic := strings.Trim(r.URL.Path, "/")
t := mb.NewTopic(topic)
log.Debugf("request for topic %#v", t.Name)
log.Debugf("request for topic %s", t.Name)
switch r.Method {
case "POST", "PUT":
defer r.Body.Close()
if r.ContentLength > int64(mb.maxPayloadSize) {
if r.ContentLength > int64(mb.options.MaxPayloadSize) {
msg := "payload exceeds max-payload-size"
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
@ -607,7 +709,7 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, msg, http.StatusBadRequest)
return
}
if len(body) > mb.maxPayloadSize {
if len(body) > mb.options.MaxPayloadSize {
msg := "payload exceeds max-payload-size"
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
@ -623,7 +725,7 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
i := SafeParseInt(r.URL.Query().Get("index"), -1)
i := SafeParseInt64(r.URL.Query().Get("index"), -1)
log.Debugf("new subscriber for %s from %s", t.Name, r.RemoteAddr)
@ -658,7 +760,7 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
type Client struct {
conn *websocket.Conn
topic *Topic
index int
index int64
bus *MessageBus
id string
@ -666,7 +768,7 @@ type Client struct {
}
// NewClient ...
func NewClient(conn *websocket.Conn, topic *Topic, index int, bus *MessageBus) *Client {
func NewClient(conn *websocket.Conn, topic *Topic, index int64, bus *MessageBus) *Client {
return &Client{conn: conn, topic: topic, index: index, bus: bus}
}

파일 보기

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"flag"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
@ -22,28 +23,53 @@ var (
)
func TestMessageBusLen(t *testing.T) {
mb := New(nil)
assert.Equal(t, mb.Len(), 0)
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
assert.Equal(0, mb.Len())
}
func TestMessage(t *testing.T) {
mb := New(nil)
assert.Equal(t, mb.Len(), 0)
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
assert.Equal(0, mb.Len())
topic := mb.NewTopic("foo")
expected := mb.NewMessage(topic, []byte("bar"))
mb.Put(expected)
err = mb.Put(expected)
require.NoError(err)
actual, ok := mb.Get(topic)
assert.True(t, ok)
assert.Equal(t, actual, expected)
require.True(ok)
assert.Equal(expected, actual)
}
func TestMessageIds(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
assert.Equal(0, mb.Len())
topic := mb.NewTopic("foo")
@ -57,48 +83,58 @@ func TestMessageIds(t *testing.T) {
mb.Put(mb.NewMessage(topic, []byte("bar")))
msg, ok := mb.Get(topic)
require.True(ok)
assert.Equal(msg.ID, 1)
assert.Equal(msg.ID, int64(2))
}
func TestMessageGetEmpty(t *testing.T) {
mb := New(nil)
assert.Equal(t, mb.Len(), 0)
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
assert.Equal(0, mb.Len())
topic := mb.NewTopic("foo")
msg, ok := mb.Get(topic)
assert.False(t, ok)
assert.Equal(t, msg, Message{})
require.False(ok)
assert.Equal(Message{}, msg)
}
func TestMessageBusPutGet(t *testing.T) {
mb := New(nil)
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
topic := mb.NewTopic("foo")
expected := mb.NewMessage(topic, []byte("foo"))
mb.Put(expected)
actual, ok := mb.Get(topic)
assert.True(t, ok)
assert.Equal(t, actual, expected)
require.True(ok)
assert.Equal(expected, actual)
}
func TestMessageBusSubscribe(t *testing.T) {
mb := New(nil)
msgs := mb.Subscribe("id1", "foo")
topic := mb.NewTopic("foo")
expected := mb.NewMessage(topic, []byte("foo"))
mb.Put(expected)
actual := <-msgs
assert.Equal(t, actual, expected)
}
func TestMessageBusSubscribeWithIndex(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
msgs := mb.Subscribe("id1", "foo")
@ -108,23 +144,98 @@ func TestMessageBusSubscribeWithIndex(t *testing.T) {
actual := <-msgs
assert.Equal(expected, actual)
assert.Equal(0, actual.ID)
}
func TestMessageBusSubscribeWithIndex(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
msgs := mb.Subscribe("id1", "foo")
topic := mb.NewTopic("foo")
expected := mb.NewMessage(topic, []byte("foo")) // ID == 1
mb.Put(expected)
actual := <-msgs
assert.Equal(expected, actual)
assert.Equal(int64(1), actual.ID)
mb.Unsubscribe("id1", "foo")
mb.Put(mb.NewMessage(topic, []byte("bar"))) // ID == 1
mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 2
mb.Put(mb.NewMessage(topic, []byte("bar"))) // ID == 2
mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 3
msgs = mb.Subscribe("id1", "foo", WithIndex(1))
assert.Equal([]byte("foo"), (<-msgs).Payload)
assert.Equal([]byte("bar"), (<-msgs).Payload)
assert.Equal([]byte("baz"), (<-msgs).Payload)
}
func TestMessageBusWAL(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
msgs := mb.Subscribe("id1", "hello")
topic := mb.NewTopic("hello")
mb.Put(mb.NewMessage(topic, []byte("foo"))) // ID == 1
mb.Put(mb.NewMessage(topic, []byte("bar"))) // ID == 2
mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 3
assert.Equal([]byte("foo"), (<-msgs).Payload)
assert.Equal([]byte("bar"), (<-msgs).Payload)
assert.Equal([]byte("baz"), (<-msgs).Payload)
assert.Equal(int64(3), topic.Sequence)
mb.Unsubscribe("id1", "foo")
// Now ensure when we start back up we've re-filled the queues and retain the same
// message ids and topic sequence number
mb, err = NewMessageBus(WithLogPath(testdir))
require.NoError(err)
// we have to tell the bus we want to subscribe from the start
msgs = mb.Subscribe("id1", "hello", WithIndex(1))
topic = mb.NewTopic("hello")
assert.Equal(int64(3), topic.Sequence)
assert.Equal([]byte("foo"), (<-msgs).Payload)
assert.Equal([]byte("bar"), (<-msgs).Payload)
assert.Equal([]byte("baz"), (<-msgs).Payload)
msg := mb.NewMessage(topic, []byte("foobar"))
assert.Equal(int64(4), msg.ID)
}
func TestServeHTTPGETIndexEmpty(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
mb := New(nil)
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
@ -135,8 +246,14 @@ func TestServeHTTPGETIndexEmpty(t *testing.T) {
func TestServeHTTPGETTopics(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
mb.Put(mb.NewMessage(mb.NewTopic("foo"), []byte("foo")))
mb.Put(mb.NewMessage(mb.NewTopic("hello"), []byte("hello world")))
@ -152,8 +269,15 @@ func TestServeHTTPGETTopics(t *testing.T) {
func TestServeHTTPGETEmptyQueue(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
mb := New(nil)
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/hello", nil)
@ -163,8 +287,15 @@ func TestServeHTTPGETEmptyQueue(t *testing.T) {
func TestServeHTTPPOST(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
mb := New(nil)
w := httptest.NewRecorder()
b := bytes.NewBufferString("hello world")
r, _ := http.NewRequest("POST", "/hello", b)
@ -175,8 +306,15 @@ func TestServeHTTPPOST(t *testing.T) {
func TestServeHTTPMaxPayloadSize(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
mb := New(nil)
w := httptest.NewRecorder()
b := bytes.NewBuffer(bytes.Repeat([]byte{'X'}, (DefaultMaxPayloadSize * 2)))
r, _ := http.NewRequest("POST", "/hello", b)
@ -188,8 +326,14 @@ func TestServeHTTPMaxPayloadSize(t *testing.T) {
func TestServeHTTPSimple(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
w := httptest.NewRecorder()
b := bytes.NewBufferString("hello world")
@ -206,29 +350,62 @@ func TestServeHTTPSimple(t *testing.T) {
var msg *Message
json.Unmarshal(w.Body.Bytes(), &msg)
assert.Equal(msg.ID, 0)
assert.Equal(msg.Topic.Name, "hello")
assert.Equal(msg.Payload, []byte("hello world"))
assert.Equal(int64(1), msg.ID)
assert.Equal("hello", msg.Topic.Name, "hello")
assert.Equal([]byte("hello world"), msg.Payload)
}
func BenchmarkServeHTTPPOST(b *testing.B) {
mb := New(nil)
func BenchmarkServeHTTP_POST(b *testing.B) {
require := require.New(b)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
b := bytes.NewBufferString("hello world")
r, _ := http.NewRequest("POST", "/hello", b)
b.Run("Sync", func(b *testing.B) {
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
b := bytes.NewBufferString("hello world")
r, _ := http.NewRequest("POST", "/hello", b)
mb.ServeHTTP(w, r)
}
})
b.Run("NoSync", func(b *testing.B) {
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir), WithNoSync(true))
require.NoError(err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
b := bytes.NewBufferString("hello world")
r, _ := http.NewRequest("POST", "/hello", b)
mb.ServeHTTP(w, r)
}
})
mb.ServeHTTP(w, r)
}
}
func TestServeHTTPSubscriber(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
s := httptest.NewServer(mb)
defer s.Close()
@ -263,16 +440,21 @@ func TestServeHTTPSubscriber(t *testing.T) {
defer r.Body.Close()
msg := <-msgs
assert.Equal(msg.ID, 0)
assert.Equal(msg.Topic.Name, "hello")
assert.Equal(msg.Payload, []byte("hello world"))
assert.Equal(int64(1), msg.ID)
assert.Equal("hello", msg.Topic.Name)
assert.Equal([]byte("hello world"), msg.Payload)
}
func TestServeHTTPSubscriberReconnect(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
s := httptest.NewServer(mb)
@ -309,34 +491,71 @@ func TestServeHTTPSubscriberReconnect(t *testing.T) {
defer r.Body.Close()
msg := <-msgs
assert.Equal(msg.ID, 0)
assert.Equal(msg.Topic.Name, "hello")
assert.Equal(msg.Payload, []byte("hello world"))
assert.Equal(int64(1), msg.ID)
assert.Equal("hello", msg.Topic.Name)
assert.Equal([]byte("hello world"), msg.Payload)
}
func TestMsgBusMetrics(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
opts := Options{
WithMetrics: true,
}
mb := New(&opts)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
assert.IsType(&Metrics{}, mb.Metrics())
}
func BenchmarkMessageBusPut(b *testing.B) {
mb := New(nil)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
mb.Put(msg)
}
require := require.New(b)
b.Run("Sync", func(b *testing.B) {
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
mb.Put(msg)
}
})
b.Run("NoSync", func(b *testing.B) {
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir), WithNoSync(true))
require.NoError(err)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
mb.Put(msg)
}
})
}
func BenchmarkMessageBusGet(b *testing.B) {
mb := New(nil)
require := require.New(b)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
for i := 0; i < b.N; i++ {
@ -349,7 +568,15 @@ func BenchmarkMessageBusGet(b *testing.B) {
}
func BenchmarkMessageBusGetEmpty(b *testing.B) {
mb := New(nil)
require := require.New(b)
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
topic := mb.NewTopic("foo")
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -358,14 +585,42 @@ func BenchmarkMessageBusGetEmpty(b *testing.B) {
}
func BenchmarkMessageBusPutGet(b *testing.B) {
mb := New(nil)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
mb.Put(msg)
mb.Get(topic)
}
require := require.New(b)
b.Run("Sync", func(b *testing.B) {
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir))
require.NoError(err)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
mb.Put(msg)
mb.Get(topic)
}
})
b.Run("NoSync", func(b *testing.B) {
testdir, err := ioutil.TempDir("", "msgbus-logs-*")
require.NoError(err)
defer os.RemoveAll(testdir)
mb, err := NewMessageBus(WithLogPath(testdir), WithNoSync(true))
require.NoError(err)
topic := mb.NewTopic("foo")
msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
mb.Put(msg)
mb.Get(topic)
}
})
}
func TestMain(m *testing.M) {

104
options.go Normal file
파일 보기

@ -0,0 +1,104 @@
package msgbus
import (
"fmt"
"os"
"github.com/tidwall/wal"
)
const (
// DefaultBind is the default bind address
DefaultBind = ":8000"
// DefaultLogPath is the default path to write logs to (wal)
DefaultLogPath = "./logs"
// DefaultMaxQueueSize is the default maximum size of queues
DefaultMaxQueueSize = 1024 // ~8MB per queue (1000 * 4KB)
// DefaultMaxPayloadSize is the default maximum payload size
DefaultMaxPayloadSize = 8192 // 8KB
// DefaultBufferLength is the default buffer length for subscriber chans
DefaultBufferLength = 256
// DefaultMetrics is the default for whether to enable metrics
DefaultMetrics = false
// DefaultNoSync is the default for whether to disable faync after writing
// messages to the write-ahead-log (wal) files. The default is `false` which
// is safer and will prevent corruption in event of crahses or power failure,
// but is slower.
DefaultNoSync = false
)
var DefaultOptions = &Options{
LogPath: DefaultLogPath,
BufferLength: DefaultBufferLength,
MaxQueueSize: DefaultMaxQueueSize,
MaxPayloadSize: DefaultMaxPayloadSize,
Metrics: DefaultMetrics,
NoSync: DefaultNoSync,
}
// Options ...
type Options struct {
LogPath string
BufferLength int
MaxQueueSize int
MaxPayloadSize int
Metrics bool
NoSync bool
}
type Option func(opts *Options) error
func WithLogPath(logPath string) Option {
return func(opts *Options) error {
if err := os.MkdirAll(logPath, 0755); err != nil {
return fmt.Errorf("error creating log path %s: %w", logPath, err)
}
opts.LogPath = logPath
return nil
}
}
func WithBufferLength(bufferLength int) Option {
return func(opts *Options) error {
opts.BufferLength = bufferLength
return nil
}
}
func WithMaxQueueSize(maxQueueSize int) Option {
return func(opts *Options) error {
opts.MaxQueueSize = maxQueueSize
return nil
}
}
func WithMaxPayloadSize(maxPayloadSize int) Option {
return func(opts *Options) error {
opts.MaxPayloadSize = maxPayloadSize
return nil
}
}
func WithMetrics(metrics bool) Option {
return func(opts *Options) error {
opts.Metrics = metrics
return nil
}
}
func WithNoSync(noSync bool) Option {
return func(opts *Options) error {
wal.DefaultOptions.NoSync = noSync
return nil
}
}

파일 보기

@ -2,9 +2,9 @@ package msgbus
import "strconv"
// SafeParseInt ...
func SafeParseInt(s string, d int) int {
n, e := strconv.Atoi(s)
// SafeParseInt64 ...
func SafeParseInt64(s string, d int64) int64 {
n, e := strconv.ParseInt(s, 10, 64)
if e != nil {
return d
}