Skip to content

Commit

Permalink
Merge pull request #76 from drmdrew/issue-75-vault-redirects
Browse files Browse the repository at this point in the history
Handle redirects from vault server versions earlier than v0.6.2
  • Loading branch information
hairyhenderson authored Nov 19, 2016
2 parents 3d0c34f + dbdd898 commit c160666
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 23 deletions.
22 changes: 18 additions & 4 deletions vault/app-id_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,34 @@ func NewAppIDAuthStrategy() *AppIDAuthStrategy {
return nil
}

// GetToken - log in to the app-id auth backend and return the client token
func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
// GetHTTPClient configures the HTTP client with a timeout
func (a *AppIDAuthStrategy) GetHTTPClient() *http.Client {
if a.hc == nil {
a.hc = &http.Client{Timeout: time.Second * 5}
}
client := a.hc
return a.hc
}

// SetToken is a no-op for AppIDAuthStrategy as a token hasn't been acquired yet
func (a *AppIDAuthStrategy) SetToken(req *http.Request) {
// no-op
}

// Do wraps http.Client.Do
func (a *AppIDAuthStrategy) Do(req *http.Request) (*http.Response, error) {
hc := a.GetHTTPClient()
return hc.Do(req)
}

// GetToken - log in to the app-id auth backend and return the client token
func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
buf := new(bytes.Buffer)
json.NewEncoder(buf).Encode(&a)

u := &url.URL{}
*u = *addr
u.Path = "/v1/auth/app-id/login"
res, err := client.Post(u.String(), "application/json; charset=utf-8", buf)
res, err := requestAndFollow(a, "POST", u, buf.Bytes())
if err != nil {
return "", err
}
Expand Down
49 changes: 30 additions & 19 deletions vault/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ func getAuthStrategy() AuthStrategy {
return nil
}

// GetHTTPClient returns a client configured w/X-Vault-Token header
func (c *Client) GetHTTPClient() *http.Client {
if c.hc == nil {
c.hc = &http.Client{
Timeout: time.Second * 5,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
c.SetToken(req)
return nil
},
}
}
return c.hc
}

// SetToken adds an X-Vault-Token header to the request
func (c *Client) SetToken(req *http.Request) {
req.Header.Set("X-Vault-Token", c.token)
}

// Do wraps http.Client.Do
func (c *Client) Do(req *http.Request) (*http.Response, error) {
hc := c.GetHTTPClient()
return hc.Do(req)
}

// Login - log in to Vault with the discovered auth backend and save the token
func (c *Client) Login() error {
token, err := c.Auth.GetToken(c.Addr)
Expand All @@ -72,17 +97,12 @@ func (c *Client) RevokeToken() {
return
}

if c.hc == nil {
c.hc = &http.Client{Timeout: time.Second * 5}
}

u := &url.URL{}
*u = *c.Addr
u.Path = "/v1/auth/token/revoke-self"
req, _ := http.NewRequest("POST", u.String(), nil)
req.Header.Set("X-Vault-Token", c.token)

res, err := c.hc.Do(req)
res, err := requestAndFollow(c, "POST", u, nil)

if err != nil {
log.Println("Error while revoking Vault Token", err)
}
Expand All @@ -94,32 +114,23 @@ func (c *Client) RevokeToken() {

func (c *Client) Read(path string) ([]byte, error) {
path = normalizeURLPath(path)
if c.hc == nil {
c.hc = &http.Client{Timeout: time.Second * 5}
}

u := &url.URL{}
*u = *c.Addr
u.Path = "/v1" + path
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, err
}
req.Header.Set("X-Vault-Token", c.token)

res, err := c.hc.Do(req)
res, err := requestAndFollow(c, "GET", u, nil)
if err != nil {
return nil, err
}

body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return nil, err
}

if res.StatusCode != 200 {
err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, u, body)
err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, path, body)
return nil, err
}

Expand All @@ -131,7 +142,7 @@ func (c *Client) Read(path string) ([]byte, error) {
}

if _, ok := response["data"]; !ok {
return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", u, body)
return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", path, body)
}

return json.Marshal(response["data"])
Expand Down
47 changes: 47 additions & 0 deletions vault/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package vault

import (
"bytes"
"net/http"
"net/url"
)

// httpClient
type httpClient interface {
GetHTTPClient() *http.Client
SetToken(req *http.Request)
Do(req *http.Request) (*http.Response, error)
}

func requestAndFollow(hc httpClient, method string, u *url.URL, body []byte) (*http.Response, error) {
var res *http.Response
var err error
for attempts := 0; attempts < 2; attempts++ {
reader := bytes.NewReader(body)
req, err := http.NewRequest(method, u.String(), reader)

if err != nil {
return nil, err
}
hc.SetToken(req)
if method == "POST" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

res, err = hc.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode == http.StatusTemporaryRedirect {
res.Body.Close()
location, errLocation := res.Location()
if errLocation != nil {
return nil, errLocation
}
u.Host = location.Host
} else {
break
}
}
return res, err
}
68 changes: 68 additions & 0 deletions vault/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package vault

import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

type testClient struct{}

func (tc *testClient) GetHTTPClient() *http.Client {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqStr := fmt.Sprintf("%s %s", r.Method, r.URL)
switch reqStr {
case "POST http://vaultA:8500/v1/foo":
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", "http://vaultB:8500/v1/foo")
w.WriteHeader(http.StatusTemporaryRedirect)
fmt.Fprintln(w, "{}")
case "POST http://vaultB:8500/v1/foo":
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "{}")
default:
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "{ 'message': 'Unexpected request: %s'}", reqStr)
}
}))
return &http.Client{
Transport: &http.Transport{
Proxy: func(req *http.Request) (*url.URL, error) {
return url.Parse(server.URL)
},
},
}
}

func (tc *testClient) SetToken(req *http.Request) {
req.Header.Set("X-Vault-Token", "dead-beef-cafe-babe")
}

func (tc *testClient) Do(req *http.Request) (*http.Response, error) {
hc := tc.GetHTTPClient()
return hc.Do(req)
}

func TestRequestAndFollow_GetWithRedirect(t *testing.T) {
tc := &testClient{}
u, _ := url.Parse("http://vaultA:8500/v1/foo")

res, err := requestAndFollow(tc, "POST", u, nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

}

func TestRequestAndFollow_GetNoRedirect(t *testing.T) {
tc := &testClient{}
u, _ := url.Parse("http://vaultB:8500/v1/foo")

res, err := requestAndFollow(tc, "POST", u, nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}

0 comments on commit c160666

Please sign in to comment.