Skip to content

Commit

Permalink
feat(vault): Support kv v1 and decode base64 key. (#5708) (#5725)
Browse files Browse the repository at this point in the history
Fixes DGRAPH-1722
Fixes DGRAPH-1723

This PR adds support for Vault kv v1 in addition to v2. Also, it allows for Base64 encoded or raw keys fetched from vault.

(cherry picked from commit eff3dd9)
  • Loading branch information
parasssh authored Jun 25, 2020
1 parent 939338b commit 73a9ddd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
12 changes: 11 additions & 1 deletion ee/enc/util_ee_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ func resetConfig(config *viper.Viper) {
config.Set(vaultAddr, "http://localhost:8200")
config.Set(vaultRoleIDFile, "")
config.Set(vaultSecretIDFile, "")
config.Set(vaultPath, "dgraph")
config.Set(vaultPath, "secret/data/dgraph")
config.Set(vaultField, "enc_key")
config.Set(vaultFormat, "base64")
}

// TODO: The function below allows instantiating a real Vault server. But results in go.mod issues.
Expand Down Expand Up @@ -117,6 +118,15 @@ func TestNewKeyReader(t *testing.T) {
require.Nil(t, k)
require.Error(t, err)

// Bad vault_format. Must be raw or base64.
resetConfig(config)
config.Set(vaultRoleIDFile, "./test-fixtures/dummy_role_id_file")
config.Set(vaultSecretIDFile, "./test-fixtures/dummy_secret_id_file")
config.Set(vaultFormat, "foo") // error.
kr, err = newKeyReader(config)
require.Error(t, err)
require.Nil(t, kr)

// RoleID and SecretID given but RoleID file and SecretID file exists and is valid.
resetConfig(config)
//nl, _ := startVaultServer(t, "dgraph", "enc_key", "1234567890123456")
Expand Down
38 changes: 28 additions & 10 deletions ee/enc/vault_ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
package enc

import (
"encoding/base64"
"io/ioutil"

"github.com/dgraph-io/dgraph/x"
"github.com/golang/glog"
"github.com/hashicorp/vault/api"
"github.com/pkg/errors"
"github.com/spf13/pflag"
Expand All @@ -35,6 +37,7 @@ const (
vaultSecretIDFile = "vault_secretid_file"
vaultPath = "vault_path"
vaultField = "vault_field"
vaultFormat = "vault_format"
)

// RegisterVaultFlags registers the required flags to integrate with Vault.
Expand All @@ -46,10 +49,12 @@ func registerVaultFlags(flag *pflag.FlagSet) {
"File containing Vault role-id used for approle auth.")
flag.String(vaultSecretIDFile, "",
"File containing Vault secret-id used for approle auth.")
flag.String(vaultPath, "dgraph",
"Vault kv store path.")
flag.String(vaultPath, "secret/data/dgraph",
"Vault kv store path. e.g. secret/data/dgraph for kv-v2, kv/dgraph for kv-v1.")
flag.String(vaultField, "enc_key",
"Vault kv store field whose value is the encryption key.")
"Vault kv store field whose value is the Base64 encoded encryption key.")
flag.String(vaultFormat, "base64",
"Vault field format. raw or base64")
}

// vaultKeyReader implements the KeyReader interface. It reads the key from vault server.
Expand All @@ -59,6 +64,7 @@ type vaultKeyReader struct {
secretID string
path string
field string
format string
}

func newVaultKeyReader(cfg *viper.Viper) (*vaultKeyReader, error) {
Expand All @@ -68,10 +74,15 @@ func newVaultKeyReader(cfg *viper.Viper) (*vaultKeyReader, error) {
secretID: cfg.GetString(vaultSecretIDFile),
path: cfg.GetString(vaultPath),
field: cfg.GetString(vaultField),
format: cfg.GetString(vaultFormat),
}

if v.addr == "" || v.path == "" || v.field == "" {
return nil, errors.Errorf("%v, %v or %v is missing", vaultAddr, vaultPath, vaultField)
if v.addr == "" || v.path == "" || v.field == "" || v.format == "" {
return nil, errors.Errorf("%v, %v, %v or %v is missing",
vaultAddr, vaultPath, vaultField, vaultFormat)
}
if v.format != "base64" && v.format != "raw" {
return nil, errors.Errorf("vault_format = %v; must be one of base64 or raw", v.format)
}

if v.roleID != "" && v.secretID != "" {
Expand Down Expand Up @@ -120,25 +131,32 @@ func (vkr *vaultKeyReader) readKey() (x.SensitiveByteSlice, error) {
}
client.SetToken(resp.Auth.ClientToken)

// Read from KV store
secret, err := client.Logical().Read("secret/data/" + vkr.path)
// Read from KV store. The given path must be v1 or v2 format. We use it as is.
secret, err := client.Logical().Read(vkr.path)
if err != nil || secret == nil {
return nil, errors.Errorf("error or nil secret on reading key at %v: "+
"err %v", vkr.path, err)
}

// Parse key from response
var m map[string]interface{}
m, ok := secret.Data["data"].(map[string]interface{})
if !ok {
return nil, errors.Errorf("kv store read response from vault is bad")
glog.Infof("Unable to extract key from kv v2 response. Trying kv v1.")
m = secret.Data
}
kVal, ok := m[vkr.field]
if !ok {
return nil, errors.Errorf("secret key not found at %v", vkr.field)
}
kbyte := []byte(kVal.(string))

// Validate key length suitable for AES
if vkr.format == "base64" {
kbyte, err = base64.StdEncoding.DecodeString(kVal.(string))
if err != nil {
return nil, errors.Errorf("Unable to decode the Base64 Encoded key: err %v", err)
}
}
// Validate key length suitable for AES.
klen := len(kbyte)
if klen != 16 && klen != 32 && klen != 64 {
return nil, errors.Errorf("bad key length %v from vault", klen)
Expand Down

0 comments on commit 73a9ddd

Please sign in to comment.