Skip to content

Refactor auth to increase encapsulation #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package kinesis

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
)

const (
ACCESS_ENV_KEY = "AWS_ACCESS_KEY"
SECRET_ENV_KEY = "AWS_SECRET_KEY"

AWS_METADATA_SERVER = "169.254.169.254"
AWS_IAM_CREDS_PATH = "/latest/meta-data/iam/security-credentials"
AWS_IAM_CREDS_URL = "http://" + AWS_METADATA_SERVER + AWS_IAM_CREDS_PATH
)

// Auth interface for authentication credentials and information
type Auth interface {
GetToken() string
GetExpiration() time.Time
GetSecretKey() string
GetAccessKey() string
HasExpiration() bool
Renew() error
Sign(*Service, time.Time) []byte
}

type auth struct {
// accessKey, secretKey are the standard AWS auth credentials
accessKey, secretKey, token string

// expiry indicates the time at which these credentials expire. If this is set
// to anything other than the zero value, indicates that the credentials are
// temporary (and probably fetched from an IAM role from the metadata server)
expiry time.Time
}

func NewAuth(accessKey, secretKey string) Auth {
return &auth{
accessKey: accessKey,
secretKey: secretKey,
}
}

// NewAuthFromEnv retrieves auth credentials from environment vars
func NewAuthFromEnv() (Auth, error) {
accessKey := os.Getenv(ACCESS_ENV_KEY)
secretKey := os.Getenv(SECRET_ENV_KEY)
if accessKey == "" {
return nil, fmt.Errorf("Unable to retrieve access key from %s env variable", ACCESS_ENV_KEY)
}
if secretKey == "" {
return nil, fmt.Errorf("Unable to retrieve secret key from %s env variable", SECRET_ENV_KEY)
}

return NewAuth(accessKey, secretKey), nil
}

// NewAuthFromMetadata retrieves auth credentials from the metadata
// server. If an IAM role is associated with the instance we are running on, the
// metadata server will expose credentials for that role under a known endpoint.
//
// TODO: specify custom network (connect, read) timeouts, else this will block
// for the default timeout durations.
func NewAuthFromMetadata() (Auth, error) {
auth := &auth{}
if err := auth.Renew(); err != nil {
return nil, err
}
return auth, nil
}

// HasExpiration returns true if the expiration time is non-zero and false otherwise
func (a *auth) HasExpiration() bool {
return !a.expiry.IsZero()
}

// GetExpiration retrieves the current expiration time
func (a *auth) GetExpiration() time.Time {
return a.expiry
}

// GetToken returns the token
func (a *auth) GetToken() string {
return a.token
}

// GetSecretKey returns the secret key
func (a *auth) GetSecretKey() string {
return a.secretKey
}

// GetAccessKey returns the access key
func (a *auth) GetAccessKey() string {
return a.accessKey
}

// Renew retrieves a new token and mutates it on an instance of the Auth struct
func (a *auth) Renew() error {
role, err := retrieveIAMRole()
if err != nil {
return err
}

data, err := retrieveAWSCredentials(role)
if err != nil {
return err
}

// Ignore the error, it just means we won't be able to refresh the
// credentials when they expire.
expiry, _ := time.Parse(time.RFC3339, data["Expiration"])

a.expiry = expiry
a.accessKey = data["AccessKeyId"]
a.secretKey = data["SecretAccessKey"]
a.token = data["Token"]
return nil
}

// Sign API request by
// http://docs.amazonwebservices.com/general/latest/gr/signature-version-4.html

func (a *auth) Sign(s *Service, t time.Time) []byte {
h := ghmac([]byte("AWS4"+a.GetSecretKey()), []byte(t.Format(iSO8601BasicFormatShort)))
h = ghmac(h, []byte(s.Region))
h = ghmac(h, []byte(s.Name))
h = ghmac(h, []byte(AWS4_URL))
return h
}

func retrieveAWSCredentials(role string) (map[string]string, error) {
var bodybytes []byte
// Retrieve the json for this role
resp, err := http.Get(AWS_IAM_CREDS_URL + "/" + role)
if err != nil || resp.StatusCode != http.StatusOK {
return nil, err
}
defer resp.Body.Close()

bodybytes, err = ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}

jsondata := make(map[string]string)
err = json.Unmarshal(bodybytes, &jsondata)
if err != nil {
return nil, err
}

return jsondata, nil
}

func retrieveIAMRole() (string, error) {
var bodybytes []byte

resp, err := http.Get(AWS_IAM_CREDS_URL)
if err != nil || resp.StatusCode != http.StatusOK {
return "", err
}
defer resp.Body.Close()

bodybytes, err = ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}

// pick the first IAM role
role := strings.Split(string(bodybytes), "\n")[0]
if len(role) == 0 {
return "", errors.New("Unable to retrieve IAM role")
}

return role, nil
}
40 changes: 40 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package kinesis

import (
"os"
"testing"
)

func TestGetSecretKey(t *testing.T) {
auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY")

if auth.GetAccessKey() != "BAD_ACCESS_KEY" {
t.Error("incorrect value for auth#accessKey")
}
}

func TestGetAccessKey(t *testing.T) {
auth := NewAuth("BAD_ACCESS_KEY", "BAD_SECRET_KEY")

if auth.GetSecretKey() != "BAD_SECRET_KEY" {
t.Error("incorrect value for auth#secretKey")
}
}

func TestNewAuthFromEnv(t *testing.T) {
os.Setenv(ACCESS_ENV_KEY, "asdf")
os.Setenv(SECRET_ENV_KEY, "asdf")

auth, _ := NewAuthFromEnv()

if auth.GetAccessKey() != "asdf" {
t.Error("Expected AccessKey to be inferred as \"asdf\"")
}

if auth.GetSecretKey() != "asdf" {
t.Error("Expected SecretKey to be inferred as \"asdf\"")
}

os.Setenv(ACCESS_ENV_KEY, "") // Use Unsetenv with go1.4
os.Setenv(SECRET_ENV_KEY, "") // Use Unsetenv with go1.4
}
141 changes: 15 additions & 126 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,160 +1,49 @@
package kinesis

import (
"encoding/json"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
)

const (
ACCESS_ENV_KEY = "AWS_ACCESS_KEY"
SECRET_ENV_KEY = "AWS_SECRET_KEY"
REGION_ENV_NAME = "AWS_REGION_NAME"

AWS_METADATA_SERVER = "169.254.169.254"
AWS_IAM_CREDS_PATH = "/latest/meta-data/iam/security-credentials"
AWS_IAM_CREDS_URL = "http://" + AWS_METADATA_SERVER + AWS_IAM_CREDS_PATH

AWS_SECURITY_TOKEN_HEADER = "X-Amz-Security-Token"
)

// Auth store information about AWS Credentials
type Auth struct {
// AccessKey, SecretKey are the standard AWS auth credentials
AccessKey, SecretKey, Token string
// Expiry indicates the time at which these credentials expire. If this is set
// to anything other than the zero value, indicates that the credentials are
// temporary (and probably fetched from an IAM role from the metadata server)
Expiry time.Time
}
const AWS_SECURITY_TOKEN_HEADER = "X-Amz-Security-Token"

// Client is like http.Client, but signs all requests using Auth.
type Client struct {
// Auth holds the credentials for this client instance
Auth *Auth
auth Auth
// The http client to make requests with. If nil, http.DefaultClient is used.
Client *http.Client
}

// NewAuth returns a new Auth object whose members (AccessKey, SecretKey, etc)
// have been initialized by inspecting the environment or querying the AWS
// metadata server (in that order).
func NewAuth() (auth Auth) {
// first try grabbing the credentials from the environment
if auth.AccessKey == "" || auth.SecretKey == "" {
auth.InferCredentialsFromEnv()
}

// if they're still not set, try the metadata server
if auth.AccessKey == "" || auth.SecretKey == "" {
auth.InferCredentialsFromMetadata()
}

return
}

// InferCredentialsFromEnv retrieves auth credentials from environment vars
func (auth *Auth) InferCredentialsFromEnv() {
auth.AccessKey = os.Getenv(ACCESS_ENV_KEY)
auth.SecretKey = os.Getenv(SECRET_ENV_KEY)
}

// InferCredentialsFromMetadata retrieves auth credentials from the metadata
// server. If an IAM role is associated with the instance we are running on, the
// metadata server will expose credentials for that role under a known endpoint.
//
// TODO: specify custom network (connect, read) timeouts, else this will block
// for the default timeout durations.
func (auth *Auth) InferCredentialsFromMetadata() {
resp1, err := http.Get(AWS_IAM_CREDS_URL)
if err != nil || resp1.StatusCode != http.StatusOK {
return
}
defer resp1.Body.Close()

bodybytes, err := ioutil.ReadAll(resp1.Body)
if err != nil {
return
}

// pick the first IAM role
role := strings.Split(string(bodybytes), "\n")[0]
if len(role) == 0 {
return
}

// Retrieve the json for this role
resp2, err := http.Get(AWS_IAM_CREDS_URL + "/" + role)
if err != nil || resp2.StatusCode != http.StatusOK {
return
}
defer resp2.Body.Close()

bodybytes, err = ioutil.ReadAll(resp2.Body)
if err != nil {
return
}

jsondata := make(map[string]string)
err = json.Unmarshal(bodybytes, &jsondata)
if err != nil {
return
}

expiry, _ := time.Parse(time.RFC3339, jsondata["Expiration"])
// Ignore the error, it just means we won't be able to refresh the
// credentials when they expire.

auth.Expiry = expiry
auth.AccessKey = jsondata["AccessKeyId"]
auth.SecretKey = jsondata["SecretAccessKey"]
auth.Token = jsondata["Token"]
client *http.Client
}

// NewClient creates a new Client that uses the credentials in the specified
// Auth object.
//
// This function assumes the Auth object has been sanely initialized. If you
// wish to infer auth credentials from the environment, refer to NewAuth
func NewClient(auth *Auth) *Client {
return &Client{Auth: auth}
func NewClient(auth Auth) *Client {
return &Client{auth: auth, client: http.DefaultClient}
}

// GetRegion returns the region name string
func GetRegion(region Region) string {
if region.Name == "" {
return os.Getenv(REGION_ENV_NAME)
}
return region.Name
}

// get the http client we use to communicate with the server
func (c *Client) client() *http.Client {
if c.Client == nil {
return http.DefaultClient
}
return c.Client
func NewClientWithHTTPClient(auth Auth, httpClient *http.Client) *Client {
return &Client{auth: auth, client: httpClient}
}

// Do some request, but sign it before sending
func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
err = Sign(c.Auth, req)
func (c *Client) Do(req *http.Request) (*http.Response, error) {
err := Sign(c.auth, req)
if err != nil {
return nil, err
}

if !c.Auth.Expiry.IsZero() {
if time.Now().After(c.Auth.Expiry) {
c.Auth.InferCredentialsFromMetadata() // TODO: (see above) may be slow
if c.auth.HasExpiration() && time.Now().After(c.auth.GetExpiration()) {
if err = c.auth.Renew(); err != nil { // TODO: (see auth.go#Renew) may be slow
return nil, err
}
}

if len(c.Auth.Token) != 0 {
req.Header.Add(AWS_SECURITY_TOKEN_HEADER, c.Auth.Token)
if c.auth.GetToken() != "" {
req.Header.Add(AWS_SECURITY_TOKEN_HEADER, c.auth.GetToken())
}

return c.client().Do(req)
return c.client.Do(req)
}
Loading