-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
juan-carl0s
committed
Feb 21, 2020
1 parent
ec91d77
commit 63a3435
Showing
9 changed files
with
349 additions
and
425 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,73 @@ | ||
package sfdcclient | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
type Client interface { | ||
SendRequest(ctx context.Context, method, relURL string, headers http.Header, requestBody []byte) (int, []byte, error) | ||
/*****************************************/ | ||
/* Salesforce auth token response type */ | ||
/*****************************************/ | ||
|
||
// AccessTokenResponse represents a response of requests to salesforce OAuth token endpoint | ||
// https://${yourInstance}.salesforce.com/services/oauth2/token | ||
type AccessTokenResponse struct { | ||
// Success response fields | ||
AccessToken string `json:"access_token"` | ||
Scope string `json:"scope"` | ||
Instance string `json:"instance_url"` | ||
ID string `json:"id"` | ||
TokenType string `json:"token_type"` | ||
} | ||
|
||
// OAuthErr represents an error that occurs during the authorization flow | ||
// https://help.salesforce.com/articleView?id=remoteaccess_oauth_flow_errors.htm&type=5 | ||
type OAuthErr struct { | ||
// Error response fields | ||
Code string `json:"error"` | ||
Description string `json:"error_description"` | ||
} | ||
|
||
func (e *OAuthErr) Error() string { | ||
return fmt.Sprintf("error code: %s, description: %s", e.Code, e.Description) | ||
} | ||
|
||
/**********************************************/ | ||
/* Salesforce REST API error response types */ | ||
/**********************************************/ | ||
|
||
// APIErrs represents an error response from salesforce REST API endpoints | ||
// Example: | ||
// [ | ||
// { | ||
// "statusCode": "MALFORMED_ID", | ||
// "message": "SomeSaleforceObject ID: id value of incorrect type: 1234", | ||
// "fields": [ | ||
// "Id" | ||
// ] | ||
// } | ||
// ] | ||
type APIErrs []APIErr | ||
|
||
func (e *APIErrs) Error() string { | ||
var str []string | ||
if e != nil { | ||
for _, e := range *e { | ||
str = append(str, e.Error()) | ||
} | ||
} | ||
return strings.Join(str, "|") | ||
} | ||
|
||
type APIErr struct { | ||
Message string `json:"message"` | ||
ErrCode string `json:"errorCode"` | ||
Fields []string `json:"fields"` | ||
} | ||
|
||
func (e *APIErr) Error() string { | ||
if len(e.Fields) > 0 { | ||
return fmt.Sprintf("error code: %s, message: %s, fields: %s", e.ErrCode, e.Message, strings.Join(e.Fields, ",")) | ||
} | ||
return fmt.Sprintf("error code: %s, message: %s", e.ErrCode, e.Message) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,232 +1,10 @@ | ||
package sfdcclient | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/rsa" | ||
"encoding/json" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"strings" | ||
"sync" | ||
"time" | ||
|
||
"github.com/dgrijalva/jwt-go" | ||
) | ||
|
||
type jwtBearer struct { | ||
// Underlying http client used for making all HTTP requests to salesforce | ||
// note that the configuration of this HTTP client will affect all HTTP | ||
// requests done by this struct (including the OAuth requests) | ||
client http.Client | ||
|
||
// URL of server where the salesforce organization lives | ||
instanceURL string | ||
|
||
// Variables needed for the generation and signing of the JWT token | ||
rsaPrivateKey *rsa.PrivateKey | ||
consumerKey string | ||
username string | ||
authServerURL string | ||
tokenExpTimeout time.Duration | ||
|
||
// Authentication token issued by Salesforce | ||
accessToken string | ||
accessTokenMutex *sync.RWMutex | ||
} | ||
|
||
func NewClientWithJWTBearer(sandbox bool, instanceURL, consumerKey, username string, privateKey []byte, tokenExpTimeout time.Duration, httpClient http.Client) (Client, error) { | ||
baseSFURL := "https://%s.salesforce.com" | ||
|
||
var authServerURL string | ||
if sandbox { | ||
authServerURL = fmt.Sprintf(baseSFURL, "test") | ||
} else { | ||
authServerURL = fmt.Sprintf(baseSFURL, "login") | ||
} | ||
|
||
jwtBearer := jwtBearer{ | ||
client: httpClient, | ||
instanceURL: instanceURL, | ||
authServerURL: authServerURL, | ||
// oauthTokenURL: instanceURL + "/services/oauth2/token", | ||
consumerKey: consumerKey, | ||
username: username, | ||
accessTokenMutex: &sync.RWMutex{}, | ||
} | ||
|
||
if sandbox { | ||
jwtBearer.authServerURL = fmt.Sprintf(baseSFURL, "test") | ||
} else { | ||
jwtBearer.authServerURL = fmt.Sprintf(baseSFURL, "login") | ||
} | ||
|
||
var err error | ||
|
||
if jwtBearer.rsaPrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey); err != nil { | ||
return nil, err | ||
} | ||
|
||
if err = jwtBearer.newAccessToken(); err != nil { | ||
return nil, err | ||
} | ||
|
||
return &jwtBearer, nil | ||
} | ||
|
||
// newAccessToken updates the client's accessToken if salesforce successfully grants one | ||
// This function implements "OAuth 2.0 JWT Bearer Flow for Server-to-Server Integration" | ||
// see https://help.salesforce.com/articleView?id=remoteaccess_oauth_jwt_flow.htm | ||
func (c *jwtBearer) newAccessToken() error { | ||
// Create JWT | ||
token := jwt.NewWithClaims( | ||
jwt.SigningMethodRS256, | ||
jwt.StandardClaims{ | ||
Issuer: c.consumerKey, | ||
Audience: c.authServerURL, | ||
Subject: c.username, | ||
ExpiresAt: time.Now().Add(c.tokenExpTimeout).UTC().Unix(), | ||
}, | ||
) | ||
// Sign JWT with the private key | ||
signedJWT, err := token.SignedString(c.rsaPrivateKey) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
oauthTokenURL := c.instanceURL + "/services/oauth2/token" | ||
// oauthTokenURL := c.authServerURL + "/services/oauth2/token" | ||
// Request new access token from salesforce's OAuth endpoint | ||
req, err := http.NewRequest( | ||
"POST", | ||
oauthTokenURL, | ||
strings.NewReader( | ||
fmt.Sprintf("grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=%s", signedJWT), | ||
), | ||
) | ||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") | ||
res, err := c.client.Do(req) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
resBytes, err := ioutil.ReadAll(res.Body) | ||
res.Body.Close() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
switch res.StatusCode { | ||
case http.StatusOK: | ||
break | ||
case http.StatusBadRequest: | ||
var errRes OAuthErr | ||
if err := json.Unmarshal(resBytes, &errRes); err != nil { | ||
return err | ||
} | ||
return &errRes | ||
default: | ||
return fmt.Errorf("%s responded with an unexpected HTTP status code: %d", oauthTokenURL, res.StatusCode) | ||
} | ||
|
||
var tokenRes AccessTokenResponse | ||
if err := json.Unmarshal(resBytes, &tokenRes); err != nil { | ||
return err | ||
} | ||
|
||
c.accessTokenMutex.Lock() | ||
c.accessToken = tokenRes.AccessToken | ||
c.accessTokenMutex.Unlock() | ||
|
||
return nil | ||
} | ||
|
||
// SendRequest sends a n HTTP request as specified by its function parameters | ||
// If the server responds with an unauthorized 401 HTTP status code, the client attempts | ||
// to get a new authorization access token and retries the same request one more time | ||
func (c jwtBearer) SendRequest(ctx context.Context, method, relURL string, headers http.Header, requestBody []byte) (int, []byte, error) { | ||
url := c.instanceURL + relURL | ||
var err error | ||
|
||
// Issue the request to salesforce | ||
statusCode, resBody, err := c.sendRequest(ctx, method, url, headers, requestBody) | ||
if err != nil { | ||
// Check if the error came from salesforce's API | ||
if _, ok := err.(*ErrorObjects); ok { | ||
// If the status code returned is Unauthorized (401) | ||
// Presumably, the current client's access token we have has expired, | ||
// hence, we attempt to update the client's access token and retry the request once | ||
// see https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/errorcodes.htm | ||
if statusCode == http.StatusUnauthorized { | ||
err := c.newAccessToken() | ||
if err != nil { | ||
return statusCode, nil, err | ||
} | ||
|
||
// Retry the original request | ||
statusCode, resBody, err = c.sendRequest(ctx, method, url, headers, requestBody) | ||
if err != nil { | ||
return statusCode, nil, err | ||
} | ||
} | ||
} else { | ||
return statusCode, nil, err | ||
} | ||
} | ||
|
||
return statusCode, resBody, nil | ||
} | ||
|
||
func (c jwtBearer) sendRequest(ctx context.Context, method, url string, headers http.Header, requestBody []byte) (int, []byte, error) { | ||
var req *http.Request | ||
var err error | ||
if requestBody == nil { | ||
req, err = http.NewRequestWithContext(ctx, method, url, nil) | ||
} else { | ||
req, err = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(requestBody)) | ||
} | ||
if err != nil { | ||
return -1, nil, err | ||
} | ||
|
||
c.accessTokenMutex.RLock() | ||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) | ||
c.accessTokenMutex.RUnlock() | ||
for hKey, hVals := range headers { | ||
for _, hVal := range hVals { | ||
req.Header.Add(hKey, hVal) | ||
} | ||
} | ||
|
||
res, err := c.client.Do(req) | ||
if err != nil { | ||
return -1, nil, err | ||
} | ||
resBytes, err := ioutil.ReadAll(res.Body) | ||
res.Body.Close() | ||
if err != nil { | ||
return -1, nil, err | ||
} | ||
|
||
var errs ErrorObjects | ||
switch res.StatusCode { | ||
// Salesforce HTTP status codes and error responses: | ||
// https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/errorcodes.htm | ||
case http.StatusOK, http.StatusCreated, http.StatusNoContent, | ||
http.StatusMultipleChoices, http.StatusNotModified: | ||
break | ||
case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden, | ||
http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusUnsupportedMediaType, | ||
http.StatusInternalServerError: | ||
err = json.Unmarshal(resBytes, &errs) | ||
if err != nil { | ||
return res.StatusCode, nil, err | ||
} | ||
return res.StatusCode, nil, &errs | ||
default: | ||
return res.StatusCode, nil, fmt.Errorf("unexpected HTTP status code: %d", res.StatusCode) | ||
} | ||
|
||
return res.StatusCode, resBytes, nil | ||
type Client interface { | ||
SendRequest(ctx context.Context, method, relURL string, headers http.Header, requestBody []byte) (int, []byte, error) | ||
} |
File renamed without changes.
Oops, something went wrong.