Skip to content

Commit

Permalink
Merge pull request #289 from chainifynet/feature/serialization-refactor
Browse files Browse the repository at this point in the history
refactor(serialization): AAD, EDK refactor, improve tests
  • Loading branch information
wobondar authored Mar 12, 2024
2 parents a0549c8 + f5740ac commit 9522038
Show file tree
Hide file tree
Showing 26 changed files with 1,655 additions and 499 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ vet:
unit: lint unit-pkg

unit-pkg:
@gotestsum -f ${GOTESTSUM_FMT} -- -timeout=1m ${BUILD_TAGS} ${SDK_PKGS}
@gotestsum -f ${GOTESTSUM_FMT} -- -timeout=4m ${BUILD_TAGS} ${SDK_PKGS}

unit-race:
@gotestsum -f ${GOTESTSUM_FMT} -- -timeout=2m -cpu=4 -race -count=1 ${BUILD_TAGS} ${SDK_PKGS}
@gotestsum -f ${GOTESTSUM_FMT} -- -timeout=6m -cpu=4 -race -count=1 ${BUILD_TAGS} ${SDK_PKGS}

##
# Integration tests
Expand Down Expand Up @@ -121,7 +121,7 @@ e2e-test-slow:

test-cover:
@#CGO_ENABLED=1 go test -count=1 -coverpkg=./... -covermode=atomic -coverprofile coverage.out ./...
@CGO_ENABLED=1 go test -race -tags example,mocks,codegen,integration -count=1 -coverpkg=./... -covermode=atomic -coverprofile=coverage.out ./pkg/...
@CGO_ENABLED=1 go test -timeout=10m -tags example,mocks,codegen,integration -cpu=2 -count=1 -coverpkg=./... -covermode=atomic -coverprofile=coverage.out ./pkg/...
@#CGO_ENABLED=1 go test -tags ${CI_TAGS} -count=1 -coverpkg=./... -covermode=atomic -coverprofile coverage.out ./pkg/...
@grep -v -E -f .covignore coverage.out > coverage.filtered.out
@mv coverage.filtered.out coverage.out
Expand Down
5 changes: 0 additions & 5 deletions pkg/internal/crypto/signature/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"fmt"
"hash"

"github.com/rs/zerolog/log"

"github.com/chainifynet/aws-encryption-sdk-go/pkg/internal/crypto/hasher"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/utils/rand"
)
Expand Down Expand Up @@ -56,9 +54,6 @@ func (s *ECCSigner) Sign() ([]byte, error) {
signature = sig
break
}
log.Trace().Int("expectedLen", s.signLen).
Int("actualLen", len(sig)).
Msg("sign is not desired length. recalculating")
continue
}
return signature, nil
Expand Down
42 changes: 2 additions & 40 deletions pkg/internal/providers/keyprovider/keyprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ import (
"errors"
"fmt"

"github.com/rs/zerolog"

"github.com/chainifynet/aws-encryption-sdk-go/pkg/helpers/structs"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/keys"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/logger"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/model"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/model/types"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/providers"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/suite"
)

var (
log = logger.L().Level(zerolog.TraceLevel) //nolint:gochecknoglobals
)

type KeyProvider struct {
providerID string
providerKind types.ProviderKind
Expand Down Expand Up @@ -78,11 +71,6 @@ func (kp *KeyProvider) DecryptDataKey(ctx context.Context, MKP model.MasterKeyPr
if kp.vendOnDecrypt {
decryptMasterKey, err := MKP.MasterKeyForDecrypt(ctx, encryptedDataKey.KeyProvider())
if err != nil {
log.Trace().
Stringer("EDK", encryptedDataKey.KeyProvider()).
Str("MKP", MKP.ProviderID()).
Str("method", "DecryptDataKey").
Err(err).Msgf("cant reach MasterKey for EDK keyID: %v", encryptedDataKey.KeyID())
if errors.Is(err, providers.ErrMasterKeyProviderDecryptForbidden) {
return nil, fmt.Errorf("DecryptDataKey MKP.MasterKeyForDecrypt is forbidden for keyID %q with MKP %q: %w", encryptedDataKey.KeyID(), MKP.ProviderID(), errors.Join(providers.ErrMasterKeyProviderDecrypt, err))
}
Expand All @@ -96,14 +84,7 @@ func (kp *KeyProvider) DecryptDataKey(ctx context.Context, MKP model.MasterKeyPr
// ref https://github.com/awslabs/aws-encryption-sdk-specification/blob/master/framework/aws-kms/aws-kms-mrk-aware-master-key.md#decrypt-data-key
// For each encrypted data key in the filtered set, one at a time,
// the master key MUST attempt to decrypt the data key.
for i, memberKey := range allMembers {
log.Trace().
Int("memberI", i).
Stringer("EDK", encryptedDataKey.KeyProvider()).
Str("MKP", MKP.ProviderID()).
Str("keyID", memberKey.KeyID()).
Str("method", "DecryptDataKey").
Msg("Provider: DecryptDataKey")
for _, memberKey := range allMembers {
if !memberKey.OwnsDataKey(encryptedDataKey) {
// if memberKey does not own encryptedDataKey, try to decrypt next provider member key
continue
Expand All @@ -120,13 +101,6 @@ func (kp *KeyProvider) DecryptDataKey(ctx context.Context, MKP model.MasterKeyPr
errMemberKey = errDecrypt
// if MasterKey returns keys.ErrDecryptKey, try to decrypt next provider member key
if errors.Is(errDecrypt, keys.ErrDecryptKey) {
log.Trace().
Int("memberI", i).
Stringer("EDK", encryptedDataKey.KeyProvider()).
Str("MKP", MKP.ProviderID()).
Str("keyID", memberKey.KeyID()).
Str("method", "DecryptDataKey").
Err(errDecrypt).Msgf("cant decrypt data key by BaseKey %v, for EDK keyID: %v", memberKey.KeyID(), encryptedDataKey.KeyID())
continue
} else { //nolint:revive
break
Expand All @@ -150,20 +124,8 @@ func (kp *KeyProvider) DecryptDataKeyFromList(ctx context.Context, MKP model.Mas
var dataKey model.DataKeyI

var errDecryptDataKey error
for i, edk := range encryptedDataKeys {
log.Trace().
Int("edkI", i).
Stringer("EDK", edk.KeyProvider()). // EncryptedDataKeyI KeyMeta (ProviderID and KeyID)
Str("MKP", MKP.ProviderID()). // MasterKeyProvider ProviderID with which we try to decrypt EncryptedDataKeyI
Str("method", "DecryptDataKeyFromList").
Msg("DecryptDataKeyFromList")
for _, edk := range encryptedDataKeys {
if err := MKP.ValidateProviderID(edk.KeyProvider().ProviderID); err != nil {
log.Trace().Err(err).
Int("edkI", i).
Stringer("EDK", edk.KeyProvider()).
Str("MKP", MKP.ProviderID()).
Str("method", "DecryptDataKeyFromList").
Msg("DecryptDataKeyFromList validate expected error")
errDecryptDataKey = fmt.Errorf("DecryptDataKeyFromList validate expected error: %w", errors.Join(providers.ErrMasterKeyProviderDecrypt, err))
continue
}
Expand Down
16 changes: 1 addition & 15 deletions pkg/internal/providers/keyprovider/keyprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ func TestKeyProvider_DecryptDataKey(t *testing.T) {
Return(dk, nil)

mkp.EXPECT().MasterKeysForDecryption().Return([]model.MasterKey{masterKey})

mkp.EXPECT().ProviderID().Return("raw")
},
kp: &KeyProvider{providerID: "raw", providerKind: types.Raw, vendOnDecrypt: false},
alg: suite.AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384,
Expand All @@ -161,12 +159,9 @@ func TestKeyProvider_DecryptDataKey(t *testing.T) {
name: "Valid decryption with second MasterKey",
setupMocks: func(t *testing.T, mkp *mocks.MockMasterKeyProvider, edk *mocks.MockEncryptedDataKey, dk model.DataKeyI) {
edk.EXPECT().KeyProvider().Return(model.KeyMeta{ProviderID: "raw", KeyID: "test2"})
edk.EXPECT().KeyID().Return("test2")

mkp.EXPECT().ValidateProviderID(mock.Anything).Return(nil)

mkp.EXPECT().ProviderID().Return("raw")

masterKey := mocks.NewMockMasterKey(t)
masterKey.EXPECT().KeyID().Return("test1")
masterKey.EXPECT().OwnsDataKey(mock.Anything).Return(true)
Expand Down Expand Up @@ -208,8 +203,6 @@ func TestKeyProvider_DecryptDataKey(t *testing.T) {
setupMocks: func(t *testing.T, mkp *mocks.MockMasterKeyProvider, edk *mocks.MockEncryptedDataKey, dk model.DataKeyI) {
edk.EXPECT().KeyProvider().Return(model.KeyMeta{ProviderID: "aws-kms", KeyID: "key3"})

mkp.EXPECT().ProviderID().Return("aws-kms")

mkp.EXPECT().ValidateProviderID(mock.Anything).Return(nil)

mkp.EXPECT().MasterKeysForDecryption().Return([]model.MasterKey{})
Expand Down Expand Up @@ -257,9 +250,6 @@ func TestKeyProvider_DecryptDataKey(t *testing.T) {
name: "Error in MasterKeyForDecrypt",
setupMocks: func(t *testing.T, mkp *mocks.MockMasterKeyProvider, edk *mocks.MockEncryptedDataKey, dk model.DataKeyI) {
edk.EXPECT().KeyProvider().Return(model.KeyMeta{ProviderID: "aws-kms", KeyID: "key3"})
edk.EXPECT().KeyID().Return("key3")

mkp.EXPECT().ProviderID().Return("aws-kms")

mkp.EXPECT().ValidateProviderID(mock.Anything).Return(nil)

Expand Down Expand Up @@ -297,9 +287,7 @@ func TestKeyProvider_DecryptDataKey(t *testing.T) {
name: "Error during data key decryption by a MasterKey",
setupMocks: func(t *testing.T, mkp *mocks.MockMasterKeyProvider, edk *mocks.MockEncryptedDataKey, dk model.DataKeyI) {
edk.EXPECT().KeyProvider().Return(model.KeyMeta{ProviderID: "aws-kms", KeyID: "key3"})
edk.EXPECT().KeyID().Return("key3")

mkp.EXPECT().ProviderID().Return("aws-kms")
mkp.EXPECT().ValidateProviderID(mock.Anything).Return(nil)

masterKey := mocks.NewMockMasterKey(t)
Expand All @@ -321,9 +309,7 @@ func TestKeyProvider_DecryptDataKey(t *testing.T) {
name: "Error during data key decryption by two MasterKeys",
setupMocks: func(t *testing.T, mkp *mocks.MockMasterKeyProvider, edk *mocks.MockEncryptedDataKey, dk model.DataKeyI) {
edk.EXPECT().KeyProvider().Return(model.KeyMeta{ProviderID: "aws-kms", KeyID: "key3"})
edk.EXPECT().KeyID().Return("key3")

mkp.EXPECT().ProviderID().Return("aws-kms")
mkp.EXPECT().ValidateProviderID(mock.Anything).Return(nil)

masterKey := mocks.NewMockMasterKey(t)
Expand Down Expand Up @@ -494,7 +480,7 @@ func TestKeyProvider_DecryptDataKeyFromList(t *testing.T) { //nolint:gocognit
Return(fmt.Errorf("validate provider ID error")).Once()
for _, edk := range edks {
edk.EXPECT().KeyProvider().Return(model.KeyMeta{ProviderID: "invalid", KeyID: "key2"}).
Times(3)
Times(1)
}
},
kp: &KeyProvider{providerID: "raw", providerKind: types.Raw, vendOnDecrypt: false},
Expand Down
16 changes: 0 additions & 16 deletions pkg/keys/kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,14 @@ import (
typesaws "github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/transport/http"
"github.com/rs/zerolog"

"github.com/chainifynet/aws-encryption-sdk-go/pkg/helpers/arn"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/keys"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/logger"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/model"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/model/types"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/suite"
)

var (
log = logger.L().Level(zerolog.TraceLevel) //nolint:gochecknoglobals
)

var (
ErrKmsClient = errors.New("KMSClient error")
)
Expand Down Expand Up @@ -142,10 +136,6 @@ func (kmsMK *MasterKey) EncryptDataKey(ctx context.Context, dataKey model.DataKe
if len(encryptOutput.CiphertextBlob) == 0 {
return nil, fmt.Errorf("KMSMasterKey error: %w", errors.Join(keys.ErrEncryptKey, fmt.Errorf("dataKeyOutput.CiphertextBlob length %d is empty", len(encryptOutput.CiphertextBlob))))
}
log.Trace().
Stringer("MK", kmsMK.Metadata()).
Stringer("DK", dataKey.KeyProvider()).
Msg("MasterKey: EncryptDataKey")

return model.NewEncryptedDataKey(
kmsMK.Metadata(),
Expand Down Expand Up @@ -210,17 +200,11 @@ func (kmsMK *MasterKey) decryptDataKey(ctx context.Context, encryptedDataKey mod
// ref github.com/aws/aws-sdk-go-v2/service/kms@v1.18.5/types/errors.go
// ref2 https://github.com/aws/aws-sdk-go-v2/issues/1110
// that is normal behaviour, we'll try to decrypt with other MasterKey in MasterKeyProvider
log.Trace().Caller().AnErr("kmsErr", kmsErr).Msg("KMS expected error")
return nil, fmt.Errorf("KMSMasterKey expected error: %w", errors.Join(keys.ErrDecryptKey, ErrKmsClient, kmsErr))
}
}
}

log.Trace().Caller().
Err(err).
Stringer("MK", kmsMK.Metadata()).
Stringer("EDK", encryptedDataKey.KeyProvider()).
Msg("MasterKey: DecryptDataKey")
return nil, fmt.Errorf("KMSMasterKey error: %w", errors.Join(keys.ErrDecryptKey, ErrKmsClient, err))
}

Expand Down
12 changes: 0 additions & 12 deletions pkg/providers/kmsprovider/kmsprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,17 @@ import (
"fmt"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/rs/zerolog"

"github.com/chainifynet/aws-encryption-sdk-go/pkg/helpers/arn"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/helpers/structs"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/keys/kms"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/logger"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/model"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/model/types"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/providers"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/providers/common"
"github.com/chainifynet/aws-encryption-sdk-go/pkg/suite"
)

var (
log = logger.L().Level(zerolog.TraceLevel) //nolint:gochecknoglobals
)

type KmsProvider interface {
model.MasterKeyProvider
getClient(ctx context.Context, keyID string) (model.KMSClient, error)
Expand Down Expand Up @@ -182,10 +176,6 @@ func (kmsKP *KmsKeyProvider[KT]) getClient(ctx context.Context, keyID string) (m
if err := kmsKP.addRegionalClient(ctx, regionName); err != nil {
return nil, fmt.Errorf("KMS client error: %w", err)
}
log.Trace().
Str("region", regionName).
Str("keyID", keyID).
Msg("GET regional KMS client")
return kmsKP.regionalClients[regionName], nil
}

Expand All @@ -201,8 +191,6 @@ func (kmsKP *KmsKeyProvider[KT]) addRegionalClient(ctx context.Context, region s
if err != nil {
return fmt.Errorf("unable to load AWS config: %w", err)
}
log.Trace().Str("region", region).
Msg("Register new regional KMS client")
kmsClient := kmsKP.options.clientFactory.NewFromConfig(cfg)
kmsKP.regionalClients[region] = kmsClient
return nil
Expand Down
Loading

0 comments on commit 9522038

Please sign in to comment.