318 lines
9.9 KiB
Go
318 lines
9.9 KiB
Go
package secure
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/jrapoport/chestnut/encoding/json/encoders"
|
|
"github.com/jrapoport/chestnut/encoding/json/encoders/lookup"
|
|
"github.com/jrapoport/chestnut/encoding/json/packager"
|
|
"github.com/jrapoport/chestnut/log"
|
|
jsoniter "github.com/json-iterator/go"
|
|
"github.com/modern-go/reflect2"
|
|
)
|
|
|
|
// DecryptionFunction defines the prototype for the decryption callback.
|
|
// See WARNING regarding use of PassthroughDecryption.
|
|
type DecryptionFunction func(ciphertext []byte) (plaintext []byte, err error)
|
|
|
|
// PassthroughDecryption is a dummy function for development and testing *ONLY*.
|
|
/*
|
|
* WARNING: DO NOT USE IN PRODUCTION.
|
|
* PassthroughDecryption is *NOT* decryption and *DOES NOT* decrypt data.
|
|
*/
|
|
var PassthroughDecryption DecryptionFunction = func(ciphertext []byte) ([]byte, error) {
|
|
return hex.DecodeString(string(ciphertext))
|
|
}
|
|
|
|
// DecoderExtension is a JSON encoder extension for the encryption and decryption of JSON
|
|
// encoded data. It supports full encryption / decryption of the encoded block in in
|
|
// addition to sparse encryption and hashing of structs on a per field basis via supplementary
|
|
// JSON struct field tag options. For addition information sparse encryption & hashing, please
|
|
// SEE: https://github.com/jrapoport/chestnut/blob/master/README.md
|
|
//
|
|
// For additional information on json-iterator extensions, please
|
|
// SEE: https://github.com/json-iterator/go/wiki/Extension
|
|
type DecoderExtension struct {
|
|
jsoniter.DecoderExtension
|
|
opts Options
|
|
encoderID string
|
|
encoder jsoniter.API
|
|
lookupCtx *lookup.Context
|
|
lookupBuffer []byte
|
|
open bool
|
|
decryptFunc DecryptionFunction
|
|
log log.Logger
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewSecureDecoderExtension returns a new DecoderExtension using the supplied DecryptionFunction. If
|
|
// an encoder id is supplied, this decoder will restrict itself to packages with a matching id.
|
|
func NewSecureDecoderExtension(encoderID string, dfn DecryptionFunction, opt ...Option) *DecoderExtension {
|
|
const decoderName = "decoder"
|
|
opts := DefaultOptions
|
|
opts = applyOptions(opts, opt...)
|
|
encoder := encoders.NewEncoder()
|
|
logName := decoderName
|
|
if encoderID != encoders.InvalidID {
|
|
logName += fmt.Sprintf(" [%s]", encoderID)
|
|
}
|
|
ext := new(DecoderExtension)
|
|
ext.opts = opts
|
|
ext.log = log.Named(opts.log, logName)
|
|
ext.encoderID = encoderID
|
|
ext.decryptFunc = dfn
|
|
ext.encoder = encoder
|
|
ext.lookupCtx = &lookup.Context{}
|
|
if ext.encoder == nil {
|
|
ext.log.Fatal(errors.New("encoder not found"))
|
|
}
|
|
if ext.decryptFunc == nil {
|
|
ext.log.Panic(errors.New("decryption function required"))
|
|
}
|
|
return ext
|
|
}
|
|
|
|
// Unseal decrypts and returns the encoded value as an unsealed package. If sparse
|
|
// is true AND the data format is sparse, the data will not be decrypted the struct
|
|
// will be decoded with empty values in place of secure fields.
|
|
// TODO: We could hash the encoded data and add that to our plaintext block before we
|
|
// encrypt it as a tamper check. Not sure that is necessary or useful right now though.
|
|
func (ext *DecoderExtension) Unseal(encoded []byte) ([]byte, error) {
|
|
ext.mu.Lock()
|
|
defer ext.mu.Unlock()
|
|
ext.log.Debugf("unsealing encoded %d bytes", len(encoded))
|
|
/// must do this first
|
|
if ext.open {
|
|
ext.log.Debug("decoder is open, closing it")
|
|
ext.close()
|
|
}
|
|
// unwrap the package
|
|
pkg, err := packager.DecodePackage(encoded)
|
|
if err != nil {
|
|
return nil, ext.logError(err)
|
|
}
|
|
compressed := pkg.Compressed
|
|
ext.log.Debugf("package data is compressed: %t", compressed)
|
|
// IF we have an encoder ID, check that it matches the package
|
|
ext.log.Debugf("checking encoding id %s", pkg.EncoderID)
|
|
if ext.encoderID != encoders.DefaultID &&
|
|
ext.encoderID != pkg.EncoderID {
|
|
err = fmt.Errorf(" encoder %s package %s id mismatch", ext.encoderID, pkg.EncoderID)
|
|
return nil, ext.logError(err)
|
|
}
|
|
ext.log.Debugf("sparse option set: %t", ext.opts.sparse)
|
|
isSparse := pkg.Format == packager.Sparse && ext.opts.sparse
|
|
ext.log.Debugf("sparse decoding: %t", isSparse)
|
|
if !isSparse {
|
|
// decrypt the data unless we are sparse decoding
|
|
ext.log.Debugf("decrypting %d ciphertext bytes", len(pkg.Cipher))
|
|
if pkg.Cipher, err = ext.decrypt(pkg.Cipher); err != nil {
|
|
return nil, ext.logError(err)
|
|
}
|
|
ext.log.Debugf("decrypted %d bytes", len(pkg.Cipher))
|
|
if compressed {
|
|
ext.log.Debug("ciphertext is compressed")
|
|
if !ext.hasDecompressor() {
|
|
err = errors.New("compressed package requires decompressor")
|
|
return nil, ext.logError(err)
|
|
}
|
|
ext.log.Debugf("decompress %d ciphertext bytes", len(pkg.Cipher))
|
|
pkg.Cipher, err = ext.decompress(pkg.Cipher)
|
|
if err != nil {
|
|
return nil, ext.logError(err)
|
|
}
|
|
ext.log.Debugf("decompressed %d ciphertext bytes", len(pkg.Cipher))
|
|
}
|
|
}
|
|
switch pkg.Format {
|
|
// the format is secure, we are done
|
|
case packager.Secure:
|
|
ext.log.Debugf("unsealed %d secure data bytes: %s", len(pkg.Cipher), string(pkg.Cipher))
|
|
return pkg.Cipher, nil
|
|
case packager.Sparse:
|
|
// set the lookup context
|
|
ext.log.Debugf("unsealed sparse token: %s", pkg.Token)
|
|
ext.lookupCtx.Token = pkg.Token
|
|
if !isSparse {
|
|
ext.log.Debugf("unsealed %d lookup data bytes: %s", len(pkg.Cipher), string(pkg.Cipher))
|
|
ext.lookupBuffer = pkg.Cipher
|
|
}
|
|
break
|
|
default:
|
|
return nil, ext.logError(errors.New("unknown package format"))
|
|
}
|
|
if compressed {
|
|
ext.log.Debug("encoded data is compressed")
|
|
if !ext.hasDecompressor() {
|
|
err = errors.New("compressed package requires decompressor")
|
|
return nil, ext.logError(err)
|
|
}
|
|
ext.log.Debugf("decompress %d encoded bytes", len(pkg.Encoded))
|
|
pkg.Encoded, err = ext.decompress(pkg.Encoded)
|
|
if err != nil {
|
|
return nil, ext.logError(err)
|
|
}
|
|
ext.log.Debugf("decompressed %d encoded bytes", len(pkg.Encoded))
|
|
}
|
|
if len(pkg.Encoded) > 0 {
|
|
ext.log.Debugf("unsealed %d sparse data bytes: %s", len(pkg.Encoded), string(pkg.Encoded))
|
|
}
|
|
return pkg.Encoded, nil
|
|
}
|
|
|
|
func (ext *DecoderExtension) hasDecompressor() bool {
|
|
return ext.opts.decompressor != nil
|
|
}
|
|
|
|
func (ext *DecoderExtension) decompress(data []byte) ([]byte, error) {
|
|
if len(data) <= 0 {
|
|
return nil, nil
|
|
}
|
|
if !ext.hasDecompressor() {
|
|
return data, nil
|
|
}
|
|
return ext.opts.decompressor(data)
|
|
}
|
|
|
|
// decrypt calls the DecryptionFunction if set, otherwise panic.
|
|
// See WARNING regarding the use of PassthroughDecryption.
|
|
func (ext *DecoderExtension) decrypt(ciphertext []byte) ([]byte, error) {
|
|
if ext.decryptFunc == nil {
|
|
ext.log.Panic(errors.New("decryption function required"))
|
|
}
|
|
return ext.decryptFunc(ciphertext)
|
|
}
|
|
|
|
// DecorateDecoder customizes the decoding by specifying alternate lookup table decoder that
|
|
// recognizes previously encoded lookup table keys and replaces them with decoded values.
|
|
func (ext *DecoderExtension) DecorateDecoder(typ reflect2.Type, decoder jsoniter.ValDecoder) jsoniter.ValDecoder {
|
|
if !ext.isOpen() {
|
|
ext.log.Debug("decoder is not open, cannot decorate decoder")
|
|
return decoder
|
|
}
|
|
if ext.lookupCtx == nil || ext.lookupCtx.Token == lookup.InvalidToken {
|
|
ext.log.Debug("decoding is not sparse, do not add lookup decoder")
|
|
return decoder
|
|
}
|
|
ext.log.Debugf("added lookup decoder for type: %s", typ)
|
|
decoder = lookup.NewLookupDecoder(ext.lookupCtx, typ, decoder)
|
|
if dec, ok := decoder.(*lookup.Decoder); ok {
|
|
dec.SetLogger(log.Named(ext.log, typ.String()))
|
|
}
|
|
return decoder
|
|
}
|
|
|
|
// Open should be called before Unmarshal to prepare the decoder.
|
|
func (ext *DecoderExtension) Open() error {
|
|
ext.mu.Lock()
|
|
defer ext.mu.Unlock()
|
|
ext.log.Debug("opening decoder")
|
|
if ext.open {
|
|
return ext.logError(errors.New("decoder already open"))
|
|
}
|
|
if err := ext.openLookupStream(); err != nil {
|
|
err = fmt.Errorf("failed to open decoder %w", err)
|
|
return ext.logError(err)
|
|
}
|
|
ext.open = true
|
|
ext.log.Debug("decoder open")
|
|
return nil
|
|
}
|
|
|
|
func (ext *DecoderExtension) isOpen() bool {
|
|
ext.mu.RLock()
|
|
defer ext.mu.RUnlock()
|
|
return ext.open
|
|
}
|
|
|
|
// Close should be called after Unmarshal.
|
|
func (ext *DecoderExtension) Close() {
|
|
ext.mu.Lock()
|
|
defer ext.mu.Unlock()
|
|
ext.close()
|
|
}
|
|
|
|
// close is the non-locking internal close call.
|
|
func (ext *DecoderExtension) close() {
|
|
ext.log.Debug("closing decoder")
|
|
ext.closeLookupStream()
|
|
ext.open = false
|
|
ext.log.Debug("decoder closed")
|
|
}
|
|
|
|
func (ext *DecoderExtension) openLookupStream() error {
|
|
ext.log.Debug("opening lookup stream")
|
|
stream := ext.encoder.BorrowStream(nil)
|
|
if stream == nil {
|
|
return ext.logError(errors.New("lookup stream is nil"))
|
|
}
|
|
if err := stream.Flush(); err != nil {
|
|
err = fmt.Errorf("cannot flush lookup stream %w", err)
|
|
return ext.logError(err)
|
|
}
|
|
// setup the lookup context
|
|
ext.setupLookupContext(stream)
|
|
if !ext.validLookupContext() {
|
|
return ext.logError(errors.New("invalid lookup context"))
|
|
}
|
|
ext.log.Debug("lookup stream open")
|
|
return nil
|
|
}
|
|
|
|
func (ext *DecoderExtension) setupLookupContext(stream *jsoniter.Stream) {
|
|
if ext.lookupCtx == nil {
|
|
return
|
|
}
|
|
ext.log.Debugf("setup lookup context: %s", ext.lookupCtx.Token)
|
|
stream.Attachment = ext.encoder.Get(ext.lookupBuffer)
|
|
ext.lookupCtx.Stream = stream
|
|
ext.lookupBuffer = nil
|
|
}
|
|
|
|
func (ext *DecoderExtension) validLookupContext() bool {
|
|
if ext.lookupCtx == nil {
|
|
ext.log.Error(errors.New("lookup context is nil"))
|
|
return false
|
|
}
|
|
if ext.lookupCtx.Stream == nil {
|
|
ext.log.Error(errors.New("lookup stream is nil"))
|
|
return false
|
|
}
|
|
if ext.lookupCtx.Token != lookup.InvalidToken &&
|
|
ext.lookupCtx.Stream.Attachment == nil {
|
|
ext.log.Error(errors.New("lookup table is nil"))
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (ext *DecoderExtension) closeLookupStream() {
|
|
ext.log.Debug("closing lookup stream")
|
|
ext.lookupBuffer = nil
|
|
if ext.lookupCtx == nil {
|
|
ext.log.Warn("lookup context is nil")
|
|
return
|
|
}
|
|
stream := ext.lookupCtx.Stream
|
|
if stream == nil {
|
|
ext.log.Warn("lookup stream is nil")
|
|
return
|
|
}
|
|
stream.Attachment = nil
|
|
ext.encoder.ReturnStream(stream)
|
|
ext.lookupCtx.Token = lookup.InvalidToken
|
|
ext.lookupCtx.Stream = nil
|
|
ext.log.Debug("lookup stream closed")
|
|
}
|
|
|
|
func (ext *DecoderExtension) logError(e error) error {
|
|
if e == nil {
|
|
return e
|
|
}
|
|
ext.log.Error(e)
|
|
return e
|
|
}
|