package msgbus import ( "bytes" "context" "encoding/json" "flag" "io/ioutil" "net/http" "net/http/httptest" "os" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) var ( debug = flag.Bool("d", false, "enable debug logging") ) func TestMessageBusLen(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) assert.Equal(0, mb.Len()) } func TestMessage(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) assert.Equal(0, mb.Len()) topic := mb.NewTopic("foo") expected := mb.NewMessage(topic, []byte("bar")) err = mb.Put(expected) require.NoError(err) actual, ok := mb.Get(topic) require.True(ok) assert.Equal(expected, actual) } func TestMessageIds(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) assert.Equal(0, mb.Len()) topic := mb.NewTopic("foo") expected := mb.NewMessage(topic, []byte("bar")) mb.Put(expected) actual, ok := mb.Get(topic) require.True(ok) assert.Equal(expected, actual) mb.Put(mb.NewMessage(topic, []byte("bar"))) msg, ok := mb.Get(topic) require.True(ok) assert.Equal(msg.ID, int64(2)) } func TestMessageGetEmpty(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) assert.Equal(0, mb.Len()) topic := mb.NewTopic("foo") msg, ok := mb.Get(topic) require.False(ok) assert.Equal(Message{}, msg) } func TestMessageBusPutGet(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) topic := mb.NewTopic("foo") expected := mb.NewMessage(topic, []byte("foo")) mb.Put(expected) actual, ok := mb.Get(topic) require.True(ok) assert.Equal(expected, actual) } func TestMessageBusSubscribe(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) msgs := mb.Subscribe("id1", "foo") topic := mb.NewTopic("foo") expected := mb.NewMessage(topic, []byte("foo")) mb.Put(expected) actual := <-msgs assert.Equal(expected, actual) } func TestMessageBusSubscribeFullBuffer(t *testing.T) { require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir), WithBufferLength(2)) require.NoError(err) topic := mb.NewTopic("hello") mb.Put(mb.NewMessage(topic, []byte("foo"))) // ID == 1 mb.Put(mb.NewMessage(topic, []byte("bar"))) // ID == 2 mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 3 ctx, cancel := context.WithCancel(context.Background()) defer cancel() msgs := mb.Subscribe("id1", "hello") go func(ctx context.Context) { select { case <-msgs: case <-ctx.Done(): return } }(ctx) //err = mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 3 //require.NoError(err) } func TestMessageBusSubscribeWithIndex(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) msgs := mb.Subscribe("id1", "foo") topic := mb.NewTopic("foo") expected := mb.NewMessage(topic, []byte("foo")) // ID == 1 mb.Put(expected) actual := <-msgs assert.Equal(expected, actual) assert.Equal(int64(1), actual.ID) mb.Unsubscribe("id1", "foo") require.NoError(mb.Put(mb.NewMessage(topic, []byte("bar")))) // ID == 2 require.NoError(mb.Put(mb.NewMessage(topic, []byte("baz")))) // ID == 3 msgs = mb.Subscribe("id1", "foo", WithIndex(1)) assert.Equal("foo", string((<-msgs).Payload)) assert.Equal("bar", string((<-msgs).Payload)) assert.Equal("baz", string((<-msgs).Payload)) } func TestMessageBusWAL(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) msgs := mb.Subscribe("id1", "hello") topic := mb.NewTopic("hello") mb.Put(mb.NewMessage(topic, []byte("foo"))) // ID == 1 mb.Put(mb.NewMessage(topic, []byte("bar"))) // ID == 2 mb.Put(mb.NewMessage(topic, []byte("baz"))) // ID == 3 assert.Equal([]byte("foo"), (<-msgs).Payload) assert.Equal([]byte("bar"), (<-msgs).Payload) assert.Equal([]byte("baz"), (<-msgs).Payload) assert.Equal(int64(3), topic.Sequence) mb.Unsubscribe("id1", "foo") // Now ensure when we start back up we've re-filled the queues and retain the same // message ids and topic sequence number mb, err = NewMessageBus(WithLogPath(testdir)) require.NoError(err) // we have to tell the bus we want to subscribe from the start msgs = mb.Subscribe("id1", "hello", WithIndex(1)) topic = mb.NewTopic("hello") assert.Equal(int64(3), topic.Sequence) assert.Equal("foo", string((<-msgs).Payload)) assert.Equal("bar", string((<-msgs).Payload)) assert.Equal("baz", string((<-msgs).Payload)) msg := mb.NewMessage(topic, []byte("foobar")) assert.Equal(int64(4), msg.ID) } func TestServeHTTPGETIndexEmpty(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "/", nil) mb.ServeHTTP(w, r) assert.Equal(w.Code, http.StatusOK) assert.Equal(w.Body.String(), "{}") } func TestServeHTTPGETTopics(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) mb.Put(mb.NewMessage(mb.NewTopic("foo"), []byte("foo"))) mb.Put(mb.NewMessage(mb.NewTopic("hello"), []byte("hello world"))) w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "/", nil) mb.ServeHTTP(w, r) assert.Equal(w.Code, http.StatusOK) assert.Contains(w.Body.String(), "foo") assert.Contains(w.Body.String(), "hello") } func TestServeHTTPGETEmptyQueue(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "/hello", nil) mb.ServeHTTP(w, r) assert.Equal(w.Code, http.StatusNoContent) } func TestServeHTTPPOST(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) w := httptest.NewRecorder() b := bytes.NewBufferString("hello world") r, _ := http.NewRequest("POST", "/hello", b) mb.ServeHTTP(w, r) assert.Equal(w.Code, http.StatusAccepted) } func TestServeHTTPMaxPayloadSize(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) w := httptest.NewRecorder() b := bytes.NewBuffer(bytes.Repeat([]byte{'X'}, (DefaultMaxPayloadSize * 2))) r, _ := http.NewRequest("POST", "/hello", b) mb.ServeHTTP(w, r) assert.Equal(http.StatusRequestEntityTooLarge, w.Code) assert.Regexp(`payload exceeds max-payload-size`, w.Body.String()) } func TestServeHTTPSimple(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) w := httptest.NewRecorder() b := bytes.NewBufferString("hello world") r, _ := http.NewRequest("POST", "/hello", b) mb.ServeHTTP(w, r) assert.Equal(w.Code, http.StatusAccepted) w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "/hello", nil) mb.ServeHTTP(w, r) assert.Equal(w.Code, http.StatusOK) var msg *Message json.Unmarshal(w.Body.Bytes(), &msg) assert.Equal(int64(1), msg.ID) assert.Equal("hello", msg.Topic.Name, "hello") assert.Equal([]byte("hello world"), msg.Payload) } func BenchmarkServeHTTP_POST(b *testing.B) { require := require.New(b) b.Run("Sync", func(b *testing.B) { testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) b.ResetTimer() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() b := bytes.NewBufferString("hello world") r, _ := http.NewRequest("POST", "/hello", b) mb.ServeHTTP(w, r) } }) b.Run("NoSync", func(b *testing.B) { testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir), WithNoSync(true)) require.NoError(err) b.ResetTimer() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() b := bytes.NewBufferString("hello world") r, _ := http.NewRequest("POST", "/hello", b) mb.ServeHTTP(w, r) } }) } func TestServeHTTPSubscriber(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) s := httptest.NewServer(mb) defer s.Close() msgs := make(chan Message, 10) ready := make(chan bool, 1) consumer := func() { var msg Message // u := fmt.Sprintf("ws%s/hello", strings.TrimPrefix(s.URL, "http")) ws, _, err := websocket.Dial(context.Background(), s.URL+"/hello", nil) require.NoError(err) defer ws.Close(websocket.StatusNormalClosure, "") ready <- true err = wsjson.Read(context.Background(), ws, &msg) require.NoError(err) msgs <- msg } go consumer() <-ready c := s.Client() b := bytes.NewBufferString("hello world") r, err := c.Post(s.URL+"/hello", "text/plain", b) require.NoError(err) defer r.Body.Close() msg := <-msgs assert.Equal(int64(1), msg.ID) assert.Equal("hello", msg.Topic.Name) assert.Equal([]byte("hello world"), msg.Payload) } func TestServeHTTPSubscriberReconnect(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) s := httptest.NewServer(mb) msgs := make(chan Message, 10) ready := make(chan bool, 1) consumer := func() { var msg Message ws, _, err := websocket.Dial(context.Background(), s.URL+"/hello", nil) require.NoError(err) defer ws.Close(websocket.StatusNormalClosure, "") ready <- true err = wsjson.Read(context.Background(), ws, &msg) require.NoError(err) msgs <- msg } go consumer() <-ready s.Close() s = httptest.NewServer(mb) defer s.Close() c := s.Client() b := bytes.NewBufferString("hello world") r, err := c.Post(s.URL+"/hello", "text/plain", b) require.NoError(err) defer r.Body.Close() msg := <-msgs assert.Equal(int64(1), msg.ID) assert.Equal("hello", msg.Topic.Name) assert.Equal([]byte("hello world"), msg.Payload) } func TestMsgBusMetrics(t *testing.T) { assert := assert.New(t) require := require.New(t) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) assert.IsType(&Metrics{}, mb.Metrics()) } func BenchmarkMessageBusPut(b *testing.B) { require := require.New(b) b.Run("Sync", func(b *testing.B) { testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) topic := mb.NewTopic("foo") msg := mb.NewMessage(topic, []byte("foo")) b.ResetTimer() for i := 0; i < b.N; i++ { mb.Put(msg) } }) b.Run("NoSync", func(b *testing.B) { testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir), WithNoSync(true)) require.NoError(err) topic := mb.NewTopic("foo") msg := mb.NewMessage(topic, []byte("foo")) b.ResetTimer() for i := 0; i < b.N; i++ { mb.Put(msg) } }) } func BenchmarkMessageBusGet(b *testing.B) { require := require.New(b) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) topic := mb.NewTopic("foo") msg := mb.NewMessage(topic, []byte("foo")) for i := 0; i < b.N; i++ { mb.Put(msg) } b.ResetTimer() for i := 0; i < b.N; i++ { mb.Get(topic) } } func BenchmarkMessageBusGetEmpty(b *testing.B) { require := require.New(b) testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) topic := mb.NewTopic("foo") b.ResetTimer() for i := 0; i < b.N; i++ { mb.Get(topic) } } func BenchmarkMessageBusPutGet(b *testing.B) { require := require.New(b) b.Run("Sync", func(b *testing.B) { testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir)) require.NoError(err) topic := mb.NewTopic("foo") msg := mb.NewMessage(topic, []byte("foo")) b.ResetTimer() for i := 0; i < b.N; i++ { mb.Put(msg) mb.Get(topic) } }) b.Run("NoSync", func(b *testing.B) { testdir, err := ioutil.TempDir("", "msgbus-logs-*") require.NoError(err) defer os.RemoveAll(testdir) mb, err := NewMessageBus(WithLogPath(testdir), WithNoSync(true)) require.NoError(err) topic := mb.NewTopic("foo") msg := mb.NewMessage(topic, []byte("foo")) b.ResetTimer() for i := 0; i < b.N; i++ { mb.Put(msg) mb.Get(topic) } }) } func TestMain(m *testing.M) { flag.Parse() if *debug { log.SetLevel(log.DebugLevel) } else { log.SetLevel(log.WarnLevel) } result := m.Run() os.Exit(result) }