Skip to content

Commit

Permalink
Trying to make awsauth thread safe by getting rid of Keys global.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dane committed Jun 2, 2014
1 parent b7d78fd commit 3c21e95
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 87 deletions.
65 changes: 33 additions & 32 deletions awsauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ import (
"time"
)

// Keys stores the authentication credentials to be used when signing requests.
// You can set them manually or leave it to awsauth to use environment variables.
var Keys *Credentials

// Credentials stores the information necessary to authorize with AWS and it
// is from this information that requests are signed.
type Credentials struct {
Expand Down Expand Up @@ -59,15 +55,16 @@ func Sign(req *http.Request, cred ...Credentials) *http.Request {
func Sign4(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
var keys Credentials
if len(cred) == 0 {
checkKeys()
keys = newKeys()
} else {
Keys = &cred[0]
keys = cred[0]
}

// Add the X-Amz-Security-Token header when using STS
if Keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", Keys.SecurityToken)
if keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", keys.SecurityToken)
}

prepareRequestV4(req)
Expand All @@ -80,10 +77,10 @@ func Sign4(req *http.Request, cred ...Credentials) *http.Request {
stringToSign := stringToSignV4(req, hashedCanonReq, meta)

// Task 3
signingKey := signingKeyV4(Keys.SecretAccessKey, meta.date, meta.region, meta.service)
signingKey := signingKeyV4(keys.SecretAccessKey, meta.date, meta.region, meta.service)
signature := signatureV4(signingKey, stringToSign)

req.Header.Set("Authorization", buildAuthHeaderV4(signature, meta))
req.Header.Set("Authorization", buildAuthHeaderV4(signature, meta, keys))

return req
}
Expand All @@ -93,15 +90,16 @@ func Sign4(req *http.Request, cred ...Credentials) *http.Request {
func Sign3(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
var keys Credentials
if len(cred) == 0 {
checkKeys()
keys = newKeys()
} else {
Keys = &cred[0]
keys = cred[0]
}

// Add the X-Amz-Security-Token header when using STS
if Keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", Keys.SecurityToken)
if keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", keys.SecurityToken)
}

prepareRequestV3(req)
Expand All @@ -110,10 +108,10 @@ func Sign3(req *http.Request, cred ...Credentials) *http.Request {
stringToSign := stringToSignV3(req)

// Task 2
signature := signatureV3(stringToSign)
signature := signatureV3(stringToSign, keys)

// Task 3
req.Header.Set("X-Amzn-Authorization", buildAuthHeaderV3(signature))
req.Header.Set("X-Amzn-Authorization", buildAuthHeaderV3(signature, keys))

return req
}
Expand All @@ -123,25 +121,26 @@ func Sign3(req *http.Request, cred ...Credentials) *http.Request {
func Sign2(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
var keys Credentials
if len(cred) == 0 {
checkKeys()
keys = newKeys()
} else {
Keys = &cred[0]
keys = cred[0]
}

// Add the SecurityToken parameter when using STS
// This must be added before the signature is calculated
if Keys.SecurityToken != "" {
if keys.SecurityToken != "" {
v := url.Values{}
v.Set("SecurityToken", Keys.SecurityToken)
v.Set("SecurityToken", keys.SecurityToken)
augmentRequestQuery(req, v)

}

prepareRequestV2(req)
prepareRequestV2(req, keys)

stringToSign := stringToSignV2(req)
signature := signatureV2(stringToSign)
signature := signatureV2(stringToSign, keys)

values := url.Values{}
values.Set("Signature", signature)
Expand All @@ -156,23 +155,24 @@ func Sign2(req *http.Request, cred ...Credentials) *http.Request {
func SignS3(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
var keys Credentials
if len(cred) == 0 {
checkKeys()
keys = newKeys()
} else {
Keys = &cred[0]
keys = cred[0]
}

// Add the X-Amz-Security-Token header when using STS
if Keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", Keys.SecurityToken)
if keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", keys.SecurityToken)
}

prepareRequestS3(req)

stringToSign := stringToSignS3(req)
signature := signatureS3(stringToSign)
signature := signatureS3(stringToSign, keys)

authHeader := "AWS " + Keys.AccessKeyID + ":" + signature
authHeader := "AWS " + keys.AccessKeyID + ":" + signature
req.Header.Set("Authorization", authHeader)

return req
Expand All @@ -185,17 +185,18 @@ func SignS3(req *http.Request, cred ...Credentials) *http.Request {
func SignS3Url(req *http.Request, expire time.Time, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
var keys Credentials
if len(cred) == 0 {
checkKeys()
keys = newKeys()
} else {
Keys = &cred[0]
keys = cred[0]
}

stringToSign := stringToSignS3Url("GET", expire, req.URL.Path)
signature := signatureS3(stringToSign)
signature := signatureS3(stringToSign, keys)

qs := req.URL.Query()
qs.Set("AWSAccessKeyId", Keys.AccessKeyID)
qs.Set("AWSAccessKeyId", keys.AccessKeyID)
qs.Set("Signature", signature)
qs.Set("Expires", timeToUnixEpochString(expire))
req.URL.RawQuery = qs.Encode()
Expand Down
41 changes: 39 additions & 2 deletions awsauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,42 @@ func TestSign(t *testing.T) {
So(signedReq.Header.Get("Authorization"), ShouldContainSubstring, ", Signature=")
}
})

var keys Credentials
keys = newKeys()
Convey("Requests to services using existing credentials Version 2 should be signed accordingly", t, func() {
reqs := []*http.Request{
newRequest("GET", "https://ec2.amazonaws.com", url.Values{}),
newRequest("GET", "https://elasticache.amazonaws.com/", url.Values{}),
}
for _, req := range reqs {
signedReq := Sign(req, keys)
So(signedReq.URL.Query().Get("SignatureVersion"), ShouldEqual, "2")
}
})

Convey("Requests to services using existing credentials Version 3 should be signed accordingly", t, func() {
reqs := []*http.Request{
newRequest("GET", "https://route53.amazonaws.com", url.Values{}),
newRequest("GET", "https://email.us-east-1.amazonaws.com/", url.Values{}),
}
for _, req := range reqs {
signedReq := Sign(req, keys)
So(signedReq.Header.Get("X-Amzn-Authorization"), ShouldNotBeBlank)
}
})

Convey("Requests to services using existing credentials Version 4 should be signed accordingly", t, func() {
reqs := []*http.Request{
newRequest("POST", "https://sqs.amazonaws.com/", url.Values{}),
newRequest("GET", "https://iam.amazonaws.com", url.Values{}),
newRequest("GET", "https://s3.amazonaws.com", url.Values{}),
}
for _, req := range reqs {
signedReq := Sign(req, keys)
So(signedReq.Header.Get("Authorization"), ShouldContainSubstring, ", Signature=")
}
})
}

func TestExpiration(t *testing.T) {
Expand All @@ -201,8 +237,9 @@ func TestExpiration(t *testing.T) {
}

func credentialsSet() bool {
checkKeys()
if Keys.AccessKeyID == "" {
var keys Credentials
keys = newKeys()
if keys.AccessKeyID == "" {
return false
} else {
return true
Expand Down
22 changes: 10 additions & 12 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,22 @@ func serviceAndRegion(host string) (service string, region string) {
return
}

func checkKeys() {
if Keys == nil {
Keys = &Credentials{
AccessKeyID: os.Getenv(envAccessKeyID),
SecretAccessKey: os.Getenv(envSecretAccessKey),
SecurityToken: os.Getenv(envSecurityToken),
}
}
func newKeys() (newCredentials Credentials) {

newCredentials.AccessKeyID = os.Getenv(envAccessKeyID)
newCredentials.SecretAccessKey = os.Getenv(envSecretAccessKey)
newCredentials.SecurityToken = os.Getenv(envSecurityToken)

// If there is no Access Key and you are on EC2, get the key from the role
if Keys.AccessKeyID == "" && onEC2() {
Keys = getIAMRoleCredentials()
if newCredentials.AccessKeyID == "" && onEC2() {
newCredentials = *getIAMRoleCredentials()
}

// If the key is expiring, get a new key
if Keys.expired() && onEC2() {
Keys = getIAMRoleCredentials()
if newCredentials.expired() && onEC2() {
newCredentials = *getIAMRoleCredentials()
}
return newCredentials
}

// onEC2 checks to see if the program is running on an EC2 instance.
Expand Down
4 changes: 2 additions & 2 deletions s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"time"
)

func signatureS3(stringToSign string) string {
hashed := hmacSHA1([]byte(Keys.SecretAccessKey), stringToSign)
func signatureS3(stringToSign string, keys Credentials) string {
hashed := hmacSHA1([]byte(keys.SecretAccessKey), stringToSign)
return base64.StdEncoding.EncodeToString(hashed)
}

Expand Down
14 changes: 7 additions & 7 deletions s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestSignatureS3(t *testing.T) {
// (but signed URL requests still utilize a lot of the same functionality)

Convey("Given a GET request to Amazon S3", t, func() {
Keys = testCredS3
keys := *testCredS3
req := test_plainRequestS3()

// Mock time
Expand Down Expand Up @@ -46,13 +46,13 @@ func TestSignatureS3(t *testing.T) {
})

Convey("The final signature string should be exactly correct", func() {
actual := signatureS3(stringToSignS3(req))
actual := signatureS3(stringToSignS3(req), keys)
So(actual, ShouldEqual, "bWq2s1WEIj+Ydj0vQ697zp+IXMU=")
})
})

Convey("Given a GET request for a resource on S3 for query string authentication", t, func() {
Keys = testCredS3
keys := *testCredS3
req, _ := http.NewRequest("GET", "https://johnsmith.s3.amazonaws.com/johnsmith/photos/puppy.jpg", nil)

now = func() time.Time {
Expand All @@ -66,13 +66,13 @@ func TestSignatureS3(t *testing.T) {
})

Convey("The signature of string to sign should be correct", func() {
actual := signatureS3(expectedStringToSignS3Url)
actual := signatureS3(expectedStringToSignS3Url, keys)
So(actual, ShouldEqual, "R2K/+9bbnBIbVDCs7dqlz3XFtBQ=")
})

Convey("The finished signed URL should be correct", func() {
expiry := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
So(SignS3Url(req, expiry).URL.String(), ShouldEqual, expectedSignedS3Url)
So(SignS3Url(req, expiry, keys).URL.String(), ShouldEqual, expectedSignedS3Url)
})
})
}
Expand All @@ -82,10 +82,10 @@ func TestS3STSRequestPreparer(t *testing.T) {
req := test_plainRequestS3()

Convey("And a set of credentials with an STS token", func() {
Keys = testCredS3WithSTS
keys := *testCredS3WithSTS

Convey("It should include an X-Amz-Security-Token when the request is signed", func() {
actualSigned := SignS3(req)
actualSigned := SignS3(req, keys)
actual := actualSigned.Header.Get("X-Amz-Security-Token")

So(actual, ShouldNotBeBlank)
Expand Down
11 changes: 4 additions & 7 deletions sign2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ import (
"strings"
)

func prepareRequestV2(req *http.Request) *http.Request {
keyID := ""
func prepareRequestV2(req *http.Request, keys Credentials) *http.Request {

if Keys != nil {
keyID = Keys.AccessKeyID
}
keyID := keys.AccessKeyID

values := url.Values{}
values.Set("AWSAccessKeyId", keyID)
Expand All @@ -37,8 +34,8 @@ func stringToSignV2(req *http.Request) string {
return str
}

func signatureV2(strToSign string) string {
hashed := hmacSHA256([]byte(Keys.SecretAccessKey), strToSign)
func signatureV2(strToSign string, keys Credentials) string {
hashed := hmacSHA256([]byte(keys.SecretAccessKey), strToSign)
return base64.StdEncoding.EncodeToString(hashed)
}

Expand Down
Loading

0 comments on commit 3c21e95

Please sign in to comment.