Skip to content

Commit

Permalink
Rearrange code in favor of clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
juan-carl0s committed Feb 21, 2020
1 parent ec91d77 commit 63a3435
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 425 deletions.
71 changes: 67 additions & 4 deletions api.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,73 @@
package sfdcclient

import (
"context"
"net/http"
"fmt"
"strings"
)

type Client interface {
SendRequest(ctx context.Context, method, relURL string, headers http.Header, requestBody []byte) (int, []byte, error)
/*****************************************/
/* Salesforce auth token response type */
/*****************************************/

// AccessTokenResponse represents a response of requests to salesforce OAuth token endpoint
// https://${yourInstance}.salesforce.com/services/oauth2/token
type AccessTokenResponse struct {
// Success response fields
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
Instance string `json:"instance_url"`
ID string `json:"id"`
TokenType string `json:"token_type"`
}

// OAuthErr represents an error that occurs during the authorization flow
// https://help.salesforce.com/articleView?id=remoteaccess_oauth_flow_errors.htm&type=5
type OAuthErr struct {
// Error response fields
Code string `json:"error"`
Description string `json:"error_description"`
}

func (e *OAuthErr) Error() string {
return fmt.Sprintf("error code: %s, description: %s", e.Code, e.Description)
}

/**********************************************/
/* Salesforce REST API error response types */
/**********************************************/

// APIErrs represents an error response from salesforce REST API endpoints
// Example:
// [
// {
// "statusCode": "MALFORMED_ID",
// "message": "SomeSaleforceObject ID: id value of incorrect type: 1234",
// "fields": [
// "Id"
// ]
// }
// ]
type APIErrs []APIErr

func (e *APIErrs) Error() string {
var str []string
if e != nil {
for _, e := range *e {
str = append(str, e.Error())
}
}
return strings.Join(str, "|")
}

type APIErr struct {
Message string `json:"message"`
ErrCode string `json:"errorCode"`
Fields []string `json:"fields"`
}

func (e *APIErr) Error() string {
if len(e.Fields) > 0 {
return fmt.Sprintf("error code: %s, message: %s, fields: %s", e.ErrCode, e.Message, strings.Join(e.Fields, ","))
}
return fmt.Sprintf("error code: %s, message: %s", e.ErrCode, e.Message)
}
26 changes: 13 additions & 13 deletions response_test.go → api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestOAuthErr_Error(t *testing.T) {
}
}

func TestErrorObject_Error(t *testing.T) {
func TestAPIErr_Error(t *testing.T) {
type fields struct {
Message string
ErrCode string
Expand All @@ -57,21 +57,21 @@ func TestErrorObject_Error(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &ErrorObject{
e := &APIErr{
Message: tt.fields.Message,
ErrCode: tt.fields.ErrCode,
}
if got := e.Error(); got != tt.want {
t.Errorf("ErrorObject.Error() = %+v, want %+v", got, tt.want)
t.Errorf("APIErr.Error() = %+v, want %+v", got, tt.want)
}
})
}
}

func TestErrorObjects_Error(t *testing.T) {
func TestAPIErrs_Error(t *testing.T) {
tests := []struct {
name string
e *ErrorObjects
e *APIErrs
want string
}{
{
Expand All @@ -81,13 +81,13 @@ func TestErrorObjects_Error(t *testing.T) {
},
{
name: "EmptyErrsSlice",
e: &ErrorObjects{},
e: &APIErrs{},
want: "",
},
{
name: "OneErr",
e: &ErrorObjects{
ErrorObject{
e: &APIErrs{
APIErr{
ErrCode: "123",
Message: "message",
Fields: []string{"field"},
Expand All @@ -97,25 +97,25 @@ func TestErrorObjects_Error(t *testing.T) {
},
{
name: "MultipleErrs",
e: &ErrorObjects{
ErrorObject{
e: &APIErrs{
APIErr{
ErrCode: "123",
Message: "message",
Fields: []string{"field"},
},
ErrorObject{
APIErr{
ErrCode: "456",
Message: "otherMessage",
Fields: []string{"otherField"},
},
},
want: "error code: 123, message: message, fields: field\nerror code: 456, message: otherMessage, fields: otherField",
want: "error code: 123, message: message, fields: field|error code: 456, message: otherMessage, fields: otherField",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.e.Error(); got != tt.want {
t.Errorf("ErrorObjects.Error() = %+v, want %+v", got, tt.want)
t.Errorf("APIErrs.Error() = %+v, want %+v", got, tt.want)
}
})
}
Expand Down
226 changes: 2 additions & 224 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,232 +1,10 @@
package sfdcclient

import (
"bytes"
"context"
"crypto/rsa"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"sync"
"time"

"github.com/dgrijalva/jwt-go"
)

type jwtBearer struct {
// Underlying http client used for making all HTTP requests to salesforce
// note that the configuration of this HTTP client will affect all HTTP
// requests done by this struct (including the OAuth requests)
client http.Client

// URL of server where the salesforce organization lives
instanceURL string

// Variables needed for the generation and signing of the JWT token
rsaPrivateKey *rsa.PrivateKey
consumerKey string
username string
authServerURL string
tokenExpTimeout time.Duration

// Authentication token issued by Salesforce
accessToken string
accessTokenMutex *sync.RWMutex
}

func NewClientWithJWTBearer(sandbox bool, instanceURL, consumerKey, username string, privateKey []byte, tokenExpTimeout time.Duration, httpClient http.Client) (Client, error) {
baseSFURL := "https://%s.salesforce.com"

var authServerURL string
if sandbox {
authServerURL = fmt.Sprintf(baseSFURL, "test")
} else {
authServerURL = fmt.Sprintf(baseSFURL, "login")
}

jwtBearer := jwtBearer{
client: httpClient,
instanceURL: instanceURL,
authServerURL: authServerURL,
// oauthTokenURL: instanceURL + "/services/oauth2/token",
consumerKey: consumerKey,
username: username,
accessTokenMutex: &sync.RWMutex{},
}

if sandbox {
jwtBearer.authServerURL = fmt.Sprintf(baseSFURL, "test")
} else {
jwtBearer.authServerURL = fmt.Sprintf(baseSFURL, "login")
}

var err error

if jwtBearer.rsaPrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey); err != nil {
return nil, err
}

if err = jwtBearer.newAccessToken(); err != nil {
return nil, err
}

return &jwtBearer, nil
}

// newAccessToken updates the client's accessToken if salesforce successfully grants one
// This function implements "OAuth 2.0 JWT Bearer Flow for Server-to-Server Integration"
// see https://help.salesforce.com/articleView?id=remoteaccess_oauth_jwt_flow.htm
func (c *jwtBearer) newAccessToken() error {
// Create JWT
token := jwt.NewWithClaims(
jwt.SigningMethodRS256,
jwt.StandardClaims{
Issuer: c.consumerKey,
Audience: c.authServerURL,
Subject: c.username,
ExpiresAt: time.Now().Add(c.tokenExpTimeout).UTC().Unix(),
},
)
// Sign JWT with the private key
signedJWT, err := token.SignedString(c.rsaPrivateKey)
if err != nil {
return err
}

oauthTokenURL := c.instanceURL + "/services/oauth2/token"
// oauthTokenURL := c.authServerURL + "/services/oauth2/token"
// Request new access token from salesforce's OAuth endpoint
req, err := http.NewRequest(
"POST",
oauthTokenURL,
strings.NewReader(
fmt.Sprintf("grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=%s", signedJWT),
),
)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
res, err := c.client.Do(req)
if err != nil {
return err
}

resBytes, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return err
}

switch res.StatusCode {
case http.StatusOK:
break
case http.StatusBadRequest:
var errRes OAuthErr
if err := json.Unmarshal(resBytes, &errRes); err != nil {
return err
}
return &errRes
default:
return fmt.Errorf("%s responded with an unexpected HTTP status code: %d", oauthTokenURL, res.StatusCode)
}

var tokenRes AccessTokenResponse
if err := json.Unmarshal(resBytes, &tokenRes); err != nil {
return err
}

c.accessTokenMutex.Lock()
c.accessToken = tokenRes.AccessToken
c.accessTokenMutex.Unlock()

return nil
}

// SendRequest sends a n HTTP request as specified by its function parameters
// If the server responds with an unauthorized 401 HTTP status code, the client attempts
// to get a new authorization access token and retries the same request one more time
func (c jwtBearer) SendRequest(ctx context.Context, method, relURL string, headers http.Header, requestBody []byte) (int, []byte, error) {
url := c.instanceURL + relURL
var err error

// Issue the request to salesforce
statusCode, resBody, err := c.sendRequest(ctx, method, url, headers, requestBody)
if err != nil {
// Check if the error came from salesforce's API
if _, ok := err.(*ErrorObjects); ok {
// If the status code returned is Unauthorized (401)
// Presumably, the current client's access token we have has expired,
// hence, we attempt to update the client's access token and retry the request once
// see https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/errorcodes.htm
if statusCode == http.StatusUnauthorized {
err := c.newAccessToken()
if err != nil {
return statusCode, nil, err
}

// Retry the original request
statusCode, resBody, err = c.sendRequest(ctx, method, url, headers, requestBody)
if err != nil {
return statusCode, nil, err
}
}
} else {
return statusCode, nil, err
}
}

return statusCode, resBody, nil
}

func (c jwtBearer) sendRequest(ctx context.Context, method, url string, headers http.Header, requestBody []byte) (int, []byte, error) {
var req *http.Request
var err error
if requestBody == nil {
req, err = http.NewRequestWithContext(ctx, method, url, nil)
} else {
req, err = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(requestBody))
}
if err != nil {
return -1, nil, err
}

c.accessTokenMutex.RLock()
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken))
c.accessTokenMutex.RUnlock()
for hKey, hVals := range headers {
for _, hVal := range hVals {
req.Header.Add(hKey, hVal)
}
}

res, err := c.client.Do(req)
if err != nil {
return -1, nil, err
}
resBytes, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return -1, nil, err
}

var errs ErrorObjects
switch res.StatusCode {
// Salesforce HTTP status codes and error responses:
// https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/errorcodes.htm
case http.StatusOK, http.StatusCreated, http.StatusNoContent,
http.StatusMultipleChoices, http.StatusNotModified:
break
case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden,
http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusUnsupportedMediaType,
http.StatusInternalServerError:
err = json.Unmarshal(resBytes, &errs)
if err != nil {
return res.StatusCode, nil, err
}
return res.StatusCode, nil, &errs
default:
return res.StatusCode, nil, fmt.Errorf("unexpected HTTP status code: %d", res.StatusCode)
}

return res.StatusCode, resBytes, nil
type Client interface {
SendRequest(ctx context.Context, method, relURL string, headers http.Header, requestBody []byte) (int, []byte, error)
}
File renamed without changes.
Loading

0 comments on commit 63a3435

Please sign in to comment.