Add support for subscribers to start from an index (#26)

Closes #20

Co-authored-by: James Mills <prologic@shortcircuit.net.au>
Co-authored-by: xuu <xuu@noreply@mills.io>
Reviewed-on: https://git.mills.io/prologic/msgbus/pulls/26
This commit is contained in:
James Mills 2022-04-02 14:05:15 +00:00
parent 4bbe613486
commit 7b71102aa8
14 changed files with 352 additions and 121 deletions

View File

@ -1,19 +1,35 @@
---
kind: pipeline kind: pipeline
name: default name: default
steps: steps:
- name: build - name: build-and-test
image: golang:latest image: r.mills.io/prologic/golang-alpine:latest
commands: commands:
- make build
- make test - make test
- name: coverage - name: build-image-push
image: plugins/codecov image: plugins/kaniko
settings: settings:
token: repo: prologic/msgbus
from_secret: codecov-token tags: latest
build_args:
- VERSION=latest
- COMMIT=${DRONE_COMMIT_SHA:0:8}
username:
from_secret: dockerhub_username
password:
from_secret: dockerhub_password
depends_on:
- build-and-test
when:
branch:
- master
event:
- push
- name: notify - name: notify-irc
image: plugins/webhook image: plugins/webhook
settings: settings:
urls: urls:
@ -22,3 +38,11 @@ steps:
status: status:
- success - success
- failure - failure
trigger:
branch:
- master
event:
- tag
- push
- pull_request

View File

@ -35,7 +35,7 @@ cli:
-ldflags "-w \ -ldflags "-w \
-X $(shell go list).Version=$(VERSION) \ -X $(shell go list).Version=$(VERSION) \
-X $(shell go list).Commit=$(COMMIT)" \ -X $(shell go list).Commit=$(COMMIT)" \
./cmd/msgbus/ ./cmd/msgbus/...
server: generate server: generate
@$(GOCMD) build $(FLAGS) -tags "netgo static_build" -installsuffix netgo \ @$(GOCMD) build $(FLAGS) -tags "netgo static_build" -installsuffix netgo \

View File

@ -1,11 +1,7 @@
# msgbus # msgbus
[![Build Status](https://cloud.drone.io/api/badges/prologic/msgbus/status.svg)](https://cloud.drone.io/prologic/msgbus) [![Build Status](https://ci.mills.io/api/badges/prologic/msgbus/status.svg)](https://ci.mills.io/prologic/msgbus)
[![CodeCov](https://codecov.io/gh/prologic/msgbus/branch/master/graph/badge.svg)](https://codecov.io/gh/prologic/msgbus) [![Go Reference](https://pkg.go.dev/git.mills.io/prologic/msgbus?status.svg)](https://pkg.go.dev/git.mills.io/prologic/msgbus)
[![Go Report Card](https://goreportcard.com/badge/prologic/msgbus)](https://goreportcard.com/report/prologic/msgbus)
[![GoDoc](https://godoc.org/git.mills.io/prologic/msgbus?status.svg)](https://godoc.org/git.mills.io/prologic/msgbus)
[![GitHub license](https://img.shields.io/github/license/prologic/msgbus.svg)](https://git.mills.io/prologic/msgbus)
[![Sourcegraph](https://sourcegraph.com/git.mills.io/prologic/msgbus/-/badge.svg)](https://sourcegraph.com/git.mills.io/prologic/msgbus?badge)
A real-time message bus server and library written in Go. A real-time message bus server and library written in Go.
@ -20,7 +16,7 @@ A real-time message bus server and library written in Go.
## Install ## Install
```#!bash ```#!bash
$ go install git.mills.io/prologic/msgbus/... $ go install git.mills.io/prologic/msgbus/cmd/...
``` ```
## Use Cases ## Use Cases

View File

@ -7,11 +7,12 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/jpillora/backoff" "github.com/jpillora/backoff"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wsjson"
@ -138,8 +139,8 @@ func (c *Client) Publish(topic, message string) error {
} }
// Subscribe ... // Subscribe ...
func (c *Client) Subscribe(topic string, handler msgbus.HandlerFunc) *Subscriber { func (c *Client) Subscribe(topic string, index int, handler msgbus.HandlerFunc) *Subscriber {
return NewSubscriber(c, topic, handler) return NewSubscriber(c, topic, index, handler)
} }
// Subscriber ... // Subscriber ...
@ -150,7 +151,9 @@ type Subscriber struct {
client *Client client *Client
topic string topic string
index int
handler msgbus.HandlerFunc handler msgbus.HandlerFunc
url string url string
@ -159,7 +162,7 @@ type Subscriber struct {
} }
// NewSubscriber ... // NewSubscriber ...
func NewSubscriber(client *Client, topic string, handler msgbus.HandlerFunc) *Subscriber { func NewSubscriber(client *Client, topic string, index int, handler msgbus.HandlerFunc) *Subscriber {
if handler == nil { if handler == nil {
handler = noopHandler handler = noopHandler
} }
@ -176,12 +179,16 @@ func NewSubscriber(client *Client, topic string, handler msgbus.HandlerFunc) *Su
} }
u.Path += fmt.Sprintf("/%s", topic) u.Path += fmt.Sprintf("/%s", topic)
q := u.Query()
q.Set("index", strconv.Itoa(index))
u.RawQuery = q.Encode()
url := u.String() url := u.String()
return &Subscriber{ return &Subscriber{
client: client, client: client,
topic: topic, topic: topic,
index: index,
handler: handler, handler: handler,
url: url, url: url,

View File

@ -23,13 +23,20 @@ var subCmd = &cobra.Command{
Short: "Subscribe to a topic", Short: "Subscribe to a topic",
Long: `This subscribes to the given topic and for every message published Long: `This subscribes to the given topic and for every message published
to the topic, the message is printed to standard output (default) or the to the topic, the message is printed to standard output (default) or the
supplied command is executed with the contents of the message as stdin.`, supplied command is executed with the contents of the message as stdin.
If the -i/--index option is supplied with a valid value (>= 0) then the
subscription will start from that position in the topic's sequence
(which are monotonic increasing integers). It is the responsibility of
the client to keep track of its last index into a topic and indexes
reset to zero on message bus restarts.`,
Args: cobra.MinimumNArgs(1), Args: cobra.MinimumNArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
uri := viper.GetString("uri") uri := viper.GetString("uri")
client := client.NewClient(uri, nil) client := client.NewClient(uri, nil)
topic := args[0] topic := args[0]
index := viper.GetInt("index")
var ( var (
command string command string
@ -40,12 +47,20 @@ supplied command is executed with the contents of the message as stdin.`,
args = args[2:] args = args[2:]
} }
subscribe(client, topic, command, args) subscribe(client, topic, index, command, args)
}, },
} }
func init() { func init() {
RootCmd.AddCommand(subCmd) RootCmd.AddCommand(subCmd)
subCmd.Flags().IntP(
"index", "i", -1,
"position in the topic's sequence to start subscribing from (-1 indicates end)",
)
viper.BindPFlag("index", subCmd.Flags().Lookup("index"))
viper.SetDefault("index", -1)
} }
func handler(command string, args []string) msgbus.HandlerFunc { func handler(command string, args []string) msgbus.HandlerFunc {
@ -86,12 +101,12 @@ func handler(command string, args []string) msgbus.HandlerFunc {
} }
} }
func subscribe(client *client.Client, topic, command string, args []string) { func subscribe(client *client.Client, topic string, index int, command string, args []string) {
if topic == "" { if topic == "" {
topic = defaultTopic topic = defaultTopic
} }
s := client.Subscribe(topic, handler(command, args)) s := client.Subscribe(topic, index, handler(command, args))
s.Start() s.Start()
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)

View File

@ -42,7 +42,7 @@ func init() {
flag.PrintDefaults() flag.PrintDefaults()
} }
flag.BoolVarP(&debug, "debug", "D", false, "enable debug logging") flag.BoolVarP(&debug, "debug", "d", false, "enable debug logging")
flag.StringVarP(&bind, "bind", "b", "0.0.0.0:8000", "[int]:<port> to bind to") flag.StringVarP(&bind, "bind", "b", "0.0.0.0:8000", "[int]:<port> to bind to")
flag.BoolVarP(&version, "version", "v", false, "display version information") flag.BoolVarP(&version, "version", "v", false, "display version information")

2
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-homedir v1.1.0
github.com/mmcloughlin/professor v0.0.0-20170922221822-6b97112ab8b3 github.com/mmcloughlin/professor v0.0.0-20170922221822-6b97112ab8b3
github.com/prometheus/client_golang v1.12.1 github.com/prometheus/client_golang v1.12.1
github.com/sasha-s/go-deadlock v0.3.1
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.4.0 github.com/spf13/cobra v1.4.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
@ -31,6 +32,7 @@ require (
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.4.3 // indirect github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect github.com/pelletier/go-toml v1.9.4 // indirect
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/common v0.32.1 // indirect

4
go.sum
View File

@ -221,6 +221,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -251,6 +253,8 @@ github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0=
github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=

View File

@ -3,8 +3,8 @@ package msgbus
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"sync"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"

199
msgbus.go
View File

@ -10,10 +10,10 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/andybalholm/brotli" "github.com/andybalholm/brotli"
sync "github.com/sasha-s/go-deadlock"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -54,7 +54,7 @@ type HandlerFunc func(msg *Message) error
// Topic ... // Topic ...
type Topic struct { type Topic struct {
Name string `json:"name"` Name string `json:"name"`
Sequence uint64 `json:"seq"` Sequence int `json:"seq"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
} }
@ -64,19 +64,32 @@ func (t *Topic) String() string {
// Message ... // Message ...
type Message struct { type Message struct {
ID uint64 `json:"id"` ID int `json:"id"`
Topic *Topic `json:"topic"` Topic *Topic `json:"topic"`
Payload []byte `json:"payload"` Payload []byte `json:"payload"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
} }
// ListenerOptions ... // SubscribeOption ...
type ListenerOptions struct { type SubscribeOption func(*SubscriberOptions)
// SubscriberOptions ...
type SubscriberOptions struct {
Index int
}
// WithIndex sets the index to start subscribing from
func WithIndex(index int) SubscribeOption {
return func(o *SubscriberOptions) { o.Index = index }
}
// SubscribersConfig ...
type SubscriberConfig struct {
BufferLength int BufferLength int
} }
// Listeners ... // Subscribers ...
type Listeners struct { type Subscribers struct {
sync.RWMutex sync.RWMutex
buflen int buflen int
@ -85,19 +98,19 @@ type Listeners struct {
chs map[string]chan Message chs map[string]chan Message
} }
// NewListeners ... // NewSubscribers ...
func NewListeners(options *ListenerOptions) *Listeners { func NewSubscribers(config *SubscriberConfig) *Subscribers {
var ( var (
bufferLength int bufferLength int
) )
if options != nil { if config != nil {
bufferLength = options.BufferLength bufferLength = config.BufferLength
} else { } else {
bufferLength = DefaultBufferLength bufferLength = DefaultBufferLength
} }
return &Listeners{ return &Subscribers{
buflen: bufferLength, buflen: bufferLength,
ids: make(map[string]bool), ids: make(map[string]bool),
@ -105,50 +118,57 @@ func NewListeners(options *ListenerOptions) *Listeners {
} }
} }
// Length ... // Len ...
func (ls *Listeners) Length() int { func (subs *Subscribers) Len() int {
ls.RLock() subs.RLock()
defer ls.RUnlock() defer subs.RUnlock()
return len(ls.ids) return len(subs.ids)
} }
// Add ... // AddSubscriber ...
func (ls *Listeners) Add(id string) chan Message { func (subs *Subscribers) AddSubscriber(id string) chan Message {
ls.Lock() subs.Lock()
defer ls.Unlock() defer subs.Unlock()
ls.ids[id] = true if subs.buflen <= 0 {
ls.chs[id] = make(chan Message, ls.buflen) log.Fatal("subscriber buflen <= 0")
return ls.chs[id] }
ch := make(chan Message, subs.buflen)
subs.ids[id] = true
subs.chs[id] = ch
return ch
} }
// Remove ... // RemoveSubscriber ...
func (ls *Listeners) Remove(id string) { func (subs *Subscribers) RemoveSubscriber(id string) {
ls.Lock() subs.Lock()
defer ls.Unlock() defer subs.Unlock()
delete(ls.ids, id) delete(subs.ids, id)
close(ls.chs[id]) close(subs.chs[id])
delete(ls.chs, id) delete(subs.chs, id)
} }
// Exists ... // HasSubscriber ...
func (ls *Listeners) Exists(id string) bool { func (subs *Subscribers) HasSubscriber(id string) bool {
ls.RLock() subs.RLock()
defer ls.RUnlock() defer subs.RUnlock()
_, ok := ls.ids[id] _, ok := subs.ids[id]
return ok return ok
} }
// Get ... // GetSubscriber ...
func (ls *Listeners) Get(id string) (chan Message, bool) { func (subs *Subscribers) GetSubscriber(id string) (chan Message, bool) {
ls.RLock() subs.RLock()
defer ls.RUnlock() defer subs.RUnlock()
ch, ok := ls.chs[id] ch, ok := subs.chs[id]
if !ok { if !ok {
return nil, false return nil, false
} }
@ -156,12 +176,12 @@ func (ls *Listeners) Get(id string) (chan Message, bool) {
} }
// NotifyAll ... // NotifyAll ...
func (ls *Listeners) NotifyAll(message Message) int { func (subs *Subscribers) NotifyAll(message Message) int {
ls.RLock() subs.RLock()
defer ls.RUnlock() defer subs.RUnlock()
i := 0 i := 0
for id, ch := range ls.chs { for id, ch := range subs.chs {
select { select {
case ch <- message: case ch <- message:
i++ i++
@ -193,9 +213,9 @@ type MessageBus struct {
maxQueueSize int maxQueueSize int
maxPayloadSize int maxPayloadSize int
topics map[string]*Topic topics map[string]*Topic
queues map[*Topic]*Queue queues map[*Topic]*Queue
listeners map[*Topic]*Listeners subscribers map[*Topic]*Subscribers
} }
// New ... // New ...
@ -312,9 +332,9 @@ func New(options *Options) *MessageBus {
maxQueueSize: maxQueueSize, maxQueueSize: maxQueueSize,
maxPayloadSize: maxPayloadSize, maxPayloadSize: maxPayloadSize,
topics: make(map[string]*Topic), topics: make(map[string]*Topic),
queues: make(map[*Topic]*Queue), queues: make(map[*Topic]*Queue),
listeners: make(map[*Topic]*Listeners), subscribers: make(map[*Topic]*Subscribers),
} }
} }
@ -406,20 +426,22 @@ func (mb *MessageBus) Get(t *Topic) (Message, bool) {
// publish ... // publish ...
func (mb *MessageBus) publish(message Message) { func (mb *MessageBus) publish(message Message) {
ls, ok := mb.listeners[message.Topic] subs, ok := mb.subscribers[message.Topic]
if !ok { if !ok {
log.Debugf("no subscribers for %s", message.Topic.Name)
return return
} }
n := ls.NotifyAll(message) log.Debug("notifying subscribers")
if n != ls.Length() && mb.metrics != nil { n := subs.NotifyAll(message)
log.Warnf("%d/%d subscribers notified", n, ls.Length()) if n != subs.Len() && mb.metrics != nil {
mb.metrics.Counter("bus", "dropped").Add(float64(ls.Length() - n)) log.Warnf("%d/%d subscribers notified", n, subs.Len())
mb.metrics.Counter("bus", "dropped").Add(float64(subs.Len() - n))
} }
} }
// Subscribe ... // Subscribe ...
func (mb *MessageBus) Subscribe(id, topic string) chan Message { func (mb *MessageBus) Subscribe(id, topic string, opts ...SubscribeOption) chan Message {
mb.Lock() mb.Lock()
defer mb.Unlock() defer mb.Unlock()
@ -429,15 +451,16 @@ func (mb *MessageBus) Subscribe(id, topic string) chan Message {
mb.topics[topic] = t mb.topics[topic] = t
} }
ls, ok := mb.listeners[t] subs, ok := mb.subscribers[t]
if !ok { if !ok {
ls = NewListeners(&ListenerOptions{BufferLength: mb.bufferLength}) subs = NewSubscribers(&SubscriberConfig{BufferLength: mb.bufferLength})
mb.listeners[t] = ls mb.subscribers[t] = subs
} }
if ls.Exists(id) { if subs.HasSubscriber(id) {
// Already verified the listener exists // Already verified the listener exists
ch, _ := ls.Get(id) log.Debugf("already have subscriber %s", id)
ch, _ := subs.GetSubscriber(id)
return ch return ch
} }
@ -445,7 +468,34 @@ func (mb *MessageBus) Subscribe(id, topic string) chan Message {
mb.metrics.Gauge("bus", "subscribers").Inc() mb.metrics.Gauge("bus", "subscribers").Inc()
} }
return ls.Add(id) o := &SubscriberOptions{}
for _, opt := range opts {
opt(o)
}
ch := subs.AddSubscriber(id)
q, ok := mb.queues[t]
if !ok {
log.Debug("nothing in queue, returning ch")
return ch
}
if o.Index >= 0 && o.Index <= q.Len() {
var n int
log.Debugf("subscriber wants to start from %d", o.Index)
q.ForEach(func(item interface{}) error {
msg := item.(Message)
log.Debugf("found #%v", msg)
if msg.ID >= o.Index {
ch <- msg
n++
}
return nil
})
log.Debugf("published %d messages", n)
}
return ch
} }
// Unsubscribe ... // Unsubscribe ...
@ -458,14 +508,14 @@ func (mb *MessageBus) Unsubscribe(id, topic string) {
return return
} }
ls, ok := mb.listeners[t] subs, ok := mb.subscribers[t]
if !ok { if !ok {
return return
} }
if ls.Exists(id) { if subs.HasSubscriber(id) {
// Already verified the listener exists // Already verified the listener exists
ls.Remove(id) subs.RemoveSubscriber(id)
if mb.metrics != nil { if mb.metrics != nil {
mb.metrics.Gauge("bus", "subscribers").Dec() mb.metrics.Gauge("bus", "subscribers").Dec()
@ -498,10 +548,10 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
topic := strings.TrimLeft(r.URL.Path, "/") topic := strings.Trim(r.URL.Path, "/")
topic = strings.TrimRight(topic, "/")
t := mb.NewTopic(topic) t := mb.NewTopic(topic)
log.Debugf("request for topic %#v", t.Name)
switch r.Method { switch r.Method {
case "POST", "PUT": case "POST", "PUT":
@ -560,7 +610,11 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
NewClient(conn, t, mb).Start() i := SafeParseInt(r.URL.Query().Get("index"), -1)
log.Debugf("new subscriber for %s from %s", t.Name, r.RemoteAddr)
NewClient(conn, t, i, mb).Start()
return return
} }
@ -591,6 +645,7 @@ func (mb *MessageBus) ServeHTTP(w http.ResponseWriter, r *http.Request) {
type Client struct { type Client struct {
conn *websocket.Conn conn *websocket.Conn
topic *Topic topic *Topic
index int
bus *MessageBus bus *MessageBus
id string id string
@ -598,8 +653,8 @@ type Client struct {
} }
// NewClient ... // NewClient ...
func NewClient(conn *websocket.Conn, topic *Topic, bus *MessageBus) *Client { func NewClient(conn *websocket.Conn, topic *Topic, index int, bus *MessageBus) *Client {
return &Client{conn: conn, topic: topic, bus: bus} return &Client{conn: conn, topic: topic, index: index, bus: bus}
} }
func (c *Client) readPump() { func (c *Client) readPump() {
@ -684,7 +739,7 @@ func (c *Client) writePump() {
// Start ... // Start ...
func (c *Client) Start() { func (c *Client) Start() {
c.id = c.conn.RemoteAddr().String() c.id = c.conn.RemoteAddr().String()
c.ch = c.bus.Subscribe(c.id, c.topic.Name) c.ch = c.bus.Subscribe(c.id, c.topic.Name, WithIndex(c.index))
c.conn.SetCloseHandler(func(code int, text string) error { c.conn.SetCloseHandler(func(code int, text string) error {
c.bus.Unsubscribe(c.id, c.topic.Name) c.bus.Unsubscribe(c.id, c.topic.Name)

View File

@ -4,16 +4,23 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"flag"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"testing" "testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wsjson"
) )
var (
debug = flag.Bool("d", false, "enable debug logging")
)
func TestMessageBusLen(t *testing.T) { func TestMessageBusLen(t *testing.T) {
mb := New(nil) mb := New(nil)
assert.Equal(t, mb.Len(), 0) assert.Equal(t, mb.Len(), 0)
@ -24,7 +31,7 @@ func TestMessage(t *testing.T) {
assert.Equal(t, mb.Len(), 0) assert.Equal(t, mb.Len(), 0)
topic := mb.NewTopic("foo") topic := mb.NewTopic("foo")
expected := Message{Topic: topic, Payload: []byte("bar")} expected := mb.NewMessage(topic, []byte("bar"))
mb.Put(expected) mb.Put(expected)
actual, ok := mb.Get(topic) actual, ok := mb.Get(topic)
@ -32,6 +39,28 @@ func TestMessage(t *testing.T) {
assert.Equal(t, actual, expected) assert.Equal(t, actual, expected)
} }
func TestMessageIds(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
mb := New(nil)
assert.Equal(0, mb.Len())
topic := mb.NewTopic("foo")
expected := mb.NewMessage(topic, []byte("bar"))
mb.Put(expected)
actual, ok := mb.Get(topic)
require.True(ok)
assert.Equal(expected, actual)
mb.Put(mb.NewMessage(topic, []byte("bar")))
msg, ok := mb.Get(topic)
require.True(ok)
assert.Equal(msg.ID, 1)
}
func TestMessageGetEmpty(t *testing.T) { func TestMessageGetEmpty(t *testing.T) {
mb := New(nil) mb := New(nil)
assert.Equal(t, mb.Len(), 0) assert.Equal(t, mb.Len(), 0)
@ -45,7 +74,7 @@ func TestMessageGetEmpty(t *testing.T) {
func TestMessageBusPutGet(t *testing.T) { func TestMessageBusPutGet(t *testing.T) {
mb := New(nil) mb := New(nil)
topic := mb.NewTopic("foo") topic := mb.NewTopic("foo")
expected := Message{Topic: topic, Payload: []byte("foo")} expected := mb.NewMessage(topic, []byte("foo"))
mb.Put(expected) mb.Put(expected)
actual, ok := mb.Get(topic) actual, ok := mb.Get(topic)
@ -59,13 +88,39 @@ func TestMessageBusSubscribe(t *testing.T) {
msgs := mb.Subscribe("id1", "foo") msgs := mb.Subscribe("id1", "foo")
topic := mb.NewTopic("foo") topic := mb.NewTopic("foo")
expected := Message{Topic: topic, Payload: []byte("foo")} expected := mb.NewMessage(topic, []byte("foo"))
mb.Put(expected) mb.Put(expected)
actual := <-msgs actual := <-msgs
assert.Equal(t, actual, expected) assert.Equal(t, actual, expected)
} }
func TestMessageBusSubscribeWithIndex(t *testing.T) {
assert := assert.New(t)
mb := New(nil)
msgs := mb.Subscribe("id1", "foo")
topic := mb.NewTopic("foo")
expected := mb.NewMessage(topic, []byte("foo"))
mb.Put(expected)
actual := <-msgs
assert.Equal(expected, actual)
assert.Equal(0, actual.ID)
mb.Unsubscribe("id1", "foo")
mb.Put(mb.NewMessage(topic, []byte("bar"))) // ID == 1
mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 2
msgs = mb.Subscribe("id1", "foo", WithIndex(1))
assert.Equal([]byte("bar"), (<-msgs).Payload)
assert.Equal([]byte("baz"), (<-msgs).Payload)
}
func TestServeHTTPGETIndexEmpty(t *testing.T) { func TestServeHTTPGETIndexEmpty(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
@ -83,8 +138,8 @@ func TestServeHTTPGETTopics(t *testing.T) {
mb := New(nil) mb := New(nil)
mb.Put(Message{Topic: mb.NewTopic("foo"), Payload: []byte("foo")}) mb.Put(mb.NewMessage(mb.NewTopic("foo"), []byte("foo")))
mb.Put(Message{Topic: mb.NewTopic("hello"), Payload: []byte("hello world")}) mb.Put(mb.NewMessage(mb.NewTopic("hello"), []byte("hello world")))
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
@ -151,7 +206,7 @@ func TestServeHTTPSimple(t *testing.T) {
var msg *Message var msg *Message
json.Unmarshal(w.Body.Bytes(), &msg) json.Unmarshal(w.Body.Bytes(), &msg)
assert.Equal(msg.ID, uint64(0)) assert.Equal(msg.ID, 0)
assert.Equal(msg.Topic.Name, "hello") assert.Equal(msg.Topic.Name, "hello")
assert.Equal(msg.Payload, []byte("hello world")) assert.Equal(msg.Payload, []byte("hello world"))
} }
@ -178,11 +233,11 @@ func TestServeHTTPSubscriber(t *testing.T) {
s := httptest.NewServer(mb) s := httptest.NewServer(mb)
defer s.Close() defer s.Close()
msgs := make(chan *Message) msgs := make(chan Message, 10)
ready := make(chan bool, 1) ready := make(chan bool, 1)
consumer := func() { consumer := func() {
var msg *Message var msg Message
// u := fmt.Sprintf("ws%s/hello", strings.TrimPrefix(s.URL, "http")) // u := fmt.Sprintf("ws%s/hello", strings.TrimPrefix(s.URL, "http"))
ws, _, err := websocket.Dial(context.Background(), s.URL+"/hello", nil) ws, _, err := websocket.Dial(context.Background(), s.URL+"/hello", nil)
@ -208,7 +263,7 @@ func TestServeHTTPSubscriber(t *testing.T) {
defer r.Body.Close() defer r.Body.Close()
msg := <-msgs msg := <-msgs
assert.Equal(msg.ID, uint64(0)) assert.Equal(msg.ID, 0)
assert.Equal(msg.Topic.Name, "hello") assert.Equal(msg.Topic.Name, "hello")
assert.Equal(msg.Payload, []byte("hello world")) assert.Equal(msg.Payload, []byte("hello world"))
} }
@ -221,11 +276,11 @@ func TestServeHTTPSubscriberReconnect(t *testing.T) {
s := httptest.NewServer(mb) s := httptest.NewServer(mb)
msgs := make(chan *Message) msgs := make(chan Message, 10)
ready := make(chan bool, 1) ready := make(chan bool, 1)
consumer := func() { consumer := func() {
var msg *Message var msg Message
ws, _, err := websocket.Dial(context.Background(), s.URL+"/hello", nil) ws, _, err := websocket.Dial(context.Background(), s.URL+"/hello", nil)
require.NoError(err) require.NoError(err)
@ -254,7 +309,7 @@ func TestServeHTTPSubscriberReconnect(t *testing.T) {
defer r.Body.Close() defer r.Body.Close()
msg := <-msgs msg := <-msgs
assert.Equal(msg.ID, uint64(0)) assert.Equal(msg.ID, 0)
assert.Equal(msg.Topic.Name, "hello") assert.Equal(msg.Topic.Name, "hello")
assert.Equal(msg.Payload, []byte("hello world")) assert.Equal(msg.Payload, []byte("hello world"))
} }
@ -273,7 +328,7 @@ func TestMsgBusMetrics(t *testing.T) {
func BenchmarkMessageBusPut(b *testing.B) { func BenchmarkMessageBusPut(b *testing.B) {
mb := New(nil) mb := New(nil)
topic := mb.NewTopic("foo") topic := mb.NewTopic("foo")
msg := Message{Topic: topic, Payload: []byte("foo")} msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mb.Put(msg) mb.Put(msg)
@ -283,7 +338,7 @@ func BenchmarkMessageBusPut(b *testing.B) {
func BenchmarkMessageBusGet(b *testing.B) { func BenchmarkMessageBusGet(b *testing.B) {
mb := New(nil) mb := New(nil)
topic := mb.NewTopic("foo") topic := mb.NewTopic("foo")
msg := Message{Topic: topic, Payload: []byte("foo")} msg := mb.NewMessage(topic, []byte("foo"))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mb.Put(msg) mb.Put(msg)
} }
@ -305,10 +360,24 @@ func BenchmarkMessageBusGetEmpty(b *testing.B) {
func BenchmarkMessageBusPutGet(b *testing.B) { func BenchmarkMessageBusPutGet(b *testing.B) {
mb := New(nil) mb := New(nil)
topic := mb.NewTopic("foo") topic := mb.NewTopic("foo")
msg := Message{Topic: topic, Payload: []byte("foo")} msg := mb.NewMessage(topic, []byte("foo"))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mb.Put(msg) mb.Put(msg)
mb.Get(topic) mb.Get(topic)
} }
} }
func TestMain(m *testing.M) {
flag.Parse()
if *debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
result := m.Run()
os.Exit(result)
}

View File

@ -1,7 +1,7 @@
package msgbus package msgbus
import ( import (
"sync" sync "github.com/sasha-s/go-deadlock"
) )
// minCapacity is the smallest capacity that queue may have. // minCapacity is the smallest capacity that queue may have.
@ -112,6 +112,25 @@ func (q *Queue) Peek() interface{} {
return q.buf[q.head] return q.buf[q.head]
} }
// ForEach applys the function `f` over each item in the queue for read-only
// access into the queue in O(n) time for indexining into the queue.
func (q *Queue) ForEach(f func(elem interface{}) error) error {
q.Lock()
defer q.Unlock()
if q.count <= 0 {
return nil
}
for i := 0; i < q.count; i++ {
if err := f(q.buf[i]); err != nil {
return err
}
}
return nil
}
// next returns the next buffer position wrapping around buffer. // next returns the next buffer position wrapping around buffer.
func (q *Queue) next(i int) int { func (q *Queue) next(i int) int {
return (i + 1) & (len(q.buf) - 1) // bitwise modulus return (i + 1) & (len(q.buf) - 1) // bitwise modulus
@ -123,7 +142,7 @@ func (q *Queue) growIfFull() {
q.buf = make([]interface{}, minCapacity) q.buf = make([]interface{}, minCapacity)
return return
} }
if q.count == len(q.buf) && q.count < q.maxlen { if q.count == len(q.buf) && (q.maxlen == 0 || q.count < q.maxlen) {
q.resize() q.resize()
} }
} }

View File

@ -1,13 +1,15 @@
package msgbus package msgbus
import ( import (
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestEmpty(t *testing.T) { func TestEmpty(t *testing.T) {
q := Queue{} q := NewQueue()
assert.Zero(t, q.Len()) assert.Zero(t, q.Len())
assert.True(t, q.Empty()) assert.True(t, q.Empty())
} }
@ -15,11 +17,12 @@ func TestEmpty(t *testing.T) {
func TestSimple(t *testing.T) { func TestSimple(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
q := Queue{} q := NewQueue()
for i := 0; i < minCapacity; i++ { for i := 0; i < minCapacity; i++ {
q.Push(i) q.Push(i)
} }
assert.Equal(minCapacity, q.Len())
for i := 0; i < minCapacity; i++ { for i := 0; i < minCapacity; i++ {
assert.Equal(q.Peek(), i) assert.Equal(q.Peek(), i)
@ -27,6 +30,31 @@ func TestSimple(t *testing.T) {
} }
} }
func TestForEach(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
q := NewQueue()
ys := []int{0, 1, 2, 3}
for _, y := range ys {
q.Push(y)
}
assert.Equal(4, q.Len())
var xs []int
err := q.ForEach(func(e interface{}) error {
i, ok := e.(int)
if !ok {
return fmt.Errorf("unexpected type %T", e)
}
xs = append(xs, i)
return nil
})
require.NoError(err)
assert.Equal(ys, xs)
}
func TestMaxLen(t *testing.T) { func TestMaxLen(t *testing.T) {
q := Queue{maxlen: minCapacity} q := Queue{maxlen: minCapacity}
assert.Equal(t, q.MaxLen(), minCapacity) assert.Equal(t, q.MaxLen(), minCapacity)
@ -43,7 +71,7 @@ func TestFull(t *testing.T) {
} }
func TestBufferWrap(t *testing.T) { func TestBufferWrap(t *testing.T) {
q := Queue{} q := NewQueue()
for i := 0; i < minCapacity; i++ { for i := 0; i < minCapacity; i++ {
q.Push(i) q.Push(i)
@ -63,7 +91,7 @@ func TestBufferWrap(t *testing.T) {
func TestLen(t *testing.T) { func TestLen(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
q := Queue{} q := NewQueue()
assert.Zero(q.Len()) assert.Zero(q.Len())
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
@ -78,14 +106,14 @@ func TestLen(t *testing.T) {
} }
func BenchmarkPush(b *testing.B) { func BenchmarkPush(b *testing.B) {
q := Queue{} q := NewQueue()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
q.Push(i) q.Push(i)
} }
} }
func BenchmarkPushPop(b *testing.B) { func BenchmarkPushPop(b *testing.B) {
q := Queue{} q := NewQueue()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
q.Push(i) q.Push(i)
} }

12
utils.go Normal file
View File

@ -0,0 +1,12 @@
package msgbus
import "strconv"
// SafeParseInt ...
func SafeParseInt(s string, d int) int {
n, e := strconv.Atoi(s)
if e != nil {
return d
}
return n
}