Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return non-retryable errors on transit encrypt and decrypt failures #13111

Merged
merged 9 commits into from
Nov 15, 2021
34 changes: 29 additions & 5 deletions builtin/logical/transit/path_decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"fmt"
"net/http"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
Expand Down Expand Up @@ -91,12 +92,16 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
batchResponseItems := make([]DecryptBatchResponseItem, len(batchInputItems))
contextSet := len(batchInputItems[0].Context) != 0

userErrorInBatch := false
internalErrorInBatch := false

for i, item := range batchInputItems {
if (len(item.Context) == 0 && contextSet) || (len(item.Context) != 0 && !contextSet) {
return logical.ErrorResponse("context should be set either in all the request blocks or in none"), logical.ErrInvalidRequest
}

if item.Ciphertext == "" {
userErrorInBatch = true
batchResponseItems[i].Error = "missing ciphertext to decrypt"
continue
}
Expand All @@ -105,6 +110,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
if len(item.Context) != 0 {
batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context)
if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error()
continue
}
Expand All @@ -114,6 +120,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
if len(item.Nonce) != 0 {
batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce)
if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error()
continue
}
Expand Down Expand Up @@ -143,13 +150,13 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext)
if err != nil {
switch err.(type) {
case errutil.UserError:
batchResponseItems[i].Error = err.Error()
continue
case errutil.InternalError:
internalErrorInBatch = true
default:
p.Unlock()
return nil, err
userErrorInBatch = true
}
batchResponseItems[i].Error = err.Error()
continue
}
batchResponseItems[i].Plaintext = plaintext
}
Expand All @@ -162,6 +169,11 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
} else {
if batchResponseItems[0].Error != "" {
p.Unlock()

if internalErrorInBatch {
return nil, errutil.InternalError{Err: batchResponseItems[0].Error}
}

return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest
}
resp.Data = map[string]interface{}{
Expand All @@ -170,6 +182,18 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
}

p.Unlock()

// Depending on the errors in the batch, different status codes should be returned. User errors
// will return a 400 and precede internal errors which return a 500. The reasoning behind this is
// that user errors are non-retryable without making changes to the request, and should be surfaced
// to the user first.
switch {
case userErrorInBatch:
return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest)
case internalErrorInBatch:
return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError)
}

return resp, nil
}

Expand Down
205 changes: 157 additions & 48 deletions builtin/logical/transit/path_decrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ package transit
import (
"context"
"encoding/json"
"net/http"
"reflect"
"testing"

"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
)

func TestTransit_BatchDecryption(t *testing.T) {
Expand Down Expand Up @@ -64,74 +68,179 @@ func TestTransit_BatchDecryption(t *testing.T) {
}

func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
var req *logical.Request
var resp *logical.Response
var err error

b, s := createBackendWithStorage(t)

policyData := map[string]interface{}{
"derived": true,
}

policyReq := &logical.Request{
// Create a derived key.
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/existing_key",
Storage: s,
Data: policyData,
Data: map[string]interface{}{
"derived": true,
},
}

resp, err = b.HandleRequest(context.Background(), policyReq)
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

batchInput := []interface{}{
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="},
// Encrypt some values for use in test cases.
plaintextItems := []struct {
plaintext, context string
}{
{plaintext: "dGhlIHF1aWNrIGJyb3duIGZveA==", context: "dGVzdGNvbnRleHQ="},
{plaintext: "anVtcGVkIG92ZXIgdGhlIGxhenkgZG9n", context: "dGVzdGNvbnRleHQy"},
}

batchData := map[string]interface{}{
"batch_input": batchInput,
}
batchReq := &logical.Request{
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "encrypt/existing_key",
Storage: s,
Data: batchData,
}
resp, err = b.HandleRequest(context.Background(), batchReq)
Data: map[string]interface{}{
"batch_input": []interface{}{
map[string]interface{}{"plaintext": plaintextItems[0].plaintext, "context": plaintextItems[0].context},
map[string]interface{}{"plaintext": plaintextItems[1].plaintext, "context": plaintextItems[1].context},
},
},
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

batchDecryptionInputItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)

batchDecryptionInput := make([]interface{}, len(batchDecryptionInputItems))
for i, item := range batchDecryptionInputItems {
batchDecryptionInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext, "context": "dGVzdGNvbnRleHQ="}
}

batchDecryptionData := map[string]interface{}{
"batch_input": batchDecryptionInput,
}

batchDecryptionReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "decrypt/existing_key",
Storage: s,
Data: batchDecryptionData,
}
resp, err = b.HandleRequest(context.Background(), batchDecryptionReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

batchDecryptionResponseItems := resp.Data["batch_results"].([]DecryptBatchResponseItem)

plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
for _, item := range batchDecryptionResponseItems {
if item.Plaintext != plaintext {
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, item.Plaintext)
}
encryptedItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)

tests := []struct {
name string
in []interface{}
want []DecryptBatchResponseItem
shouldErr bool
wantHTTPStatus int
}{
{
name: "nil-input",
in: nil,
shouldErr: true,
},
{
name: "empty-input",
in: []interface{}{},
shouldErr: true,
},
{
name: "single-item-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Plaintext: plaintextItems[0].plaintext},
},
},
{
name: "single-item-invalid-ciphertext",
in: []interface{}{
map[string]interface{}{"ciphertext": "xxx", "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "invalid ciphertext: no prefix"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "single-item-wrong-context",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-full-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Plaintext: plaintextItems[0].plaintext},
{Plaintext: plaintextItems[1].plaintext},
},
},
{
name: "batch-partial-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Plaintext: plaintextItems[1].plaintext},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-full-failure",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Error: "cipher: message authentication failed"},
},
wantHTTPStatus: http.StatusBadRequest,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "decrypt/existing_key",
Storage: s,
Data: map[string]interface{}{
"batch_input": tt.in,
},
}
resp, err = b.HandleRequest(context.Background(), req)

didErr := err != nil || (resp != nil && resp.IsError())
if didErr {
if !tt.shouldErr {
t.Fatalf("unexpected error err:%v, resp:%#v", err, resp)
}
} else {
if tt.shouldErr {
t.Fatal("expected error, but none occurred")
}

if rawRespBody, ok := resp.Data[logical.HTTPRawBody]; ok {
httpResp := &logical.HTTPResponse{}
err = jsonutil.DecodeJSON([]byte(rawRespBody.(string)), httpResp)
if err != nil {
t.Fatalf("failed to unmarshal nested response: err:%v, resp:%#v", err, resp)
}

if respStatus, ok := resp.Data[logical.HTTPStatusCode]; !ok || respStatus != tt.wantHTTPStatus {
t.Fatalf("HTTP response status code mismatch, want:%d, got:%d", tt.wantHTTPStatus, respStatus)
}

resp = logical.HTTPResponseToLogicalResponse(httpResp)
}

var respItems []DecryptBatchResponseItem
err = mapstructure.Decode(resp.Data["batch_results"], &respItems)
if err != nil {
t.Fatalf("problem decoding response items: err:%v, resp:%#v", err, resp)
}
if !reflect.DeepEqual(tt.want, respItems) {
t.Fatalf("response items mismatch, want:%#v, got:%#v", tt.want, respItems)
}
}
})
}
}
Loading