Add Subscriber test (fixed some race conditions)
This commit is contained in:
parent
9f01db6002
commit
bbb7445d58
42
msgbus.go
42
msgbus.go
|
@ -61,6 +61,8 @@ type Message struct {
|
|||
|
||||
// Listeners ...
|
||||
type Listeners struct {
|
||||
sync.RWMutex
|
||||
|
||||
ids map[string]bool
|
||||
chs map[string]chan Message
|
||||
}
|
||||
|
@ -75,11 +77,17 @@ func NewListeners() *Listeners {
|
|||
|
||||
// Length ...
|
||||
func (ls *Listeners) Length() int {
|
||||
ls.RLock()
|
||||
defer ls.RUnlock()
|
||||
|
||||
return len(ls.ids)
|
||||
}
|
||||
|
||||
// Add ...
|
||||
func (ls *Listeners) Add(id string) chan Message {
|
||||
ls.Lock()
|
||||
defer ls.Unlock()
|
||||
|
||||
ls.ids[id] = true
|
||||
ls.chs[id] = make(chan Message)
|
||||
return ls.chs[id]
|
||||
|
@ -87,6 +95,9 @@ func (ls *Listeners) Add(id string) chan Message {
|
|||
|
||||
// Remove ...
|
||||
func (ls *Listeners) Remove(id string) {
|
||||
ls.Lock()
|
||||
defer ls.Unlock()
|
||||
|
||||
delete(ls.ids, id)
|
||||
|
||||
close(ls.chs[id])
|
||||
|
@ -95,12 +106,18 @@ func (ls *Listeners) Remove(id string) {
|
|||
|
||||
// Exists ...
|
||||
func (ls *Listeners) Exists(id string) bool {
|
||||
ls.RLock()
|
||||
defer ls.RUnlock()
|
||||
|
||||
_, ok := ls.ids[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Get ...
|
||||
func (ls *Listeners) Get(id string) (chan Message, bool) {
|
||||
ls.RLock()
|
||||
defer ls.RUnlock()
|
||||
|
||||
ch, ok := ls.chs[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
|
@ -110,6 +127,9 @@ func (ls *Listeners) Get(id string) (chan Message, bool) {
|
|||
|
||||
// NotifyAll ...
|
||||
func (ls *Listeners) NotifyAll(message Message) int {
|
||||
ls.RLock()
|
||||
defer ls.RUnlock()
|
||||
|
||||
i := 0
|
||||
for id, ch := range ls.chs {
|
||||
select {
|
||||
|
@ -135,7 +155,7 @@ type Options struct {
|
|||
|
||||
// MessageBus ...
|
||||
type MessageBus struct {
|
||||
sync.Mutex
|
||||
sync.RWMutex
|
||||
|
||||
metrics *Metrics
|
||||
|
||||
|
@ -308,6 +328,9 @@ func (mb *MessageBus) NewMessage(topic *Topic, payload []byte) Message {
|
|||
|
||||
// Put ...
|
||||
func (mb *MessageBus) Put(message Message) {
|
||||
mb.Lock()
|
||||
defer mb.Unlock()
|
||||
|
||||
log.Debugf(
|
||||
"[msgbus] PUT id=%d topic=%s payload=%s",
|
||||
message.ID, message.Topic.Name, message.Payload,
|
||||
|
@ -325,11 +348,14 @@ func (mb *MessageBus) Put(message Message) {
|
|||
mb.metrics.GaugeVec("queue", "len").WithLabelValues(t.Name).Inc()
|
||||
}
|
||||
|
||||
mb.NotifyAll(message)
|
||||
mb.publish(message)
|
||||
}
|
||||
|
||||
// Get ...
|
||||
func (mb *MessageBus) Get(t *Topic) (Message, bool) {
|
||||
mb.RLock()
|
||||
defer mb.RUnlock()
|
||||
|
||||
log.Debugf("[msgbus] GET topic=%s", t)
|
||||
|
||||
q, ok := mb.queues[t]
|
||||
|
@ -350,10 +376,10 @@ func (mb *MessageBus) Get(t *Topic) (Message, bool) {
|
|||
return m.(Message), true
|
||||
}
|
||||
|
||||
// NotifyAll ...
|
||||
func (mb *MessageBus) NotifyAll(message Message) {
|
||||
// publish ...
|
||||
func (mb *MessageBus) publish(message Message) {
|
||||
log.Debugf(
|
||||
"[msgbus] NotifyAll id=%d topic=%s payload=%s",
|
||||
"[msgbus] publish id=%d topic=%s payload=%s",
|
||||
message.ID, message.Topic.Name, message.Payload,
|
||||
)
|
||||
ls, ok := mb.listeners[message.Topic]
|
||||
|
@ -370,6 +396,9 @@ func (mb *MessageBus) NotifyAll(message Message) {
|
|||
|
||||
// Subscribe ...
|
||||
func (mb *MessageBus) Subscribe(id, topic string) chan Message {
|
||||
mb.Lock()
|
||||
defer mb.Unlock()
|
||||
|
||||
log.Debugf("[msgbus] Subscribe id=%s topic=%s", id, topic)
|
||||
|
||||
t, ok := mb.topics[topic]
|
||||
|
@ -399,6 +428,9 @@ func (mb *MessageBus) Subscribe(id, topic string) chan Message {
|
|||
|
||||
// Unsubscribe ...
|
||||
func (mb *MessageBus) Unsubscribe(id, topic string) {
|
||||
mb.Lock()
|
||||
defer mb.Unlock()
|
||||
|
||||
log.Debugf("[msgbus] Unsubscribe id=%s topic=%s", id, topic)
|
||||
|
||||
t, ok := mb.topics[topic]
|
||||
|
|
|
@ -3,10 +3,14 @@ package msgbus
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -113,6 +117,53 @@ func TestServeHTTPSimple(t *testing.T) {
|
|||
assert.Equal(msg.Payload, []byte("hello world"))
|
||||
}
|
||||
|
||||
func TestServeHTTPSubscriber(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
mb := New(nil)
|
||||
|
||||
s := httptest.NewServer(mb)
|
||||
defer s.Close()
|
||||
|
||||
msgs := make(chan *Message)
|
||||
ready := make(chan bool, 1)
|
||||
|
||||
consumer := func() {
|
||||
var msg *Message
|
||||
|
||||
u := fmt.Sprintf("ws%s/hello", strings.TrimPrefix(s.URL, "http"))
|
||||
|
||||
ws, _, err := websocket.DefaultDialer.Dial(u, nil)
|
||||
assert.NoError(err)
|
||||
defer ws.Close()
|
||||
|
||||
ready <- true
|
||||
|
||||
err = ws.ReadJSON(&msg)
|
||||
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
|
||||
msgs <- msg
|
||||
}
|
||||
|
||||
go consumer()
|
||||
|
||||
<-ready
|
||||
|
||||
c := s.Client()
|
||||
b := bytes.NewBufferString("hello world")
|
||||
r, err := c.Post(s.URL+"/hello", "text/plain", b)
|
||||
assert.NoError(err)
|
||||
defer r.Body.Close()
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.NoError(err)
|
||||
assert.Regexp(`message successfully published to hello with sequence \d+`, string(body))
|
||||
|
||||
msg := <-msgs
|
||||
assert.Equal(msg.ID, uint64(0))
|
||||
assert.Equal(msg.Topic.Name, "hello")
|
||||
assert.Equal(msg.Payload, []byte("hello world"))
|
||||
}
|
||||
|
||||
func BenchmarkMessageBusPut(b *testing.B) {
|
||||
mb := New(nil)
|
||||
topic := mb.NewTopic("foo")
|
||||
|
|
Loading…
Reference in New Issue