chestnut/encoding/json/encoders/secure/encoder.go

336 lines
10 KiB
Go

package secure
import (
"encoding/hex"
"errors"
"fmt"
"reflect"
"sync"
"github.com/jrapoport/chestnut/encoding/json/encoders"
"github.com/jrapoport/chestnut/encoding/json/encoders/hash"
"github.com/jrapoport/chestnut/encoding/json/encoders/lookup"
"github.com/jrapoport/chestnut/encoding/json/packager"
"github.com/jrapoport/chestnut/encoding/tags"
"github.com/jrapoport/chestnut/log"
"github.com/json-iterator/go"
)
// SecureLookupPrefix will format the secure lookup token to "[prefix]-[encoder id]-[index]".
const SecureLookupPrefix = "cn"
// EncryptionFunction defines the prototype for the encryption callback.
// See WARNING regarding use of PassthroughEncryption.
type EncryptionFunction func(plaintext []byte) (ciphertext []byte, err error)
// PassthroughEncryption is a dummy function for development and testing *ONLY*.
/*
* WARNING: DO NOT USE IN PRODUCTION.
* PassthroughEncryption is *NOT* encryption and *DOES NOT* encrypt data.
*/
var PassthroughEncryption EncryptionFunction = func(plaintext []byte) ([]byte, error) {
return []byte(hex.EncodeToString(plaintext)), nil
}
// EncoderExtension is a JSON encoder extension for the encryption and decryption of JSON
// encoded data. It supports full encryption / decryption of the encoded block in in
// addition to sparse encryption and hashing of structs on a per field basis via supplementary
// JSON struct field tag options. For additional information on sparse encryption & hashing, please
// SEE: https://github.com/jrapoport/chestnut/blob/master/README.md
//
// For additional information on json-iterator extensions, please
// SEE: https://github.com/json-iterator/go/wiki/Extension
type EncoderExtension struct {
jsoniter.EncoderExtension
opts Options
encoderID string
encoder jsoniter.API
lookupCtx *lookup.Context
lookupBuffer []byte
open bool
encryptFunc EncryptionFunction
log log.Logger
mu sync.RWMutex
}
// NewSecureEncoderExtension returns a new EncoderExtension using the supplied
// EncryptionFunction. If no encoder id is supplied, a new random encoder id will be used.
func NewSecureEncoderExtension(encoderID string, efn EncryptionFunction, opt ...Option) *EncoderExtension {
const encoderName = "encoder"
if encoderID == encoders.InvalidID {
encoderID = encoders.NewEncoderID()
}
opts := DefaultOptions
opts = applyOptions(opts, opt...)
encoder := encoders.NewEncoder()
logName := fmt.Sprintf("%s [%s]", encoderName, encoderID)
token := lookup.NewLookupToken(SecureLookupPrefix, encoderID)
ext := new(EncoderExtension)
ext.opts = opts
ext.log = log.Named(opts.log, logName)
ext.encoderID = encoderID
ext.encryptFunc = efn
ext.encoder = encoder
ext.lookupCtx = &lookup.Context{Token: token}
if encoder == nil {
ext.log.Panic(errors.New("encoder not found"))
}
if efn == nil {
ext.log.Panic(errors.New("encryption required"))
}
return ext
}
// Seal encrypts and returns the encoded value as a sealed package.
func (ext *EncoderExtension) Seal(encoded []byte) ([]byte, error) {
ext.mu.Lock()
defer ext.mu.Unlock()
ext.log.Debugf("sealing %d encoded bytes: %s", len(encoded), string(encoded))
/// must do this first
if ext.open {
ext.log.Debug("encoder is open, closing it")
ext.close()
}
token := ext.lookupCtx.Token
ext.log.Debugf("package token: %s", token)
plaintext := ext.lookupBuffer
if ext.isSparse() {
ext.log.Debug("sparse encoding data")
} else {
ext.log.Debug("secure encoding data")
plaintext = encoded
token = ""
encoded = nil
}
if ext.hasCompressor() {
var err error
ext.log.Debugf("compress %d plaintext bytes", len(plaintext))
if plaintext, err = ext.compress(plaintext); err != nil {
return nil, ext.logError(err)
}
ext.log.Debugf("compressed %d plaintext bytes", len(plaintext))
ext.log.Debugf("compress %d encoded bytes", len(encoded))
if encoded, err = ext.compress(encoded); err != nil {
return nil, ext.logError(err)
}
ext.log.Debugf("compressed %d encoded bytes", len(encoded))
}
ext.log.Debugf("encrypting %d plaintext bytes: %s",
len(plaintext), string(plaintext))
// encrypt the blocks
ciphertext, err := ext.encrypt(plaintext)
if err != nil {
return nil, ext.logError(err)
}
ext.log.Debugf("encrypted %d bytes", len(ciphertext))
comp := ext.hasCompressor()
ext.log.Debug("sealing package")
pkg, err := packager.EncodePackage(ext.encoderID, token, ciphertext, encoded, comp)
if err != nil {
return nil, ext.logError(err)
}
ext.log.Debugf("sealed %d encoded bytes", len(pkg))
return pkg, nil
}
func (ext *EncoderExtension) hasCompressor() bool {
return ext.opts.compressor != nil
}
func (ext *EncoderExtension) compress(data []byte) ([]byte, error) {
if len(data) <= 0 {
return nil, nil
}
if !ext.hasCompressor() {
return data, nil
}
return ext.opts.compressor(data)
}
// encrypt calls the EncryptionFunction if set, otherwise panic.
// See WARNING regarding the use of PassthroughEncryption.
func (ext *EncoderExtension) encrypt(plaintext []byte) ([]byte, error) {
if ext.encryptFunc == nil {
ext.log.Panic(errors.New("encryption function required"))
}
return ext.encryptFunc(plaintext)
}
// UpdateStructDescriptor customizes the encoding by specifying alternate
// lookup encoder for secure struct field tags and hash struct field strings.
func (ext *EncoderExtension) UpdateStructDescriptor(structDescriptor *jsoniter.StructDescriptor) {
if !ext.isOpen() {
ext.log.Debug("encoder is not open, cannot update struct descriptor")
return
}
ext.log.Debugf("updating struct: %s", structDescriptor.Type)
for _, binding := range structDescriptor.Fields {
field := binding.Field
typ := field.Type()
ext.log.Debugf("updating struct field %s.%s", structDescriptor.Type, field.Name())
tag, has := binding.Field.Tag().Lookup(tags.JSONTag)
if !has {
ext.log.Debug("json tag not found, ignore")
continue
}
name, opts := tags.ParseJSONTag(tag)
ext.log.Debugf("json tag name: %s options: %s", name, opts)
if tags.IgnoreField(name) {
ext.log.Debugf("json tag name %s, ignore", name)
binding.ToNames = []string{}
continue
}
hashName := tags.HashName(opts)
secure := tags.IsSecure(opts)
if !secure && hashName == tags.HashNone {
ext.log.Debug("tag options not found, ignore")
continue
}
encoder := binding.Encoder
if secure {
ext.log.Debugf("added lookup encoder to secure field %s", field.Name())
encoder = lookup.NewLookupEncoder(ext.lookupCtx, typ, encoder)
if enc, ok := encoder.(*lookup.Encoder); ok {
enc.SetLogger(log.Named(ext.log, typ.String()))
}
}
if hashName != tags.HashNone && typ.Kind() == reflect.String {
// if the hash name is unsupported hashFn will be nil
if hashFn := hash.FunctionForName(hashName); hashFn != nil {
ext.log.Debugf("added %s hash encoder for field %s", field.Name(), hashName)
encoder = hash.NewHashEncoder(hashName.String(), hashFn, encoder)
if enc, ok := encoder.(*hash.Encoder); ok {
enc.SetLogger(log.Named(ext.log, hashName.String()))
}
} else {
ext.log.Warnf("%s hash encoder not found", hashName)
}
}
binding.Encoder = encoder
}
}
// Open should be called before Marshal to prepare the encoder.
func (ext *EncoderExtension) Open() error {
ext.mu.Lock()
defer ext.mu.Unlock()
ext.log.Debug("opening encoder")
if ext.open {
return ext.logError(errors.New("encoder already open"))
}
if err := ext.openLookupStream(); err != nil {
err = fmt.Errorf("failed to open encoder %w", err)
return ext.logError(err)
}
ext.open = true
ext.log.Debug("encoder open")
return nil
}
func (ext *EncoderExtension) isOpen() bool {
ext.mu.RLock()
defer ext.mu.RUnlock()
return ext.open
}
// Close should be called after Marshal, but before Seal. Calling
// Seal before Close will call Close automatically if necessary.
func (ext *EncoderExtension) Close() {
ext.mu.Lock()
defer ext.mu.Unlock()
ext.close()
}
// close is the non-locking internal close call.
func (ext *EncoderExtension) close() {
ext.log.Debug("closing encoder")
ext.open = false
ext.closeLookupStream()
ext.log.Debug("encoder closed")
}
func (ext *EncoderExtension) openLookupStream() error {
ext.log.Debug("opening lookup stream")
stream := ext.encoder.BorrowStream(nil)
if stream == nil {
return ext.logError(errors.New("lookup stream is nil"))
}
if err := stream.Flush(); err != nil {
err = fmt.Errorf("cannot flush lookup stream %w", err)
return ext.logError(err)
}
ext.setupLookupContext(stream)
if !ext.validLookupContext() {
return ext.logError(errors.New("invalid lookup context"))
}
ext.log.Debug("lookup stream open")
return nil
}
func (ext *EncoderExtension) setupLookupContext(stream *jsoniter.Stream) {
if ext.lookupCtx == nil {
return
}
ext.log.Debugf("setup lookup context: %s", ext.lookupCtx.Token)
// reset the lookup index to 0
stream.Attachment = 0
stream.WriteObjectStart()
ext.lookupCtx.Stream = stream
}
func (ext *EncoderExtension) validLookupContext() bool {
if ext.lookupCtx == nil {
ext.log.Error(errors.New("lookup context is nil"))
return false
}
if ext.lookupCtx.Stream == nil {
ext.log.Error(errors.New("lookup stream is nil"))
return false
}
if ext.lookupCtx.Token == lookup.InvalidToken {
ext.log.Error(errors.New("lookup token is invalid"))
return false
}
sa := ext.lookupCtx.Stream.Attachment
if sa == nil || sa.(int) != 0 {
ext.log.Error(errors.New("lookup index is invalid"))
return false
}
return true
}
func (ext *EncoderExtension) closeLookupStream() {
ext.log.Debug("closing lookup stream")
if ext.lookupCtx == nil {
ext.log.Warn("lookup context is nil")
return
}
stream := ext.lookupCtx.Stream
if stream == nil {
ext.log.Warn("lookup stream is nil")
return
}
stream.WriteObjectEnd()
ext.lookupBuffer = stream.Buffer()
stream.Attachment = nil
ext.encoder.ReturnStream(stream)
ext.lookupCtx.Stream = nil
ext.log.Debug("lookup stream closed")
}
// isSparse checks to see if the value used sparse encryption. If the encoded struct
// used struct tags to secure specific fields, we should have a lookup table.
func (ext *EncoderExtension) isSparse() bool {
const emptyBuffer = "{}"
return len(ext.lookupBuffer) > len(emptyBuffer)
}
func (ext *EncoderExtension) logError(e error) error {
if e == nil {
return e
}
ext.log.Error(e)
return e
}