Skip to content

Commit bc7f4b3

Browse files
authored
secrets/hashivault: Add support for VAULT_ADDR and VAULT_TOKEN as aliases for existing env variables
1 parent 448a545 commit bc7f4b3

File tree

2 files changed

+122
-9
lines changed

2 files changed

+122
-9
lines changed

secrets/hashivault/vault.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
// URLs
2020
//
2121
// For secrets.OpenKeeper, hashivault registers for the scheme "hashivault".
22-
// The default URL opener will dial a Vault server using the environment
23-
// variables "VAULT_SERVER_URL" and "VAULT_SERVER_TOKEN".
22+
// The default URL opener will dial a Vault server using the environment variables
23+
// "VAULT_SERVER_URL" (or "VAULT_ADDR") and "VAULT_SERVER_TOKEN" (or "VAULT_TOKEN").
2424
// To customize the URL opener, or for more details on the URL format,
2525
// see URLOpener.
2626
// See https://gocloud.dev/concepts/urls/ for background information.
@@ -74,8 +74,41 @@ func init() {
7474
secrets.DefaultURLMux().RegisterKeeper(Scheme, new(defaultDialer))
7575
}
7676

77+
// getVaultURL ensures that we check both VAULT_SERVER_URL and VAULT_ADDR environment
78+
// variables for the API address for vault. VAULT_SERVER_URL takes precedence over VAULT_ADDR.
79+
func getVaultURL() (string, error) {
80+
serverURL := os.Getenv("VAULT_SERVER_URL")
81+
if serverURL != "" {
82+
return serverURL, nil
83+
}
84+
85+
vaultAddr := os.Getenv("VAULT_ADDR")
86+
if vaultAddr != "" {
87+
return vaultAddr, nil
88+
}
89+
90+
return "", errors.New("neither VAULT_SERVER_URL nor VAULT_ADDR environment variables are set")
91+
}
92+
93+
// getVaultToken ensures that we check both VAULT_SERVER_TOKEN and VAULT_TOKEN environment
94+
// variables for the API token for vault. VAULT_SERVER_TOKEN takes precedence over VAULT_TOKEN.
95+
// If neither environment variables are found, then we return an empty string as token is not required.
96+
func getVaultToken() string {
97+
serverToken := os.Getenv("VAULT_SERVER_TOKEN")
98+
if serverToken != "" {
99+
return serverToken
100+
}
101+
102+
vaultToken := os.Getenv("VAULT_TOKEN")
103+
if vaultToken != "" {
104+
return vaultToken
105+
}
106+
107+
return ""
108+
}
109+
77110
// defaultDialer dials a default Vault server based on the environment variables
78-
// VAULT_SERVER_URL and VAULT_SERVER_TOKEN.
111+
// VAULT_SERVER_URL / VAULT_ADDR and VAULT_SERVER_TOKEN / VAULT_TOKEN
79112
type defaultDialer struct {
80113
init sync.Once
81114
opener *URLOpener
@@ -84,12 +117,12 @@ type defaultDialer struct {
84117

85118
func (o *defaultDialer) OpenKeeperURL(ctx context.Context, u *url.URL) (*secrets.Keeper, error) {
86119
o.init.Do(func() {
87-
serverURL := os.Getenv("VAULT_SERVER_URL")
88-
if serverURL == "" {
89-
o.err = errors.New("VAULT_SERVER_URL environment variable is not set")
120+
serverURL, err := getVaultURL()
121+
if err != nil {
122+
o.err = err
90123
return
91124
}
92-
token := os.Getenv("VAULT_SERVER_TOKEN") // token is not required
125+
token := getVaultToken()
93126
cfg := Config{Token: token, APIConfig: api.Config{Address: serverURL}}
94127
client, err := Dial(ctx, &cfg)
95128
if err != nil {

secrets/hashivault/vault_test.go

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ func newHarness(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
7171
enableTransit.Do(func() {
7272
s, err := c.Logical().Read("sys/mounts")
7373
if err != nil {
74-
t.Fatal(err, "; run secrets/vault/localvault.sh to start a dev vault container")
74+
t.Fatal(err, "; run secrets/hashivault/localvault.sh to start a dev vault container")
7575
}
7676
if _, ok := s.Data["transit/"]; !ok {
7777
if _, err := c.Logical().Write("sys/mounts/transit", map[string]interface{}{"type": "transit"}); err != nil {
78-
t.Fatal(err, "; run secrets/vault/localvault.sh to start a dev vault container")
78+
t.Fatal(err, "; run secrets/hashivault/localvault.sh to start a dev vault container")
7979
}
8080
}
8181
})
@@ -144,6 +144,34 @@ func fakeConnectionStringInEnv() func() {
144144
}
145145
}
146146

147+
func alternativeConnectionStringEnvVars() func() {
148+
oldURLVal := os.Getenv("VAULT_ADDR")
149+
oldTokenVal := os.Getenv("VAULT_TOKEN")
150+
os.Setenv("VAULT_ADDR", "http://myalternativevaultserver")
151+
os.Setenv("VAULT_TOKEN", "faketoken2")
152+
return func() {
153+
os.Setenv("VAULT_ADDR", oldURLVal)
154+
os.Setenv("VAULT_TOKEN", oldTokenVal)
155+
}
156+
}
157+
158+
func unsetConnectionStringEnvVars() func() {
159+
oldURLVal := os.Getenv("VAULT_ADDR")
160+
oldTokenVal := os.Getenv("VAULT_TOKEN")
161+
oldServerURLVal := os.Getenv("VAULT_SERVER_URL")
162+
oldServerTokenVal := os.Getenv("VAULT_SERVER_TOKEN")
163+
os.Unsetenv("VAULT_ADDR")
164+
os.Unsetenv("VAULT_TOKEN")
165+
os.Unsetenv("VAULT_SERVER_URL")
166+
os.Unsetenv("VAULT_SERVER_TOKEN")
167+
return func() {
168+
os.Setenv("VAULT_ADDR", oldURLVal)
169+
os.Setenv("VAULT_SERVER_URL", oldServerURLVal)
170+
os.Setenv("VAULT_TOKEN", oldTokenVal)
171+
os.Setenv("VAULT_SERVER_TOKEN", oldServerTokenVal)
172+
}
173+
}
174+
147175
func TestOpenKeeper(t *testing.T) {
148176
cleanup := fakeConnectionStringInEnv()
149177
defer cleanup()
@@ -171,3 +199,55 @@ func TestOpenKeeper(t *testing.T) {
171199
}
172200
}
173201
}
202+
203+
func TestGetVaultConnectionDetails(t *testing.T) {
204+
t.Run("Test Current Env Vars", func(t *testing.T) {
205+
cleanup := fakeConnectionStringInEnv()
206+
defer cleanup()
207+
208+
serverUrl, err := getVaultURL()
209+
if err != nil {
210+
t.Errorf("got unexpected error: %v", err)
211+
}
212+
if serverUrl != "http://myvaultserver" {
213+
t.Errorf("expected 'http://myvaultserver': got %q", serverUrl)
214+
}
215+
216+
vaultToken := getVaultToken()
217+
if vaultToken != "faketoken" {
218+
t.Errorf("export 'faketoken': got %q", vaultToken)
219+
}
220+
})
221+
222+
t.Run("Test Alternative Env Vars", func(t *testing.T) {
223+
cleanup := alternativeConnectionStringEnvVars()
224+
defer cleanup()
225+
226+
serverUrl, err := getVaultURL()
227+
if err != nil {
228+
t.Errorf("got unexpected error: %v", err)
229+
}
230+
if serverUrl != "http://myalternativevaultserver" {
231+
t.Errorf("export '': got %q", serverUrl)
232+
}
233+
234+
vaultToken := getVaultToken()
235+
if vaultToken != "faketoken2" {
236+
t.Errorf("export 'faketoken2': got %q", vaultToken)
237+
}
238+
})
239+
t.Run("Test Unset Env Vars Throws Error", func(t *testing.T) {
240+
cleanup := unsetConnectionStringEnvVars()
241+
defer cleanup()
242+
243+
serverUrl, err := getVaultURL()
244+
if err == nil {
245+
t.Errorf("expected error but got a url: %s", serverUrl)
246+
}
247+
248+
vaultToken := getVaultToken()
249+
if vaultToken != "" {
250+
t.Errorf("export '': got %q", vaultToken)
251+
}
252+
})
253+
}

0 commit comments

Comments
 (0)