From 0eee3591638be746584519169099db093360e382 Mon Sep 17 00:00:00 2001 From: jrapoport <249460+jrapoport@users.noreply.github.com> Date: Tue, 12 Jan 2021 23:49:40 -0800 Subject: [PATCH] bbolt support --- .github/workflows/test.yml | 9 +- .gitignore | 1 + README.md | 100 ++++-- chestnut.go | 5 +- chestnut_test.go | 259 ++++++++++---- codecov.yml | 2 + encoding/compress/compress.go | 16 + encoding/compress/compress_test.go | 17 +- encoding/json/decode_test.go | 4 + encoding/json/encode_test.go | 5 +- encoding/json/encoders/lookup/decoder.go | 14 +- encoding/json/encoders/lookup/decoder_test.go | 23 ++ encoding/json/encoders/lookup/encoder.go | 15 +- encoding/json/encoders/lookup/encoder_test.go | 34 ++ encoding/json/encoders/secure/decoder.go | 7 +- encoding/json/encoders/secure/decoder_test.go | 69 ++++ encoding/json/encoders/secure/encoder.go | 5 +- encoding/json/encoders/secure/encoder_test.go | 84 +++++ encoding/json/packager/encoding.go | 5 +- encoding/json/packager/package_test.go | 30 +- encryptor/aes/aes_test.go | 9 + go.mod | 1 + go.sum | 2 + keystore/keyutils.go | 28 +- storage/bolt/store.go | 322 ++++++++++++++++++ storage/bolt/store_test.go | 11 + storage/nuts/store.go | 85 ++--- storage/nuts/store_test.go | 214 +----------- storage/storage.go | 3 + storage/store_test/test_suite.go | 260 ++++++++++++++ 30 files changed, 1236 insertions(+), 403 deletions(-) create mode 100644 codecov.yml create mode 100644 storage/bolt/store.go create mode 100644 storage/bolt/store_test.go create mode 100644 storage/store_test/test_suite.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7466f25..b589444 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,5 @@ on: push: - branches: [master] - tags: ['*'] pull_request: types: [opened, synchronize, reopened] name: test @@ -21,7 +19,6 @@ jobs: - name: Install dependencies run: make deps - name: Lint and test - run: make all -# TEST_FLAGS="-covermode=atomic -coverpkg=./... -coverprofile=coverage.txt" -# - name: Upload coverage to Codecov -# uses: codecov/codecov-action@v1 + run: make all TEST_FLAGS="-covermode=atomic -coverpkg=./... -coverprofile=coverage.txt" + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 diff --git a/.gitignore b/.gitignore index 6c9f783..504e5ab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +coverage.txt # Created by https://www.toptal.com/developers/gitignore/api/go,visualstudiocode,jetbrains+all,macos # Edit at https://www.toptal.com/developers/gitignore?templates=go,visualstudiocode,jetbrains+all,macos diff --git a/README.md b/README.md index 5cce6ea..3e9e142 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # 🌰  Chestnut -![GitHub Workflow Status](https://img.shields.io/github/workflow/status/jrapoport/chestnut/test?style=flat-square) [![Go Report Card](https://goreportcard.com/badge/github.com/jrapoport/chestnut?style=flat-square&)](https://goreportcard.com/report/github.com/jrapoport/chestnut) ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/jrapoport/chestnut?style=flat-square) [![GitHub](https://img.shields.io/github/license/jrapoport/chestnut?style=flat-square)](https://github.com/jrapoport/chestnut/blob/master/LICENSE) +![GitHub Workflow Status](https://img.shields.io/github/workflow/status/jrapoport/chestnut/test?style=flat-square) [![Go Report Card](https://goreportcard.com/badge/github.com/jrapoport/chestnut?style=flat-square&)](https://goreportcard.com/report/github.com/jrapoport/chestnut) ![Codecov branch](https://img.shields.io/codecov/c/github/jrapoport/chestnut/master?style=flat-square&token=7REY4BDPHW) ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/jrapoport/chestnut?style=flat-square) [![GitHub](https://img.shields.io/github/license/jrapoport/chestnut?style=flat-square)](https://github.com/jrapoport/chestnut/blob/master/LICENSE) [![Buy Me A Coffee](https://img.shields.io/badge/buy%20me%20a%20coffee-☕-6F4E37?style=flat-square)](https://www.buymeacoffee.com/jrapoport) @@ -13,17 +13,19 @@ about things like storage, compression, hashing, secrets, or encryption. Chestnut is a storage chest, and not a datastore itself. As such, Chestnut must be backed by a storage solution. -Currently, Chestnut supports [NutsDB](https://github.com/xujiajun/nutsdb) for -storage with [BBolt](https://github.com/etcd-io/bbolt) support coming soon(ish). +Currently, Chestnut supports [BBolt](https://github.com/etcd-io/bbolt) and +[NutsDB](https://github.com/xujiajun/nutsdb) as backing storage. ## Table of Contents - [Getting Started](#getting-started) * [Installing](#installing) * [Importing Chestnut](#importing-chestnut) - + [Requirments](#requirments) + + [Requirements](#requirements) - [Storage](#storage) - * [Current Support](#current-support) - * [Planned Support](#planned-support) + * [Built-in](#supported) + + [BBolt](#bbolt) + + [NutsDB](#nutsdb) + * [Planned](#planned) - [Encryption](#encryption) * [AES256-CTR](#aes256-ctr) * [Custom Encryption](#custom-encryption) @@ -94,8 +96,6 @@ $ go get -u github.com/jrapoport/chestnut To use Chestnut as an encrypted store, import as: ```go -package main - import ( "github.com/jrapoport/chestnut" "github.com/jrapoport/chestnut/encryptor/aes" @@ -118,7 +118,7 @@ defer cn.Close() ``` -#### Requirments +#### Requirements Chestnut has two requirements: 1) [Storage](#storage) that supports the `storage.Storage` interface (with a lightweight adapter). @@ -126,23 +126,59 @@ Chestnut has two requirements: ## Storage Chestnut will work seamlessly with **any** storage solution (or adapter) that -supports the`storage.Storage` interface. We picked [NutsDB](https://github.com/xujiajun/nutsdb) -to start, and plan to add [BBolt](https://github.com/etcd-io/bbolt) support soon. +supports the`storage.Storage` interface. -### Current Support +### Built-in -* [NutsDB](https://github.com/xujiajun/nutsdb) +Currently, Chestnut has built-in support for +[BBolt](https://github.com/etcd-io/bbolt) and +[NutsDB](https://github.com/xujiajun/nutsdb). -### Planned Support +#### BBolt -* [BBolt](https://github.com/etcd-io/bbolt) — soon(ish). +https://github.com/etcd-io/bbolt +Chestnut has built-in support for using +[BBolt](https://github.com/etcd-io/bbolt) as a backing store. -* [GORM](https://github.com/go-gorm/gorm) — no timeframe. - -Gorm is an ORM, and while not a datastore per se, we think it could be adapted -to support sparse encryption. The upside of Gorm is automatic support for -databases like mysql, sqlite, etc. The downside is supporting Gorm is likely a -lot of work. +To use bbolt for a backing store you can import Chestnut's `bolt` package +and call `bolt.NewStore()`: + +```go +import "github.com/jrapoport/chestnut/storage/bolt" + +//use or create a bbolt backing store at path +store := bolt.NewStore(path) + +// use bbolt for the storage chest +cn := chestnut.NewChestnut(store, ...) +``` + +#### NutsDB + +https://github.com/xujiajun/nutsdb +Chestnut has built-in support for using +[NutsDB](https://github.com/xujiajun/nutsdb) as a backing store. + +To use nutsDB for a backing store you can import Chestnut's `nuts` package +and call `nuts.NewStore()`: + +```go +import "github.com/jrapoport/chestnut/storage/nuts" + +//use or create a nutsdb backing store at path +store := nuts.NewStore(path) + +// use nutsdb for the storage chest +cn := chestnut.NewChestnut(store, ...) +``` + +### Planned + +[GORM](https://github.com/go-gorm/gorm) + Gorm is an ORM, and while not a datastore per se, we think it could be adapted + to support sparse encryption. The upside of Gorm is automatic support for + databases like mysql, sqlite, etc. The downside is supporting Gorm is likely a + lot of work, so no timeframe. ## Encryption Chestnut supports several flavors of AES out of the box: @@ -602,6 +638,15 @@ To get a list of all the keys for a namespace you can call `Chestnut.List()`: keys, err := cn.List("my-namespace") ``` +#### ListAll + +To get a mapped list of all keys in the store organized by namespace you can call +`Chestnut.ListAll()`: + +```go +keymap, err := cn.ListAll() +``` + #### Export To export the storage chest to another path you can call `Chestnut.Export()`: @@ -652,7 +697,6 @@ type MySecureStruct struct { ValueA int `json:",secure"` // *will* be encrypted ValueB struct{} `json:"value_b,secure"` // *will* be encrypted ValueC string `json:",omitempty,secure"` // *will* be encrypted - ... PlaintextA string // will *not* be encrypted PlaintextB int `json:""` // will *not* be encrypted PlaintextC int `json:"-"` // will *not* be encrypted @@ -710,18 +754,18 @@ var myStruct = &MyStructD{ `myStruct` will be encrypted by Chestnut as: ```go -*main.MyStructD { - ValueD: '****' +*MyStructD { + ValueD: **** Embed1: main.MyStructA{ - ValueA: '****' + ValueA: **** }, Embed2: main.MyStructB{ MyStructA: main.MyStructA{ - ValueA: '****' + ValueA: **** }, - ValueB: "baz" + ValueB: **** }, - Embed3: '****' + Embed3: **** } ``` where `'****'` represents an encrypted value. diff --git a/chestnut.go b/chestnut.go index 3d75471..3362468 100644 --- a/chestnut.go +++ b/chestnut.go @@ -31,7 +31,7 @@ func NewChestnut(store storage.Storage, opt ...ChestOption) *Chestnut { logger := log.Named(opts.log, logName) cn := &Chestnut{opts, store, logger} if err := cn.validConfig(); err != nil { - logger.Fatal(err) + logger.Panic(err) return nil } return cn @@ -53,6 +53,9 @@ func (cn *Chestnut) validConfig() error { if cn.opts.compression == compress.Custom && cn.opts.decompressor == nil { return errors.New("decompressor is required") } + if !cn.opts.compression.Valid() { + return errors.New("invalid compression format") + } return nil } diff --git a/chestnut_test.go b/chestnut_test.go index 4f9d2a8..061dd6e 100644 --- a/chestnut_test.go +++ b/chestnut_test.go @@ -1,6 +1,7 @@ package chestnut import ( + "errors" "path/filepath" "reflect" "sort" @@ -14,6 +15,7 @@ import ( "github.com/jrapoport/chestnut/encryptor/crypto" "github.com/jrapoport/chestnut/log" "github.com/jrapoport/chestnut/storage" + "github.com/jrapoport/chestnut/storage/bolt" "github.com/jrapoport/chestnut/storage/nuts" "github.com/jrapoport/chestnut/value" "github.com/stretchr/testify/assert" @@ -188,24 +190,37 @@ func newKey() string { return uuid.New().String() } -func nutsDBStore(t *testing.T) storage.Storage { - path := t.TempDir() +func nutsStore(t *testing.T, path string) storage.Storage { store := nuts.NewStore(path) assert.NotNil(t, store) return store } +func boltStore(t *testing.T, path string) storage.Storage { + store := bolt.NewStore(path) + assert.NotNil(t, store) + return store +} + +type StoreFunc = func(t *testing.T, path string) storage.Storage + type ChestnutTestSuite struct { suite.Suite - cn *Chestnut + storeFunc StoreFunc + cn *Chestnut } func TestChestnut(t *testing.T) { - suite.Run(t, new(ChestnutTestSuite)) + testStores := []StoreFunc{nutsStore, boltStore} + for _, test := range testStores { + ts := new(ChestnutTestSuite) + ts.storeFunc = test + suite.Run(t, ts) + } } func (ts *ChestnutTestSuite) SetupTest() { - store := nutsDBStore(ts.T()) + store := ts.storeFunc(ts.T(), ts.T().TempDir()) assert.NotNil(ts.T(), store) ts.cn = NewChestnut(store, encryptorOpt) assert.NotNil(ts.T(), ts.cn) @@ -431,19 +446,19 @@ func (ts *ChestnutTestSuite) TestStore_SecureEntry() { } } -func TestChestnut_OverwritesDisabled(t *testing.T) { - testOptionDisableOverwrites(t, false) +func (ts *ChestnutTestSuite) TestChestnut_OverwritesDisabled() { + ts.testOptionDisableOverwrites(false) } -func TestChestnut_OverwritesEnabled(t *testing.T) { - testOptionDisableOverwrites(t, true) +func (ts *ChestnutTestSuite) TestChestnut_OverwritesEnabled() { + ts.testOptionDisableOverwrites(true) } -func testOptionDisableOverwrites(t *testing.T, enabled bool) { +func (ts *ChestnutTestSuite) testOptionDisableOverwrites(enabled bool) { key := newKey() - path := filepath.Join(t.TempDir()) - store := nuts.NewStore(path) - assert.NotNil(t, store) + path := filepath.Join(ts.T().TempDir()) + store := ts.storeFunc(ts.T(), path) + assert.NotNil(ts.T(), store) opts := []ChestOption{ encryptorOpt, } @@ -453,26 +468,26 @@ func testOptionDisableOverwrites(t *testing.T, enabled bool) { opts = append(opts, OverwritesForbidden()) } cn := NewChestnut(store, opts...) - assert.NotNil(t, cn) - assert.Equal(t, enabled, cn.opts.overwrites) + assert.NotNil(ts.T(), cn) + assert.Equal(ts.T(), enabled, cn.opts.overwrites) defer func() { err := cn.Close() - assert.NoError(t, err) + assert.NoError(ts.T(), err) }() err := cn.Open() - assert.NoError(t, err) + assert.NoError(ts.T(), err) err = cn.Put(testName, []byte(key), []byte(testValue)) - assert.NoError(t, err) + assert.NoError(ts.T(), err) // this should fail with an error if overwrites are disabled err = cn.Put(testName, []byte(key), []byte(testValue)) - assertErr(t, err) + assertErr(ts.T(), err) } -func TestChestnut_ChainedEncryptor(t *testing.T) { +func (ts *ChestnutTestSuite) TestChestnut_ChainedEncryptor() { var operation = "encrypting" // initialize a keystore with a chained encryptor openSecret := func(s crypto.Secret) []byte { - t.Logf("%s with secret %s", operation, s.ID()) + ts.T().Logf("%s with secret %s", operation, s.ID()) return []byte(s.ID()) } managedSecret := crypto.NewManagedSecret(uuid.New().String(), "i-am-a-managed-secret") @@ -483,101 +498,105 @@ func TestChestnut_ChainedEncryptor(t *testing.T) { encryptor.NewAESEncryptor(crypto.Key192, aes.CTR, managedSecret), encryptor.NewAESEncryptor(crypto.Key256, aes.GCM, secureSecret2), ) - path := t.TempDir() - store := nuts.NewStore(path) - assert.NotNil(t, store) + path := ts.T().TempDir() + store := ts.storeFunc(ts.T(), path) + assert.NotNil(ts.T(), store) cn := NewChestnut(store, encryptorChainOpt) - assert.NotNil(t, cn) + assert.NotNil(ts.T(), cn) defer func() { err := cn.Close() - assert.NoError(t, err) + assert.NoError(ts.T(), err) }() err := cn.Open() - assert.NoError(t, err) + assert.NoError(ts.T(), err) key := newKey() err = cn.Put(testName, []byte(key), []byte(testValue)) - assert.NoError(t, err) + assert.NoError(ts.T(), err) operation = "decrypting" v, err := cn.Get(testName, []byte(key)) - assert.NotEmpty(t, v) - assert.NoError(t, err) - assert.Equal(t, []byte(testValue), v) + assert.NotEmpty(ts.T(), v) + assert.NoError(ts.T(), err) + assert.Equal(ts.T(), []byte(testValue), v) err = cn.Delete(testName, []byte(key)) - assert.NoError(t, err) + assert.NoError(ts.T(), err) e := value.NewSecureValue(uuid.New().String(), []byte(testValue)) err = cn.Save(testName, []byte(key), e) - assert.NoError(t, err) + assert.NoError(ts.T(), err) se1 := &value.Secure{} err = cn.Sparse(testName, []byte(key), se1) - assert.NoError(t, err) + assert.NoError(ts.T(), err) se2 := &value.Secure{} err = cn.Load(testName, []byte(key), se2) - assert.NoError(t, err) + assert.NoError(ts.T(), err) } -func TestChestnut_Compression(t *testing.T) { +func (ts *ChestnutTestSuite) TestChestnut_Compression() { compOpt := WithCompression(compress.Zstd) key := newKey() - path := filepath.Join(t.TempDir()) - store := nuts.NewStore(path) - assert.NotNil(t, store) + path := filepath.Join(ts.T().TempDir()) + store := ts.storeFunc(ts.T(), path) + assert.NotNil(ts.T(), store) cn := NewChestnut(store, encryptorOpt, compOpt) - assert.NotNil(t, cn) + assert.NotNil(ts.T(), cn) defer func() { err := cn.Close() - assert.NoError(t, err) + assert.NoError(ts.T(), err) }() err := cn.Open() - assert.NoError(t, err) + assert.NoError(ts.T(), err) err = cn.Put(testName, []byte(key), []byte(lorumIpsum)) - assert.NoError(t, err) + assert.NoError(ts.T(), err) val, err := cn.Get(testName, []byte(key)) - assert.NoError(t, err) - assert.Equal(t, lorumIpsum, string(val)) + assert.NoError(ts.T(), err) + assert.Equal(ts.T(), lorumIpsum, string(val)) } -func TestChestnut_Compressors(t *testing.T) { +func (ts *ChestnutTestSuite) TestChestnut_Compressors() { compOpt := WithCompressors(zstd.Compress, zstd.Decompress) key := newKey() - path := filepath.Join(t.TempDir()) - store := nuts.NewStore(path) - assert.NotNil(t, store) + path := filepath.Join(ts.T().TempDir()) + store := ts.storeFunc(ts.T(), path) + assert.NotNil(ts.T(), store) cn := NewChestnut(store, encryptorOpt, compOpt) - assert.NotNil(t, cn) + assert.NotNil(ts.T(), cn) defer func() { err := cn.Close() - assert.NoError(t, err) + assert.NoError(ts.T(), err) }() err := cn.Open() - assert.NoError(t, err) + assert.NoError(ts.T(), err) err = cn.Put(testName, []byte(key), []byte(lorumIpsum)) - assert.NoError(t, err) + assert.NoError(ts.T(), err) val, err := cn.Get(testName, []byte(key)) - assert.NoError(t, err) - assert.Equal(t, lorumIpsum, string(val)) + assert.NoError(ts.T(), err) + assert.Equal(ts.T(), lorumIpsum, string(val)) } -func TestChestnut_OpenErr(t *testing.T) { +func (ts *ChestnutTestSuite) TestChestnut_OpenErr() { cn := &Chestnut{} err := cn.Open() - assert.Error(t, err) + assert.Error(ts.T(), err) } -func TestChestnut_SetLogger(t *testing.T) { - path := t.TempDir() - store := nuts.NewStore(path) - assert.NotNil(t, store) +func (ts *ChestnutTestSuite) TestChestnut_SetLogger() { + path := ts.T().TempDir() + store := ts.storeFunc(ts.T(), path) + assert.NotNil(ts.T(), store) cn := NewChestnut(store, encryptorOpt) - cn.SetLogger(log.NewZapLoggerWithLevel(log.DebugLevel)) - defer func() { - err := cn.Close() - assert.NoError(t, err) - }() - err := cn.Open() - assert.NoError(t, err) + logTests := []log.Logger{ + nil, + log.NewZapLoggerWithLevel(log.DebugLevel), + } + for _, test := range logTests { + cn.SetLogger(test) + err := cn.Open() + assert.NoError(ts.T(), err) + err = cn.Close() + assert.NoError(ts.T(), err) + } } -func TestChestnut_WithLogger(t *testing.T) { +func (ts *ChestnutTestSuite) TestChestnut_WithLogger() { levels := []log.Level{ log.DebugLevel, log.InfoLevel, @@ -591,17 +610,109 @@ func TestChestnut_WithLogger(t *testing.T) { WithStdLogger, WithZapLogger, } - path := t.TempDir() - store := nuts.NewStore(path) - assert.NotNil(t, store) + path := ts.T().TempDir() + store := ts.storeFunc(ts.T(), path) + assert.NotNil(ts.T(), store) for _, level := range levels { for _, logOpt := range logOpts { opt := logOpt(level) cn := NewChestnut(store, encryptorOpt, opt) err := cn.Open() - assert.NoError(t, err) + assert.NoError(ts.T(), err) err = cn.Close() - assert.NoError(t, err) + assert.NoError(ts.T(), err) } } } + +func (ts *ChestnutTestSuite) TestChestnut_BadConfig() { + store := ts.storeFunc(ts.T(), ts.T().TempDir()) + assert.Panics(ts.T(), func() { + _ = NewChestnut(nil, encryptorOpt) + }) + assert.Panics(ts.T(), func() { + _ = NewChestnut(store) + }) + assert.Panics(ts.T(), func() { + _ = NewChestnut(store, encryptorOpt, WithCompression("X")) + }) + assert.Panics(ts.T(), func() { + _ = NewChestnut(store, encryptorOpt, WithCompressors(nil, nil)) + }) + assert.Panics(ts.T(), func() { + _ = NewChestnut(store, encryptorOpt, WithCompressors(compress.PassthroughCompressor, nil)) + }) + assert.Panics(ts.T(), func() { + _ = NewChestnut(store, encryptorOpt, WithCompressors(nil, compress.PassthroughDecompressor)) + }) +} + +type badEncryptor struct {} + +func (b badEncryptor) ID() string { + return "a" +} + +func (b badEncryptor) Name() string { + return "a" +} + +func (b badEncryptor) Encrypt([]byte) ([]byte, error) { + return nil, errors.New("an error") +} + +func (b badEncryptor) Decrypt([]byte) ([]byte,error) { + return nil, errors.New("an error") +} + +var _ crypto.Encryptor = (*badEncryptor)(nil) + +func (ts *ChestnutTestSuite) TestChestnut_BadEncryptor() { + var testGood = []byte("test-good") + var testBad = []byte("test-bad") + badCompress := func(data []byte) (compressed []byte, err error) { + return nil, errors.New("error") + } + store := ts.storeFunc(ts.T(), ts.T().TempDir()) + assert.Panics(ts.T(), func() { + _ = NewChestnut(store, WithEncryptor(nil)) + }) + cn := NewChestnut(store, encryptorOpt) + err := cn.Open() + assert.NoError(ts.T(), err) + err = cn.Put(testName, testGood, testGood) + assert.NoError(ts.T(), err) + err = cn.Close() + assert.NoError(ts.T(), err) + + cn = NewChestnut(store, WithEncryptor(&badEncryptor{})) + err = cn.Open() + assert.NoError(ts.T(), err) + err = cn.Put(testName, testBad, testBad) + assert.Error(ts.T(), err) + _, err = cn.Get(testName, testGood) + assert.Error(ts.T(), err) + err = cn.Close() + assert.NoError(ts.T(), err) + + compOpt := WithCompressors(compress.PassthroughCompressor, compress.PassthroughDecompressor) + cn = NewChestnut(store, encryptorOpt, compOpt) + err = cn.Open() + assert.NoError(ts.T(), err) + err = cn.Put(testName, testGood, testGood) + assert.NoError(ts.T(), err) + err = cn.Close() + assert.NoError(ts.T(), err) + + cn = NewChestnut(store, encryptorOpt, WithCompressors(badCompress, badCompress)) + err = cn.Open() + assert.NoError(ts.T(), err) + err = cn.Put(testName, testBad, testBad) + assert.Error(ts.T(), err) + assert.Error(ts.T(), err) + _, err = cn.Get(testName, testGood) + assert.Error(ts.T(), err) + err = cn.Close() + assert.NoError(ts.T(), err) +} + diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..990b020 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,2 @@ +ignore: + - "examples/" diff --git a/encoding/compress/compress.go b/encoding/compress/compress.go index 7f0cf83..3147aa2 100644 --- a/encoding/compress/compress.go +++ b/encoding/compress/compress.go @@ -21,6 +21,22 @@ const ( Zstd Format = "zstd" ) +func (f Format) Valid() bool { + switch f { + case None: + break + case Custom: + break + case Zstd: + break + default: + return false + } + return true +} + + + // CompressorFunc is the function the prototype for compression. type CompressorFunc func(data []byte) (compressed []byte, err error) diff --git a/encoding/compress/compress_test.go b/encoding/compress/compress_test.go index 5404f3f..50b53ca 100644 --- a/encoding/compress/compress_test.go +++ b/encoding/compress/compress_test.go @@ -1,6 +1,7 @@ package compress import ( + "bytes" "testing" "github.com/stretchr/testify/assert" @@ -23,8 +24,9 @@ var ( extraFmt = []byte{ 0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0x7a, 0x73, 0x74, 0x64, 0x1e, 0x69, 0x2d, 0x61, 0x6d, 0x2d, 0x1e, 0x2d, 0x74, 0x65, 0x73, 0x1e, 0x2d, 0x69, 0x6e} - badFmt = []byte{0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0xa, 0x73, 0x74, 0x64, 0x1e, 0x69, + badFmt1 = []byte{0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0xa, 0x73, 0x74, 0x64, 0x1e, 0x69, 0x2d, 0x61, 0x6d, 0x2d, 0x61, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x2d, 0x69, 0x6e} + badFmt2 = bytes.Join([][]byte{formatTag, empty}, formatSep) ) func TestEncodeFormat(t *testing.T) { @@ -61,7 +63,8 @@ func TestDecodeFormat(t *testing.T) { {valueFmt, value, Zstd}, {compFmt, comp, Zstd}, {extraFmt, extra, Zstd}, - {badFmt, badFmt, None}, + {badFmt1, badFmt1, None}, + {badFmt2, badFmt2, None}, } for _, test := range tests { out, format := DecodeFormat(test.in) @@ -69,3 +72,13 @@ func TestDecodeFormat(t *testing.T) { assert.Equal(t, test.out, out) } } + +func TestPassthrough(t *testing.T) { + testString := []byte("test-string") + c, err := PassthroughCompressor(testString) + assert.NoError(t, err) + assert.NotEmpty(t, c) + d, err := PassthroughDecompressor(c) + assert.NoError(t, err) + assert.Equal(t, testString, d) +} diff --git a/encoding/json/decode_test.go b/encoding/json/decode_test.go index e433907..e179b0d 100644 --- a/encoding/json/decode_test.go +++ b/encoding/json/decode_test.go @@ -36,4 +36,8 @@ func TestSecureUnmarshal_Error(t *testing.T) { assert.Error(t, err) err = SecureUnmarshal([]byte("bad encoding"), secureObj, decrypt) assert.Error(t, err) + var p chan bool + err = SecureUnmarshal(familyEnc, &p, decrypt) + assert.Error(t, err) } + diff --git a/encoding/json/encode_test.go b/encoding/json/encode_test.go index dd52dbf..5f17d10 100644 --- a/encoding/json/encode_test.go +++ b/encoding/json/encode_test.go @@ -17,7 +17,6 @@ func TestSecureMarshal(t *testing.T) { assert.Equal(t, familyComp, bytes) } - func TestSecureMarshal_Error(t *testing.T) { assert.Panics(t, func() { _, _ = SecureMarshal(family, nil) @@ -28,4 +27,8 @@ func TestSecureMarshal_Error(t *testing.T) { bytes, err = SecureMarshal(nil, encrypt) assert.Error(t, err) assert.Nil(t, bytes) + var p chan bool + bytes, err = SecureMarshal(p, encrypt) + assert.Error(t, err) + assert.Nil(t, bytes) } diff --git a/encoding/json/encoders/lookup/decoder.go b/encoding/json/encoders/lookup/decoder.go index 4aa30b2..79cc024 100644 --- a/encoding/json/encoders/lookup/decoder.go +++ b/encoding/json/encoders/lookup/decoder.go @@ -27,15 +27,23 @@ type Decoder struct { func NewLookupDecoder(ctx *Context, typ reflect2.Type, decoder jsoniter.ValDecoder) jsoniter.ValDecoder { logger := log.Log if decoder == nil { - logger.Fatal(errors.New("value encoder required")) + logger.Panic(errors.New("value encoder required")) + return nil + } + if typ == nil { + logger.Panic(errors.New("decoder typ required")) return nil } if ctx == nil { - logger.Fatal(errors.New("lookup context required")) + logger.Panic(errors.New("lookup context required")) + return nil + } + if ctx.Token == "" { + logger.Panic(errors.New("lookup token required")) return nil } if ctx.Stream == nil { - logger.Fatal(errors.New("lookup stream required")) + logger.Panic(errors.New("lookup stream required")) return nil } return &Decoder{ diff --git a/encoding/json/encoders/lookup/decoder_test.go b/encoding/json/encoders/lookup/decoder_test.go index e151de1..9c5d0cd 100644 --- a/encoding/json/encoders/lookup/decoder_test.go +++ b/encoding/json/encoders/lookup/decoder_test.go @@ -77,3 +77,26 @@ func TestLookupDecoder_Decode(t *testing.T) { assert.NotEqual(t, jsoniter.InvalidValue, any.ValueType()) } } + +func TestLookupEncoder_NewLookupDecoder(t *testing.T) { + encoder := encoders.NewEncoder() + str := "a-string" + typ := reflect2.TypeOf(&str) + enc := encoder.DecoderOf(typ) + bad1 := &Context{} + bad2 := &Context{InvalidToken, newTestStream(t)} + bad3 := &Context{"a-string-value",nil} + good := &Context{"a-string-value", newTestStream(t)} + for _, ctx := range []*Context {nil, bad1, bad2, bad3, good} { + for _, tp := range []reflect2.Type{nil, typ} { + for _, ve := range []jsoniter.ValDecoder{nil, enc} { + if ctx == good && tp == typ && ve == enc { + continue + } + assert.Panics(t, func() { + _ = NewLookupDecoder(ctx, tp, ve) + }, ctx, tp, enc) + } + } + } +} diff --git a/encoding/json/encoders/lookup/encoder.go b/encoding/json/encoders/lookup/encoder.go index a496aa5..f4bf181 100644 --- a/encoding/json/encoders/lookup/encoder.go +++ b/encoding/json/encoders/lookup/encoder.go @@ -2,7 +2,6 @@ package lookup import ( "errors" - "fmt" "unsafe" "github.com/jrapoport/chestnut/encoding/json/encoders" @@ -31,19 +30,23 @@ type Encoder struct { func NewLookupEncoder(ctx *Context, typ reflect2.Type, encoder jsoniter.ValEncoder) jsoniter.ValEncoder { logger := log.Log if encoder == nil { - logger.Fatal(errors.New("value encoder required")) + logger.Panic(errors.New("value encoder required")) + return nil + } + if typ == nil { + logger.Panic(errors.New("encoder type required")) return nil } if ctx == nil { - logger.Fatal(errors.New("lookup context required")) + logger.Panic(errors.New("lookup context required")) return nil } if ctx.Token == InvalidToken { - logger.Fatal(errors.New("lookup token required")) + logger.Panic(errors.New("lookup token required")) return nil } if ctx.Stream == nil { - logger.Fatal(errors.New("lookup stream required")) + logger.Panic(errors.New("lookup stream required")) return nil } return &Encoder{ @@ -76,8 +79,6 @@ func (e *Encoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { e.log.Debugf("use sub-encoder type %s", e.valType) // use the clean encoder to encode to our own stream. subEncoder.Encode(ptr, stream) - } else { - e.log.Error(fmt.Errorf("sub-encoder for type %s not found", e.valType)) } return } diff --git a/encoding/json/encoders/lookup/encoder_test.go b/encoding/json/encoders/lookup/encoder_test.go index 83623d6..815d6f7 100644 --- a/encoding/json/encoders/lookup/encoder_test.go +++ b/encoding/json/encoders/lookup/encoder_test.go @@ -2,6 +2,8 @@ package lookup import ( "fmt" + "github.com/jrapoport/chestnut/log" + jsoniter "github.com/json-iterator/go" "testing" "github.com/jrapoport/chestnut/encoding/json/encoders" @@ -87,3 +89,35 @@ func TestLookupEncoder_IsEmpty(t *testing.T) { test.assertEmpty(t, empty, "value: %v", test.value) } } + +func TestLookupEncoder_NewLookupEncoder(t *testing.T) { + encoder := encoders.NewEncoder() + typ := reflect2.TypeOf("a-string") + enc := encoder.EncoderOf(typ) + bad1 := &Context{} + bad2 := &Context{InvalidToken, newTestStream(t)} + bad3 := &Context{"a-string-value",nil} + good := &Context{"a-string-value", newTestStream(t)} + for _, ctx := range []*Context {nil, bad1, bad2, bad3, good} { + for _, tp := range []reflect2.Type{nil, typ} { + for _, ve := range []jsoniter.ValEncoder{nil, enc} { + if ctx == good && tp == typ && ve == enc { + continue + } + assert.Panics(t, func() { + _ = NewLookupEncoder(ctx, tp, ve) + }, ctx, tp, enc) + } + } + } +} + +func TestLookupEncoder_Fallback(t *testing.T) { + strVal := "not-empty" + stream := newTestStream(t) + encoder := encoders.NewEncoder() + kty := reflect2.TypeOf("a-string") + enc := encoder.EncoderOf(kty) + le := &Encoder{stream: stream, valType: kty, encoder: enc, log: log.Log} + le.Encode(reflect2.PtrOf(strVal), stream) +} \ No newline at end of file diff --git a/encoding/json/encoders/secure/decoder.go b/encoding/json/encoders/secure/decoder.go index abab652..b47b375 100644 --- a/encoding/json/encoders/secure/decoder.go +++ b/encoding/json/encoders/secure/decoder.go @@ -94,10 +94,6 @@ func (ext *DecoderExtension) Unseal(encoded []byte) ([]byte, error) { if err != nil { return nil, ext.logError(err) } - if err = pkg.Valid(); err != nil { - err = fmt.Errorf("invalid encoding %w", err) - return nil, ext.logError(err) - } compressed := pkg.Compressed ext.log.Debugf("package data is compressed: %t", compressed) // IF we have an encoder ID, check that it matches the package @@ -267,6 +263,9 @@ func (ext *DecoderExtension) openLookupStream() error { } func (ext *DecoderExtension) setupLookupContext(stream *jsoniter.Stream) { + if ext.lookupCtx == nil { + return + } ext.log.Debugf("setup lookup context: %s", ext.lookupCtx.Token) stream.Attachment = ext.encoder.Get(ext.lookupBuffer) ext.lookupCtx.Stream = stream diff --git a/encoding/json/encoders/secure/decoder_test.go b/encoding/json/encoders/secure/decoder_test.go index 44e3176..04b60fb 100644 --- a/encoding/json/encoders/secure/decoder_test.go +++ b/encoding/json/encoders/secure/decoder_test.go @@ -1,6 +1,7 @@ package secure import ( + "errors" "reflect" "testing" @@ -83,4 +84,72 @@ func TestSecureDecoderExtension(t *testing.T) { decoderExt.Close() }) } + d := NewSecureDecoderExtension(encoders.InvalidID, PassthroughDecryption) + assert.NotNil(t, d) + assert.Empty(t, d.encoderID ) + assert.Panics(t, func() { + _ = NewSecureDecoderExtension(encoders.InvalidID, nil) + }) +} + +func TestSecureDecoderExtension_BadUnseal(t *testing.T) { + var i int + badCompressor := func(data []byte) (compressed []byte, err error) { + if i % 2 != 0 && i < 10 { + i++ + return nil, errors.New("compression error") + } + i++ + return nil, err + } + bade := true + ext := NewSecureDecoderExtension(testEncoderID, func(plaintext []byte) (ciphertext []byte, err error) { + if bade { + return nil, errors.New("encryption error") + } + return nil, err + }, + WithCompressor(badCompressor)) + err := ext.Open() + assert.NoError(t, err) + err = ext.Open() + assert.Error(t, err) + _, err = ext.Unseal(bothEncoded) + assert.Error(t, err) + ext.Close() + _, err = ext.Unseal(bothEncoded) + assert.Error(t, err) + _, err = ext.Unseal(bothSealed) + assert.Error(t, err) + bade = false + _, err = ext.Unseal(bothComp) + assert.Error(t, err) + i = 1 + _, err = ext.Unseal(bothComp) + i = 0 + ext.Close() + encoder := encoders.NewEncoder() + encoder.RegisterExtension(ext) + err = encoder.Unmarshal(allComp, &None{}) + assert.Error(t, err) + err = ext.Open() + assert.NoError(t, err) + assert.Panics(t, func() { + ext.decryptFunc = nil + _, err = ext.Unseal(bothComp) + assert.Error(t, err) + }) +} + +func TestSecureDecoderExtension_BadOpen(t *testing.T) { + ext := NewSecureDecoderExtension(testEncoderID, PassthroughDecryption) + err := ext.Open() + assert.NoError(t, err) + err = ext.Open() + assert.Error(t, err) + ext.Close() + ext.lookupCtx = nil + err = ext.Open() + assert.Error(t, err) + ext.Close() } diff --git a/encoding/json/encoders/secure/encoder.go b/encoding/json/encoders/secure/encoder.go index e4ff3e8..095ea45 100644 --- a/encoding/json/encoders/secure/encoder.go +++ b/encoding/json/encoders/secure/encoder.go @@ -73,7 +73,7 @@ func NewSecureEncoderExtension(encoderID string, efn EncryptionFunction, opt ... ext.encoder = encoder ext.lookupCtx = &lookup.Context{Token: token} if encoder == nil { - ext.log.Fatal(errors.New("encoder not found")) + ext.log.Panic(errors.New("encoder not found")) } if efn == nil { ext.log.Panic(errors.New("encryption required")) @@ -269,6 +269,9 @@ func (ext *EncoderExtension) openLookupStream() error { } func (ext *EncoderExtension) setupLookupContext(stream *jsoniter.Stream) { + if ext.lookupCtx == nil { + return + } ext.log.Debugf("setup lookup context: %s", ext.lookupCtx.Token) // reset the lookup index to 0 stream.Attachment = 0 diff --git a/encoding/json/encoders/secure/encoder_test.go b/encoding/json/encoders/secure/encoder_test.go index b63a63d..7ad64b8 100644 --- a/encoding/json/encoders/secure/encoder_test.go +++ b/encoding/json/encoders/secure/encoder_test.go @@ -1,6 +1,8 @@ package secure import ( + "errors" + "github.com/jrapoport/chestnut/log" "reflect" "testing" @@ -42,6 +44,7 @@ func TestSecureEncoderExtension(t *testing.T) { // register encoding extension encoderExt := NewSecureEncoderExtension(testEncoderID, PassthroughEncryption, + WithLogger(log.Log), test.compressed) encoder.RegisterExtension(encoderExt) // open the encoder @@ -63,4 +66,85 @@ func TestSecureEncoderExtension(t *testing.T) { assert.NoError(t, pkg.Valid()) }) } + e := NewSecureEncoderExtension(encoders.InvalidID, PassthroughEncryption) + assert.NotNil(t, e) + assert.NotEmpty(t, e.encoderID ) + assert.Panics(t, func() { + _ = NewSecureEncoderExtension(encoders.InvalidID, nil) + }) } + +func TestSecureEncoderExtension_BadSeal(t *testing.T) { + var i int + badCompressor := func(data []byte) (compressed []byte, err error) { + if i % 2 != 0 && i < 10 { + i++ + return nil, errors.New("compression error") + } + i++ + return nil, err + } + bade := true + ext := NewSecureEncoderExtension(testEncoderID, func(plaintext []byte) (ciphertext []byte, err error) { + if bade { + return nil, errors.New("encryption error") + } + return nil, err + }, + WithCompressor(badCompressor)) + err := ext.Open() + assert.NoError(t, err) + i = 0 + ext.Close() + ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") + _, err = ext.Seal(bothEncoded) + i = 1 + ext.Close() + ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") + _, err = ext.Seal(bothEncoded) + i = 10 + ext.Close() + assert.Error(t, err) + ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") + _, err = ext.Seal(bothEncoded) + assert.Error(t, err) + i = 10 + bade = false + ext.Close() + assert.Error(t, err) + ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") + ext.encoderID = encoders.InvalidID + _, err = ext.Seal(bothEncoded) + assert.Error(t, err) + i = 10 + bade = false + ext.Close() + assert.Error(t, err) + ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") + ext.encoderID = testEncoderID + ext.lookupCtx.Stream = nil + _, err = ext.Seal(bothEncoded) + assert.Error(t, err) +} + +func TestSecureEncoderExtension_BadOpen(t *testing.T) { + ext := NewSecureEncoderExtension(testEncoderID, PassthroughEncryption) + err := ext.Open() + assert.NoError(t, err) + err = ext.Open() + assert.Error(t, err) + ext.Close() + ctx := ext.lookupCtx + ext.lookupCtx = nil + err = ext.Open() + assert.Error(t, err) + ext.lookupCtx = ctx + ext.lookupCtx.Token = encoders.InvalidID + err = ext.Open() + assert.Error(t, err) + ext.lookupCtx = ctx + ext.lookupCtx.Stream = nil + err = ext.Open() + assert.Error(t, err) +} + diff --git a/encoding/json/packager/encoding.go b/encoding/json/packager/encoding.go index 2b1d0ca..8d34c97 100644 --- a/encoding/json/packager/encoding.go +++ b/encoding/json/packager/encoding.go @@ -3,11 +3,14 @@ package packager import ( "bytes" "encoding/gob" + "errors" + "github.com/jrapoport/chestnut/encoding/json/encoders" ) // EncodePackage returns a valid binary enc package for storage. func EncodePackage(encoderID, token string, cipher, encoded []byte, compressed bool) ([]byte, error) { - if encoderID == "" { + if encoderID == encoders.InvalidID { + return nil, errors.New("invalid encoder id") } format := Secure // are we sparse? diff --git a/encoding/json/packager/package_test.go b/encoding/json/packager/package_test.go index 30e9996..ccee435 100644 --- a/encoding/json/packager/package_test.go +++ b/encoding/json/packager/package_test.go @@ -1,6 +1,8 @@ package packager import ( + "bytes" + "encoding/gob" "testing" "github.com/stretchr/testify/assert" @@ -22,6 +24,7 @@ var ( zstd = []byte("KLUv/QQAAQEAeyJ0ZXN0X29iamVjdCI6eyJjbmMxZmY3NzU1IjowfX1hE1Nm") emptyZstd = []byte("KLUv/QQACQAAII1jaLY=") badVer = "999.999.999" + badVer2 = ".*" badFormat = Format("invalid") badData = []byte("==") badZstd = []byte("bm9wZQ") @@ -50,14 +53,20 @@ var tests = []TestCase{ assert.Error, assert.Error}, {badVer, "", empty, empty, noComp, nil, nil, assert.Error, assert.Error}, + {badVer2, "", empty, empty, noComp, nil, nil, + assert.Error, assert.Error}, {ver, "", empty, empty, noComp, nil, nil, assert.Error, assert.Error}, {ver, badFormat, empty, empty, noComp, nil, nil, assert.Error, assert.Error}, + {ver, badFormat, id, empty, noComp, nil, nil, + assert.Error, assert.Error}, {ver, Secure, id, empty, noComp, nil, nil, assert.Error, assert.Error}, {ver, Sparse, empty, empty, noComp, nil, nil, assert.Error, assert.Error}, + {ver, Sparse, id, empty, noComp, nil, nil, + assert.Error, assert.Error}, // valid packages {ver, Secure, id, empty, noComp, sec, nil, assert.NoError, assert.NoError}, @@ -172,8 +181,25 @@ func (ts *PackageTestSuite) TestPackage_Decode() { Cipher: test.sec, Encoded: test.enc, } - bytes, err := encode(testPkg) - pkg, err := DecodePackage(bytes) + _, err := encode(testPkg) + test.unwrapErr(ts.T(), err) + } + + for _, test := range tests { + testPkg := &Package{ + Version: test.ver, + Format: test.fmt, + Compressed: test.comp, + EncoderID: test.id, + Token: test.token, + Cipher: test.sec, + Encoded: test.enc, + } + b := bytes.Buffer{} + e := gob.NewEncoder(&b) + err := e.Encode(testPkg) + assert.NoError(ts.T(), err) + pkg, err := DecodePackage(b.Bytes()) test.unwrapErr(ts.T(), err) if err != nil { assert.Nil(ts.T(), pkg) diff --git a/encryptor/aes/aes_test.go b/encryptor/aes/aes_test.go index 9915053..e7a7268 100644 --- a/encryptor/aes/aes_test.go +++ b/encryptor/aes/aes_test.go @@ -1,6 +1,7 @@ package aes import ( + "math" "testing" "github.com/jrapoport/chestnut/encryptor/crypto" @@ -67,4 +68,12 @@ func testCipher(t *testing.T, encryptCall, decryptCall CipherCall) { _, err = decryptCall(crypto.Key256, []byte(secret), bd) assert.Error(t, err) } + for _, bd := range badData { + _, err = decryptCall(0, nil, bd) + assert.Error(t, err) + } + for _, bd := range badData { + _, err = decryptCall(math.MaxInt64, nil, bd) + assert.Error(t, err) + } } diff --git a/go.mod b/go.mod index 1b198fa..b3e4b27 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/sirupsen/logrus v1.7.0 github.com/stretchr/testify v1.6.1 github.com/xujiajun/nutsdb v0.5.0 + go.etcd.io/bbolt v1.3.5 go.uber.org/zap v1.16.0 golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad ) diff --git a/go.sum b/go.sum index 45b096e..ec05ab1 100644 --- a/go.sum +++ b/go.sum @@ -860,6 +860,8 @@ github.com/xujiajun/nutsdb v0.5.0/go.mod h1:owdwN0tW084RxEodABLbO7h4Z2s9WiAjZGZF github.com/xujiajun/utils v0.0.0-20190123093513-8bf096c4f53b h1:jKG9OiL4T4xQN3IUrhUpc1tG+HfDXppkgVcrAiiaI/0= github.com/xujiajun/utils v0.0.0-20190123093513-8bf096c4f53b/go.mod h1:AZd87GYJlUzl82Yab2kTjx1EyXSQCAfZDhpTo1SQC4k= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= +go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/keystore/keyutils.go b/keystore/keyutils.go index 132f931..d0d7b6f 100644 --- a/keystore/keyutils.go +++ b/keystore/keyutils.go @@ -27,11 +27,8 @@ func PrivKeyToRSAPrivateKey(privKey crypto.PrivKey) *rsa.PrivateKey { // RSAPrivateKeyToPrivKey converts standard library rsa // private keys to libp2p/go-libp2p-core/crypto private keys. func RSAPrivateKeyToPrivKey(privateKey *rsa.PrivateKey) crypto.PrivKey { - pk, _, err := crypto.KeyPairFromStdKey(privateKey) - if err != nil { - log.Panic(err) - return nil - } + // because we are strongly typing the interface it will never fail + pk, _, _ := crypto.KeyPairFromStdKey(privateKey) return pk } @@ -52,11 +49,8 @@ func PrivKeyToECDSAPrivateKey(privKey crypto.PrivKey) *ecdsa.PrivateKey { // ECDSAPrivateKeyToPrivKey converts standard library ecdsa // private keys to libp2p/go-libp2p-core/crypto private keys. func ECDSAPrivateKeyToPrivKey(privateKey *ecdsa.PrivateKey) crypto.PrivKey { - pk, _, err := crypto.KeyPairFromStdKey(privateKey) - if err != nil { - log.Panic(err) - return nil - } + // because we are strongly typing the interface it will never fail + pk, _, _ := crypto.KeyPairFromStdKey(privateKey) return pk } @@ -77,11 +71,8 @@ func PrivKeyToEd25519PrivateKey(privKey crypto.PrivKey) *ed25519.PrivateKey { // Ed25519PrivateKeyToPrivKey converts ed25519 private keys // to libp2p/go-libp2p-core/crypto private keys. func Ed25519PrivateKeyToPrivKey(privateKey *ed25519.PrivateKey) crypto.PrivKey { - pk, _, err := crypto.KeyPairFromStdKey(privateKey) - if err != nil { - log.Panic(err) - return nil - } + // because we are strongly typing the interface it will never fail + pk, _, _ := crypto.KeyPairFromStdKey(privateKey) return pk } @@ -104,10 +95,7 @@ func PrivKeyToBTCECPrivateKey(privKey crypto.PrivKey) *btcec.PrivateKey { // private keys to libp2p/go-libp2p-core/crypto private keys. Internally // equivalent to (*crypto.Secp256k1PrivateKey)(privateKey). func BTCECPrivateKeyToPrivKey(privateKey *btcec.PrivateKey) crypto.PrivKey { - pk, _, err := crypto.KeyPairFromStdKey(privateKey) - if err != nil { - log.Panic(err) - return nil - } + // because we are strongly typing the interface it will never fail + pk, _, _ := crypto.KeyPairFromStdKey(privateKey) return pk } diff --git a/storage/bolt/store.go b/storage/bolt/store.go new file mode 100644 index 0000000..1496fb1 --- /dev/null +++ b/storage/bolt/store.go @@ -0,0 +1,322 @@ +package bolt + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/jrapoport/chestnut/log" + "github.com/jrapoport/chestnut/storage" + jsoniter "github.com/json-iterator/go" + bolt "go.etcd.io/bbolt" +) + +const ( + logName = "bolt" + storeName = "chest.db" +) + +// boltStore is an implementation the Storage interface for bbolt +// https://github.com/etcd-io/bbolt. +type boltStore struct { + opts storage.StoreOptions + path string + db *bolt.DB + log log.Logger +} + +var _ storage.Storage = (*boltStore)(nil) + +// NewStore is used to instantiate a datastore backed by bbolt. +func NewStore(path string, opt ...storage.StoreOption) storage.Storage { + opts := storage.ApplyOptions(storage.DefaultStoreOptions, opt...) + logger := log.Named(opts.Logger(), logName) + if path == "" { + logger.Fatal("store path required") + } + return &boltStore{path: path, opts: opts, log: logger} +} + +// Options returns the configuration options for the store. +func (s *boltStore) Options() storage.StoreOptions { + return s.opts +} + +// Open opens the store. +func (s *boltStore) Open() (err error) { + s.log.Debugf("opening store at path: %s", s.path) + var path string + path, err = ensureDBPath(s.path) + if err != nil { + err = s.logError("open", err) + return + } + s.db, err = bolt.Open(path, 0600, nil) + if err != nil { + err = s.logError("open", err) + return + } + if s.db == nil { + err = errors.New("unable to open backing store") + err = s.logError("open", err) + return + } + s.log.Infof("opened store at path: %s", s.path) + return +} + +// Put an entry in the store. +func (s *boltStore) Put(name string, key []byte, value []byte) error { + s.log.Debugf("put: %d value bytes to key: %s", len(value), key) + if err := storage.ValidKey(name, key); err != nil { + return s.logError("put", err) + } else if len(value) <= 0 { + err = errors.New("value cannot be empty") + return s.logError("put", err) + } + putValue := func(tx *bolt.Tx) error { + s.log.Debugf("put: tx %d bytes to key: %s.%s", + len(value), name, string(key)) + b, err := tx.CreateBucketIfNotExists([]byte(name)) + if err != nil { + return err + } + return b.Put(key, value) + } + return s.logError("put", s.db.Update(putValue)) +} + +// Get a value from the store. +func (s *boltStore) Get(name string, key []byte) ([]byte, error) { + s.log.Debugf("get: value at key: %s", key) + if err := storage.ValidKey(name, key); err != nil { + return nil, s.logError("get", err) + } + var value []byte + getValue := func(tx *bolt.Tx) error { + s.log.Debugf("get: tx key: %s.%s", name, key) + b := tx.Bucket([]byte(name)) + if b == nil { + return fmt.Errorf("bucket not found: %s", name) + } + v := b.Get(key) + if len(v) <= 0 { + return errors.New("nil value") + } + value = v + s.log.Debugf("get: tx key: %s.%s value (%d bytes)", + name, string(key), len(value)) + return nil + } + if err := s.db.View(getValue); err != nil { + return nil, s.logError("get", err) + } + return value, nil +} + +// Save the value in v and store the result at key. +func (s *boltStore) Save(name string, key []byte, v interface{}) error { + b, err := jsoniter.Marshal(v) + if err != nil { + return s.logError("save", err) + } + return s.Put(name, key, b) +} + +// Load the value at key and stores the result in v. +func (s *boltStore) Load(name string, key []byte, v interface{}) error { + b, err := s.Get(name, key) + if err != nil { + return s.logError("load", err) + } + return s.logError("load", jsoniter.Unmarshal(b, v)) +} + +// Has checks for a key in the store. +func (s *boltStore) Has(name string, key []byte) (bool, error) { + s.log.Debugf("has: key: %s", key) + if err := storage.ValidKey(name, key); err != nil { + return false, s.logError("has", err) + } + var has bool + hasKey := func(tx *bolt.Tx) error { + s.log.Debugf("has: tx get namespace: %s", name) + b := tx.Bucket([]byte(name)) + if b == nil { + err := fmt.Errorf("bucket not found: %s", name) + return err + } + v := b.Get(key) + has = len(v) > 0 + if has { + s.log.Debugf("has: tx key found: %s.%s", name, string(key)) + } + return nil + } + if err := s.db.View(hasKey); err != nil { + return false, s.logError("has", err) + } + s.log.Debugf("has: found key %s: %t", key, has) + return has, nil +} + +// Delete removes a key from the store. +func (s *boltStore) Delete(name string, key []byte) error { + s.log.Debugf("delete: key: %s", key) + if err := storage.ValidKey(name, key); err != nil { + return s.logError("delete", err) + } + del := func(tx *bolt.Tx) error { + s.log.Debugf("delete: tx key: %s.%s", name, string(key)) + b := tx.Bucket([]byte(name)) + if b == nil { + err := fmt.Errorf("bucket not found: %s", name) + // an error just means we couldn't find the bucket + s.log.Warn(err) + return nil + } + return b.Delete(key) + } + return s.logError("delete", s.db.Update(del)) +} + +// List returns a list of all keys in the namespace. +func (s *boltStore) List(name string) (keys [][]byte, err error) { + s.log.Debugf("list: keys in namespace: %s", name) + listKeys := func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(name)) + if b == nil { + err = fmt.Errorf("bucket not found: %s", name) + return err + } + keys, err = s.listKeys(name, b) + return err + } + if err = s.db.View(listKeys); err != nil { + return nil, s.logError("list", err) + } + s.log.Debugf("list: found %d keys: %s", len(keys), keys) + return +} + +func (s *boltStore) listKeys(name string, b *bolt.Bucket) ([][]byte, error) { + if b == nil { + err := fmt.Errorf("invalid bucket: %s", name) + return nil, err + } + var keys [][]byte + s.log.Debugf("list: tx scan namespace: %s", name) + count := b.Stats().KeyN + keys = make([][]byte, count) + s.log.Debugf("list: tx found %d keys in: %s", count, name) + var i int + _ = b.ForEach(func(k, _ []byte) error { + s.log.Debugf("list: tx found key: %s.%s", name, string(k)) + keys[i] = k + i++ + return nil + }) + return keys, nil +} + +// ListAll returns a mapped list of all keys in the store. +func (s *boltStore) ListAll() (map[string][][]byte, error) { + s.log.Debugf("list: all keys") + var total int + allKeys := map[string][][]byte{} + listKeys := func(tx *bolt.Tx) error { + err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + keys, err := s.listKeys(string(name), b) + if err != nil { + return err + } + if len(keys) <= 0 { + return nil + } + allKeys[string(name)] = keys + total += len(keys) + return nil + }) + return err + } + if err := s.db.View(listKeys); err != nil { + return nil, s.logError("list", err) + } + s.log.Debugf("list: found %d keys: %s", total, allKeys) + return allKeys, nil +} + +// Export copies the datastore to directory at path. +func (s *boltStore) Export(path string) error { + s.log.Debugf("export: to path: %s", path) + if path == "" { + err := fmt.Errorf("invalid path: %s", path) + return s.logError("export", err) + } else if s.path == path { + err := fmt.Errorf("path cannot be store path: %s", path) + return s.logError("export", err) + } + var err error + path, err = ensureDBPath(path) + if err != nil { + return s.logError("export", err) + } + err = s.db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(path, 0600) + }) + if err != nil { + return s.logError("export", err) + } + s.log.Debugf("export: to path complete: %s", path) + return nil +} + +// Close closes the datastore and releases all db resources. +func (s *boltStore) Close() error { + s.log.Debugf("closing store at path: %s", s.path) + err := s.db.Close() + s.db = nil + s.log.Info("store closed") + return s.logError("close", err) +} + +func (s *boltStore) logError(name string, err error) error { + if err == nil { + return nil + } + if name != "" { + err = fmt.Errorf("%s: %w", name, err) + } + s.log.Error(err) + return err +} + +func ensureDBPath(path string) (string, error) { + if path == "" { + return "", errors.New("path not found") + } + // does the path exist? + _, err := os.Stat(path) + exists := !os.IsNotExist(err) + if err != nil && exists { + return "", err + } + if !exists { + // make sure the directory path exists + if err = os.MkdirAll(path, 0700); err != nil { + return "", err + } + } + // is the path a directory? + d, err := os.Stat(path) + if err != nil { + return "", err + } + if !d.Mode().IsDir() { + return path, nil + } + // if we have a directory, then append our default name + path = filepath.Join(path, storeName) + return path, nil +} diff --git a/storage/bolt/store_test.go b/storage/bolt/store_test.go new file mode 100644 index 0000000..ddecc7a --- /dev/null +++ b/storage/bolt/store_test.go @@ -0,0 +1,11 @@ +package bolt + +import ( + "testing" + + "github.com/jrapoport/chestnut/storage/store_test" +) + +func TestStore(t *testing.T) { + store_test.TestStore(t, NewStore) +} diff --git a/storage/nuts/store.go b/storage/nuts/store.go index c334b0e..a0ce692 100644 --- a/storage/nuts/store.go +++ b/storage/nuts/store.go @@ -13,34 +13,34 @@ import ( const logName = "nutsdb" -// Store is an implementation the Storage interface for nutsdb +// nutsDBStore is an implementation the Storage interface for nutsdb // https://github.com/xujiajun/nutsdb. -type Store struct { +type nutsDBStore struct { opts storage.StoreOptions path string db *nutsdb.DB log log.Logger } -var _ storage.Storage = (*Store)(nil) +var _ storage.Storage = (*nutsDBStore)(nil) // NewStore is used to instantiate a datastore backed by nutsdb. -func NewStore(path string, opt ...storage.StoreOption) *Store { +func NewStore(path string, opt ...storage.StoreOption) storage.Storage { opts := storage.ApplyOptions(storage.DefaultStoreOptions, opt...) logger := log.Named(opts.Logger(), logName) if path == "" { logger.Fatal("store path required") } - return &Store{path: path, opts: opts, log: logger} + return &nutsDBStore{path: path, opts: opts, log: logger} } // Options returns the configuration options for the store. -func (s *Store) Options() storage.StoreOptions { +func (s *nutsDBStore) Options() storage.StoreOptions { return s.opts } // Open opens the store. -func (s *Store) Open() (err error) { +func (s *nutsDBStore) Open() (err error) { s.log.Debugf("opening store at path: %s", s.path) opt := nutsdb.DefaultOptions opt.Dir = s.path @@ -48,12 +48,17 @@ func (s *Store) Open() (err error) { err = s.logError("open", err) return } + if s.db == nil { + err = errors.New("unable to open backing store") + err = s.logError("open", err) + return + } s.log.Infof("opened store at path: %s", s.path) return } // Put an entry in the store. -func (s *Store) Put(name string, key []byte, value []byte) error { +func (s *nutsDBStore) Put(name string, key []byte, value []byte) error { s.log.Debugf("put: %d value bytes to key: %s", len(value), key) if err := storage.ValidKey(name, key); err != nil { return s.logError("put", err) @@ -62,23 +67,22 @@ func (s *Store) Put(name string, key []byte, value []byte) error { return s.logError("put", err) } putValue := func(tx *nutsdb.Tx) error { - s.log.Debugf("put: tx key: %s.%s value (%d bytes)", - name, string(key), len(value)) + s.log.Debugf("put: tx %d bytes to key: %s.%s", + len(value), name, string(key)) return tx.Put(name, key, value, 0) } return s.logError("put", s.db.Update(putValue)) } // Get a value from the store. -func (s *Store) Get(name string, key []byte) ([]byte, error) { +func (s *nutsDBStore) Get(name string, key []byte) ([]byte, error) { s.log.Debugf("get: value at key: %s", key) if err := storage.ValidKey(name, key); err != nil { return nil, s.logError("get", err) } var value []byte getValue := func(tx *nutsdb.Tx) error { - s.log.Debugf("get: tx key: %s.%s", - name, key) + s.log.Debugf("get: tx key: %s.%s", name, key) e, err := tx.Get(name, key) if err != nil { return err @@ -95,25 +99,25 @@ func (s *Store) Get(name string, key []byte) ([]byte, error) { } // Save the value in v and store the result at key. -func (s *Store) Save(name string, key []byte, v interface{}) error { - bytes, err := jsoniter.Marshal(v) +func (s *nutsDBStore) Save(name string, key []byte, v interface{}) error { + b, err := jsoniter.Marshal(v) if err != nil { return s.logError("save", err) } - return s.Put(name, key, bytes) + return s.Put(name, key, b) } // Load the value at key and stores the result in v. -func (s *Store) Load(name string, key []byte, v interface{}) error { - bytes, err := s.Get(name, key) +func (s *nutsDBStore) Load(name string, key []byte, v interface{}) error { + b, err := s.Get(name, key) if err != nil { return s.logError("load", err) } - return s.logError("load", jsoniter.Unmarshal(bytes, v)) + return s.logError("load", jsoniter.Unmarshal(b, v)) } // Has checks for a key in the store. -func (s *Store) Has(name string, key []byte) (bool, error) { +func (s *nutsDBStore) Has(name string, key []byte) (bool, error) { s.log.Debugf("has: key: %s", key) if err := storage.ValidKey(name, key); err != nil { return false, s.logError("has", err) @@ -143,7 +147,7 @@ func (s *Store) Has(name string, key []byte) (bool, error) { } // Delete removes a key from the store. -func (s *Store) Delete(name string, key []byte) error { +func (s *nutsDBStore) Delete(name string, key []byte) error { s.log.Debugf("delete: key: %s", key) if err := storage.ValidKey(name, key); err != nil { return s.logError("delete", err) @@ -156,12 +160,11 @@ func (s *Store) Delete(name string, key []byte) error { } // List returns a list of all keys in the namespace. -func (s *Store) List(name string) (keys [][]byte, err error) { +func (s *nutsDBStore) List(name string) (keys [][]byte, err error) { s.log.Debugf("list: keys in namespace: %s", name) listKeys := func(tx *nutsdb.Tx) error { - var txErr error - keys, txErr = s.list(tx, name) - return txErr + keys, err = s.listKeys(name, tx) + return err } if err = s.db.View(listKeys); err != nil { return nil, s.logError("list", err) @@ -170,7 +173,7 @@ func (s *Store) List(name string) (keys [][]byte, err error) { return } -func (s *Store) list(tx *nutsdb.Tx, name string) ([][]byte, error) { +func (s *nutsDBStore) listKeys(name string, tx *nutsdb.Tx) ([][]byte, error) { var keys [][]byte s.log.Debugf("list: tx scan namespace: %s", name) entries, err := tx.GetAll(name) @@ -186,16 +189,19 @@ func (s *Store) list(tx *nutsdb.Tx, name string) ([][]byte, error) { return keys, nil } -// ListAll returns a list of all keys in the store. -func (s *Store) ListAll() (map[string][][]byte, error) { +// ListAll returns a mapped list of all keys in the store. +func (s *nutsDBStore) ListAll() (map[string][][]byte, error) { s.log.Debugf("list: all keys") var total int allKeys := map[string][][]byte{} listKeys := func(tx *nutsdb.Tx) error { for name := range s.db.BPTreeIdx { - keys, txErr := s.list(tx, name) - if txErr != nil { - return txErr + keys, err := s.listKeys(name, tx) + if err != nil { + return err + } + if len(keys) <= 0 { + continue } allKeys[name] = keys total += len(keys) @@ -210,7 +216,7 @@ func (s *Store) ListAll() (map[string][][]byte, error) { } // Export copies the datastore to directory at path. -func (s *Store) Export(path string) error { +func (s *nutsDBStore) Export(path string) error { s.log.Debugf("export: to path: %s", path) if path == "" { err := fmt.Errorf("invalid path: %s", path) @@ -227,18 +233,15 @@ func (s *Store) Export(path string) error { } // Close closes the datastore and releases all db resources. -func (s *Store) Close() error { +func (s *nutsDBStore) Close() error { s.log.Debugf("closing store at path: %s", s.path) - defer func() { - // this is fine because the only possible error - // is ErrDBClosed if the db is *already* closed - s.db = nil - s.log.Info("store closed") - }() - return s.logError("close", s.db.Close()) + err := s.db.Close() + s.db = nil + s.log.Info("store closed") + return s.logError("close", err) } -func (s *Store) logError(name string, err error) error { +func (s *nutsDBStore) logError(name string, err error) error { if err == nil { return nil } diff --git a/storage/nuts/store_test.go b/storage/nuts/store_test.go index 9c79b99..70a9098 100644 --- a/storage/nuts/store_test.go +++ b/storage/nuts/store_test.go @@ -1,221 +1,11 @@ package nuts import ( - "fmt" - "sort" "testing" - "github.com/google/uuid" - "github.com/jrapoport/chestnut/log" - "github.com/jrapoport/chestnut/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" + "github.com/jrapoport/chestnut/storage/store_test" ) -type testCase struct { - name string - key string - value string - err assert.ErrorAssertionFunc - has assert.BoolAssertionFunc -} - -type TestObject struct { - Value string -} - -var ( - testName = "test-name" - testKey = "test-key" - testValue = "test-value" - testObj = &TestObject{"hello"} -) - -var putTests = []testCase{ - {"", "", "", assert.Error, assert.False}, - {"a", testKey, "", assert.Error, assert.False}, - {"b", testKey, testValue, assert.NoError, assert.True}, - {"c/c", testKey, testValue, assert.NoError, assert.True}, - {".d", testKey, testValue, assert.NoError, assert.True}, - {testName, "", "", assert.Error, assert.False}, - {testName, "a", "", assert.Error, assert.False}, - {testName, "b", testValue, assert.NoError, assert.True}, - {testName, "c/c", testValue, assert.NoError, assert.True}, - {testName, ".d", testValue, assert.NoError, assert.True}, - {testName, testKey, testValue, assert.NoError, assert.True}, -} - -var tests = append(putTests, - testCase{testName, "not-found", "", assert.Error, assert.False}, -) - -type StoreTestSuite struct { - suite.Suite - store *Store -} - func TestStore(t *testing.T) { - suite.Run(t, new(StoreTestSuite)) -} - -func (ts *StoreTestSuite) SetupTest() { - ts.store = NewStore(ts.T().TempDir()) - err := ts.store.Open() - assert.NoError(ts.T(), err) -} - -func (ts *StoreTestSuite) TearDownTest() { - err := ts.store.Close() - assert.NoError(ts.T(), err) -} - -func (ts *StoreTestSuite) BeforeTest(_, testName string) { - switch testName { - case "TestStore_Put", - "TestStore_Save", - "TestStore_Load", - "TestStore_List", - "TestStore_ListAll": - break - default: - ts.TestStore_Put() - } -} - -func (ts *StoreTestSuite) TestStore_Put() { - for i, test := range putTests { - err := ts.store.Put(test.name, []byte(test.key), []byte(test.value)) - test.err(ts.T(), err, "%d test name: %s key: %s", i, test.name, test.key) - } -} - -func (ts *StoreTestSuite) TestStore_Save() { - err := ts.store.Save(testName, []byte(testKey), testObj) - assert.NoError(ts.T(), err) -} - -func (ts *StoreTestSuite) TestStore_Load() { - ts.T().Run("Setup", func(t *testing.T) { - ts.TestStore_Save() - }) - to := &TestObject{} - err := ts.store.Load(testName, []byte(testKey), to) - assert.NoError(ts.T(), err) - assert.Equal(ts.T(), testObj, to) -} - -func (ts *StoreTestSuite) TestStore_Get() { - for i, test := range tests { - value, err := ts.store.Get(test.name, []byte(test.key)) - test.err(ts.T(), err, "%d test name: %s key: %s", i, test.name, test.key) - assert.Equal(ts.T(), test.value, string(value), - "%d test key: %s", i, test.key) - } -} - -func (ts *StoreTestSuite) TestStore_Has() { - for i, test := range tests { - has, _ := ts.store.Has(test.name, []byte(test.key)) - test.has(ts.T(), has, "%d test key: %s", i, test.key) - } -} - -func (ts *StoreTestSuite) TestStore_List() { - const listLen = 100 - list := make([]string, listLen) - for i := 0; i < listLen; i++ { - list[i] = uuid.New().String() - err := ts.store.Put(testName, []byte(list[i]), []byte(testValue)) - assert.NoError(ts.T(), err) - } - keys, err := ts.store.List(testName) - assert.NoError(ts.T(), err) - assert.Len(ts.T(), keys, listLen) - // put both lists in the same order so we can compare them - strKeys := make([]string, len(keys)) - for i, k := range keys { - strKeys[i] = string(k) - } - sort.Strings(list) - sort.Strings(strKeys) - assert.Equal(ts.T(), list, strKeys) -} - -func (ts *StoreTestSuite) TestStore_ListAll() { - const listLen = 100 - list := make([]string, listLen) - for i := 0; i < listLen; i++ { - list[i] = uuid.New().String() - ns := fmt.Sprintf("%s%d", testName, i) - err := ts.store.Put(ns, []byte(list[i]), []byte(testValue)) - assert.NoError(ts.T(), err) - } - keyMap, err := ts.store.ListAll() - assert.NoError(ts.T(), err) - var keys []string - for _, ks := range keyMap { - for _, k := range ks { - keys = append(keys, string(k)) - } - } - assert.Len(ts.T(), keys, listLen) - sort.Strings(list) - sort.Strings(keys) - assert.Equal(ts.T(), list, keys) -} - -func (ts *StoreTestSuite) TestStore_Delete() { - var deleteTests = []struct { - key string - err assert.ErrorAssertionFunc - }{ - {"", assert.Error}, - {"a", assert.NoError}, - {"b", assert.NoError}, - {"c/c", assert.NoError}, - {".d", assert.NoError}, - {"eee", assert.NoError}, - {"not-found", assert.NoError}, - } - for i, test := range deleteTests { - err := ts.store.Delete(testName, []byte(test.key)) - test.err(ts.T(), err, "%d test key: %s", i, test.key) - } -} - -func (ts *StoreTestSuite) TestStore_Export() { - err := ts.store.Export("") - assert.Error(ts.T(), err) - err = ts.store.Export(ts.store.path) - assert.Error(ts.T(), err) - err = ts.store.Export(ts.T().TempDir()) - assert.NoError(ts.T(), err) -} - -func TestStore_WithLogger(t *testing.T) { - levels := []log.Level{ - log.DebugLevel, - log.InfoLevel, - log.WarnLevel, - log.ErrorLevel, - log.PanicLevel, - } - type LoggerOpt func(log.Level) storage.StoreOption - logOpts := []LoggerOpt{ - storage.WithLogrusLogger, - storage.WithStdLogger, - storage.WithZapLogger, - } - path := t.TempDir() - for _, level := range levels { - for _, logOpt := range logOpts { - opt := logOpt(level) - store := NewStore(path, opt) - assert.NotNil(t, store) - err := store.Open() - assert.NoError(t, err) - err = store.Close() - assert.NoError(t, err) - } - } + store_test.TestStore(t, NewStore) } diff --git a/storage/storage.go b/storage/storage.go index 1d69e10..133989a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -28,6 +28,9 @@ type Storage interface { // List returns a list of all keys in the namespace. List(namespace string) ([][]byte, error) + // ListAll returns a mapped list of all keys in the store. + ListAll() (map[string][][]byte, error) + // Delete removes a key from the store. Delete(name string, key []byte) error diff --git a/storage/store_test/test_suite.go b/storage/store_test/test_suite.go new file mode 100644 index 0000000..b1fae15 --- /dev/null +++ b/storage/store_test/test_suite.go @@ -0,0 +1,260 @@ +package store_test + +import ( + "fmt" + "sort" + "testing" + + "github.com/google/uuid" + "github.com/jrapoport/chestnut/log" + "github.com/jrapoport/chestnut/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type testCase struct { + name string + key string + value string + err assert.ErrorAssertionFunc + has assert.BoolAssertionFunc +} + +type testObject struct { + Value string +} + +var ( + testName = "test-name" + testKey = "test-key" + testValue = "test-value" + testObj = &testObject{"hello"} +) + +var putTests = []testCase{ + {"", "", "", assert.Error, assert.False}, + {"a", testKey, "", assert.Error, assert.False}, + {"b", testKey, testValue, assert.NoError, assert.True}, + {"c/c", testKey, testValue, assert.NoError, assert.True}, + {".d", testKey, testValue, assert.NoError, assert.True}, + {testName, "", "", assert.Error, assert.False}, + {testName, "a", "", assert.Error, assert.False}, + {testName, "b", testValue, assert.NoError, assert.True}, + {testName, "c/c", testValue, assert.NoError, assert.True}, + {testName, ".d", testValue, assert.NoError, assert.True}, + {testName, testKey, testValue, assert.NoError, assert.True}, +} + +var tests = append(putTests, + testCase{testName, "not-found", "", assert.Error, assert.False}, +) + +type storeFunc = func(string, ...storage.StoreOption) storage.Storage + +type storeTestSuite struct { + suite.Suite + storeFunc + store storage.Storage + path string +} + +// TestStore tests a store +func TestStore(t *testing.T, fn storeFunc) { + ts := new(storeTestSuite) + ts.storeFunc = fn + suite.Run(t, ts) +} + +// SetupTest +func (ts *storeTestSuite) SetupTest() { + ts.path = ts.T().TempDir() + ts.store = ts.storeFunc(ts.path) + err := ts.store.Open() + assert.NoError(ts.T(), err) +} + +// TearDownTest +func (ts *storeTestSuite) TearDownTest() { + err := ts.store.Close() + assert.NoError(ts.T(), err) +} + +// BeforeTest +func (ts *storeTestSuite) BeforeTest(_, testName string) { + switch testName { + case "TestStorePut", + "TestStoreSave", + "TestStoreLoad", + "TestStoreList", + "TestStoreListAll", + "TestStoreWithLogger": + break + default: + ts.TestStorePut() + } +} + +// TestStorePut +func (ts *storeTestSuite) TestStorePut() { + for i, test := range putTests { + err := ts.store.Put(test.name, []byte(test.key), []byte(test.value)) + test.err(ts.T(), err, "%d test name: %s key: %s", i, test.name, test.key) + } +} + +// TestStoreSave +func (ts *storeTestSuite) TestStoreSave() { + err := ts.store.Save(testName, []byte(testKey), testObj) + assert.NoError(ts.T(), err) +} + +// TestStoreLoad +func (ts *storeTestSuite) TestStoreLoad() { + ts.T().Run("Setup", func(t *testing.T) { + ts.TestStoreSave() + }) + to := &testObject{} + err := ts.store.Load(testName, []byte(testKey), to) + assert.NoError(ts.T(), err) + assert.Equal(ts.T(), testObj, to) +} + +// TestStoreGet +func (ts *storeTestSuite) TestStoreGet() { + for i, test := range tests { + value, err := ts.store.Get(test.name, []byte(test.key)) + test.err(ts.T(), err, "%d test name: %s key: %s", i, test.name, test.key) + assert.Equal(ts.T(), test.value, string(value), + "%d test key: %s", i, test.key) + } +} + +// TestStoreHas +func (ts *storeTestSuite) TestStoreHas() { + for i, test := range tests { + has, _ := ts.store.Has(test.name, []byte(test.key)) + test.has(ts.T(), has, "%d test key: %s", i, test.key) + } +} + +// TestStoreList +func (ts *storeTestSuite) TestStoreList() { + const listLen = 100 + list := make([]string, listLen) + for i := 0; i < listLen; i++ { + list[i] = uuid.New().String() + err := ts.store.Put(testName, []byte(list[i]), []byte(testValue)) + assert.NoError(ts.T(), err) + } + keys, err := ts.store.List(testName) + assert.NoError(ts.T(), err) + assert.Len(ts.T(), keys, listLen) + // put both lists in the same order so we can compare them + strKeys := make([]string, len(keys)) + for i, k := range keys { + strKeys[i] = string(k) + } + sort.Strings(list) + sort.Strings(strKeys) + assert.Equal(ts.T(), list, strKeys) +} + +// TestStoreListAll +func (ts *storeTestSuite) TestStoreListAll() { + const listLen = 100 + list := make([]string, listLen) + for i := 0; i < listLen; i++ { + list[i] = uuid.New().String() + ns := fmt.Sprintf("%s%d", testName, i) + err := ts.store.Put(ns, []byte(list[i]), []byte(testValue)) + assert.NoError(ts.T(), err) + } + keyMap, err := ts.store.ListAll() + assert.NoError(ts.T(), err) + var keys []string + for _, ks := range keyMap { + for _, k := range ks { + keys = append(keys, string(k)) + } + } + assert.Len(ts.T(), keys, listLen) + sort.Strings(list) + sort.Strings(keys) + assert.Equal(ts.T(), list, keys) +} + +// TestStoreDelete +func (ts *storeTestSuite) TestStoreDelete() { + var deleteTests = []struct { + key string + err assert.ErrorAssertionFunc + }{ + {"", assert.Error}, + {"a", assert.NoError}, + {"b", assert.NoError}, + {"c/c", assert.NoError}, + {".d", assert.NoError}, + {"eee", assert.NoError}, + {"not-found", assert.NoError}, + } + for i, test := range deleteTests { + err := ts.store.Delete(testName, []byte(test.key)) + test.err(ts.T(), err, "%d test key: %s", i, test.key) + } +} + +// TestStoreExport +func (ts *storeTestSuite) TestStoreExport() { + exTests := []struct { + path string + Err assert.ErrorAssertionFunc + }{ + {"", assert.Error}, + {ts.path, assert.Error}, + {ts.T().TempDir(), assert.NoError}, + } + for _, test := range exTests { + err := ts.store.Export(test.path) + test.Err(ts.T(), err) + if err == nil { + s2 := ts.storeFunc(test.path) + assert.NotNil(ts.T(), s2) + err = s2.Open() + assert.NoError(ts.T(), err) + keys, err := s2.ListAll() + assert.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), keys) + err = s2.Close() + assert.NoError(ts.T(), err) + } + } +} + +// TestStoreWithLogger +func (ts *storeTestSuite) TestStoreWithLogger() { + levels := []log.Level{ + log.DebugLevel, + log.InfoLevel, + log.WarnLevel, + log.ErrorLevel, + log.PanicLevel, + } + type LoggerOpt func(log.Level) storage.StoreOption + logOpts := []LoggerOpt{ + storage.WithLogrusLogger, + storage.WithStdLogger, + storage.WithZapLogger, + } + path := ts.T().TempDir() + for _, level := range levels { + for _, logOpt := range logOpts { + opt := logOpt(level) + store := ts.storeFunc(path, opt) + assert.NotNil(ts.T(), store) + err := store.Open() + assert.NoError(ts.T(), err) + err = store.Close() + assert.NoError(ts.T(), err) + } + } +}