-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add RemoteOAuth Token helper to refresh
access_token
from clo…
…ud environment (#1866) Implements cloudquery/cloudquery-issues#1978 (internal issue)
- Loading branch information
Showing
6 changed files
with
495 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
package remoteoauth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"os" | ||
|
||
cloudquery_api "github.com/cloudquery/cloudquery-api-go" | ||
"github.com/google/uuid" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
func NewTokenSource(opts ...TokenSourceOption) (oauth2.TokenSource, error) { | ||
t := &tokenSource{} | ||
for _, opt := range opts { | ||
opt(t) | ||
} | ||
|
||
if _, cloudEnabled := os.LookupEnv("CQ_CLOUD"); !cloudEnabled { | ||
return oauth2.StaticTokenSource(&t.currentToken), nil | ||
} | ||
|
||
cloudToken, err := newCloudTokenSource(t.defaultContext) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if t.noWrap { | ||
return cloudToken, nil | ||
} | ||
|
||
return oauth2.ReuseTokenSource(nil, cloudToken), nil | ||
} | ||
|
||
type tokenSource struct { | ||
defaultContext context.Context | ||
currentToken oauth2.Token | ||
noWrap bool | ||
} | ||
|
||
type cloudTokenSource struct { | ||
defaultContext context.Context | ||
apiClient *cloudquery_api.ClientWithResponses | ||
|
||
apiURL string | ||
apiToken string | ||
teamName string | ||
syncName string | ||
testConnUUID uuid.UUID | ||
syncRunUUID uuid.UUID | ||
connectorUUID uuid.UUID | ||
isTestConnection bool | ||
} | ||
|
||
var _ oauth2.TokenSource = (*cloudTokenSource)(nil) | ||
|
||
func newCloudTokenSource(defaultContext context.Context) (oauth2.TokenSource, error) { | ||
t := &cloudTokenSource{ | ||
defaultContext: defaultContext, | ||
} | ||
if t.defaultContext == nil { | ||
t.defaultContext = context.Background() | ||
} | ||
|
||
err := t.initCloudOpts() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
t.apiClient, err = cloudquery_api.NewClientWithResponses(t.apiURL, | ||
cloudquery_api.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { | ||
req.Header.Set("Authorization", "Bearer "+t.apiToken) | ||
return nil | ||
})) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create api client: %w", err) | ||
} | ||
|
||
return t, nil | ||
} | ||
|
||
// Token returns a new token from the remote source using the default context. | ||
func (t *cloudTokenSource) Token() (*oauth2.Token, error) { | ||
return t.retrieveToken(t.defaultContext) | ||
} | ||
|
||
func (t *cloudTokenSource) retrieveToken(ctx context.Context) (*oauth2.Token, error) { | ||
var oauthResp *cloudquery_api.ConnectorCredentialsResponseOAuth | ||
if !t.isTestConnection { | ||
resp, err := t.apiClient.GetSyncRunConnectorCredentialsWithResponse(ctx, t.teamName, t.syncName, t.syncRunUUID, t.connectorUUID) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to get sync run connector credentials: %w", err) | ||
} | ||
if resp.StatusCode() != http.StatusOK { | ||
if resp.JSON422 != nil { | ||
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.JSON422.Message) | ||
} | ||
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.Status()) | ||
} | ||
oauthResp = resp.JSON200.Oauth | ||
} else { | ||
resp, err := t.apiClient.GetTestConnectionConnectorCredentialsWithResponse(ctx, t.teamName, t.testConnUUID, t.connectorUUID) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to get test connection connector credentials: %w", err) | ||
} | ||
if resp.StatusCode() != http.StatusOK { | ||
if resp.JSON422 != nil { | ||
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.JSON422.Message) | ||
} | ||
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.Status()) | ||
} | ||
oauthResp = resp.JSON200.Oauth | ||
} | ||
|
||
if oauthResp == nil { | ||
return nil, fmt.Errorf("missing oauth credentials in response") | ||
} | ||
|
||
tok := &oauth2.Token{ | ||
AccessToken: oauthResp.AccessToken, | ||
} | ||
if oauthResp.Expires != nil { | ||
tok.Expiry = *oauthResp.Expires | ||
} | ||
return tok, nil | ||
} | ||
|
||
func (t *cloudTokenSource) initCloudOpts() error { | ||
var allErr error | ||
|
||
t.apiToken = os.Getenv("CLOUDQUERY_API_KEY") | ||
if t.apiToken == "" { | ||
allErr = errors.Join(allErr, errors.New("CLOUDQUERY_API_KEY missing")) | ||
} | ||
t.apiURL = os.Getenv("CLOUDQUERY_API_URL") | ||
if t.apiURL == "" { | ||
t.apiURL = "https://api.cloudquery.io" | ||
} | ||
|
||
t.teamName = os.Getenv("_CQ_TEAM_NAME") | ||
if t.teamName == "" { | ||
allErr = errors.Join(allErr, errors.New("_CQ_TEAM_NAME missing")) | ||
} | ||
t.syncName = os.Getenv("_CQ_SYNC_NAME") | ||
syncRunID := os.Getenv("_CQ_SYNC_RUN_ID") | ||
testConnID := os.Getenv("_CQ_SYNC_TEST_CONNECTION_ID") | ||
if testConnID == "" && syncRunID == "" { | ||
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID or _CQ_SYNC_RUN_ID missing")) | ||
} else if testConnID != "" && syncRunID != "" { | ||
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID and _CQ_SYNC_RUN_ID are mutually exclusive")) | ||
} | ||
|
||
var err error | ||
if syncRunID != "" { | ||
if t.syncName == "" { | ||
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME missing")) | ||
} | ||
|
||
t.syncRunUUID, err = uuid.Parse(syncRunID) | ||
if err != nil { | ||
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_RUN_ID is not a valid UUID: %w", err)) | ||
} | ||
} | ||
if testConnID != "" { | ||
if t.syncName != "" { | ||
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME should be empty")) | ||
} | ||
|
||
t.testConnUUID, err = uuid.Parse(testConnID) | ||
if err != nil { | ||
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_TEST_CONNECTION_ID is not a valid UUID: %w", err)) | ||
} | ||
t.isTestConnection = true | ||
} | ||
|
||
connectorID := os.Getenv("_CQ_CONNECTOR_ID") | ||
if connectorID == "" { | ||
allErr = errors.Join(allErr, errors.New("_CQ_CONNECTOR_ID missing")) | ||
} else { | ||
t.connectorUUID, err = uuid.Parse(connectorID) | ||
if err != nil { | ||
allErr = errors.Join(allErr, fmt.Errorf("_CQ_CONNECTOR_ID is not a valid UUID: %w", err)) | ||
} | ||
} | ||
return allErr | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
package remoteoauth | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"github.com/google/uuid" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const testAPIKey = "test-key" | ||
|
||
func TestLocalTokenAccess(t *testing.T) { | ||
r := require.New(t) | ||
_, cloud := os.LookupEnv("CQ_CLOUD") | ||
r.False(cloud, "CQ_CLOUD should not be set") | ||
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{})) | ||
r.NoError(err) | ||
tk, err := tok.Token() | ||
r.NoError(err) | ||
r.True(tk.Valid()) | ||
r.Equal("token", tk.AccessToken) | ||
} | ||
|
||
func TestFirstLocalTokenAccess(t *testing.T) { | ||
runID := uuid.NewString() | ||
connID := uuid.NewString() | ||
testURL := setupMockTokenServer(t, map[string]string{ | ||
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`, | ||
}) | ||
setEnvs(t, map[string]string{ | ||
"CQ_CLOUD": "1", | ||
"CLOUDQUERY_API_URL": testURL, | ||
"CLOUDQUERY_API_KEY": testAPIKey, | ||
"_CQ_TEAM_NAME": "the-team", | ||
"_CQ_SYNC_NAME": "the-sync", | ||
"_CQ_SYNC_RUN_ID": runID, | ||
"_CQ_CONNECTOR_ID": connID, | ||
}) | ||
r := require.New(t) | ||
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{})) | ||
r.NoError(err) | ||
tk, err := tok.Token() | ||
r.NoError(err) | ||
r.True(tk.Valid()) | ||
r.Equal("new-token", tk.AccessToken) | ||
} | ||
|
||
func TestInvalidAPIKeyTokenAccess(t *testing.T) { | ||
runID := uuid.NewString() | ||
connID := uuid.NewString() | ||
testURL := setupMockTokenServer(t, nil) | ||
setEnvs(t, map[string]string{ | ||
"CQ_CLOUD": "1", | ||
"CLOUDQUERY_API_URL": testURL, | ||
"CLOUDQUERY_API_KEY": "invalid", | ||
"_CQ_TEAM_NAME": "the-team", | ||
"_CQ_SYNC_NAME": "the-sync", | ||
"_CQ_SYNC_RUN_ID": runID, | ||
"_CQ_CONNECTOR_ID": connID, | ||
}) | ||
r := require.New(t) | ||
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{})) | ||
r.NoError(err) | ||
tk, err := tok.Token() | ||
r.Nil(tk) | ||
r.False(tk.Valid()) | ||
r.ErrorContains(err, "failed to get sync run connector credentials") | ||
} | ||
|
||
func TestSyncRunTokenAccess(t *testing.T) { | ||
runID := uuid.NewString() | ||
connID := uuid.NewString() | ||
testURL := setupMockTokenServer(t, map[string]string{ | ||
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`, | ||
}) | ||
setEnvs(t, map[string]string{ | ||
"CQ_CLOUD": "1", | ||
"CLOUDQUERY_API_URL": testURL, | ||
"CLOUDQUERY_API_KEY": testAPIKey, | ||
"_CQ_TEAM_NAME": "the-team", | ||
"_CQ_SYNC_NAME": "the-sync", | ||
"_CQ_SYNC_RUN_ID": runID, | ||
"_CQ_CONNECTOR_ID": connID, | ||
}) | ||
r := require.New(t) | ||
tok, err := NewTokenSource() | ||
r.NoError(err) | ||
tk, err := tok.Token() | ||
r.NoError(err) | ||
r.True(tk.Valid()) | ||
r.Equal("new-token", tk.AccessToken) | ||
} | ||
|
||
func TestTestConnectionTokenAccess(t *testing.T) { | ||
testID := uuid.NewString() | ||
connID := uuid.NewString() | ||
testURL := setupMockTokenServer(t, map[string]string{ | ||
"/teams/the-team/syncs/test-connections/" + testID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`, | ||
}) | ||
setEnvs(t, map[string]string{ | ||
"CQ_CLOUD": "1", | ||
"CLOUDQUERY_API_URL": testURL, | ||
"CLOUDQUERY_API_KEY": testAPIKey, | ||
"_CQ_TEAM_NAME": "the-team", | ||
"_CQ_SYNC_TEST_CONNECTION_ID": testID, | ||
"_CQ_CONNECTOR_ID": connID, | ||
}) | ||
r := require.New(t) | ||
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{})) | ||
r.NoError(err) | ||
tk, err := tok.Token() | ||
r.NoError(err) | ||
r.True(tk.Valid()) | ||
r.Equal("new-token", tk.AccessToken) | ||
} | ||
|
||
func setEnvs(t *testing.T, envs map[string]string) { | ||
t.Helper() | ||
for k, v := range envs { | ||
t.Setenv(k, v) | ||
} | ||
} | ||
|
||
func setupMockTokenServer(t *testing.T, responses map[string]string) string { | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if a := r.Header.Get("Authorization"); a != "Bearer "+testAPIKey { | ||
w.WriteHeader(http.StatusUnauthorized) | ||
return | ||
} | ||
|
||
resp, ok := responses[r.URL.Path] | ||
if !ok { | ||
w.WriteHeader(http.StatusNotFound) | ||
return | ||
} | ||
|
||
w.Header().Set("Content-Type", "application/json") | ||
w.WriteHeader(http.StatusOK) | ||
w.Write([]byte(resp)) | ||
})) | ||
t.Cleanup(func() { | ||
ts.Close() | ||
}) | ||
return ts.URL | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package remoteoauth | ||
|
||
import ( | ||
"context" | ||
"time" | ||
|
||
"golang.org/x/oauth2" | ||
) | ||
|
||
type TokenSourceOption func(source *tokenSource) | ||
|
||
func WithAccessToken(token, tokenType string, expiry time.Time) TokenSourceOption { | ||
return func(t *tokenSource) { | ||
t.currentToken = oauth2.Token{ | ||
AccessToken: token, | ||
TokenType: tokenType, | ||
Expiry: expiry, | ||
} | ||
} | ||
} | ||
|
||
func WithDefaultContext(ctx context.Context) TokenSourceOption { | ||
return func(t *tokenSource) { | ||
t.defaultContext = ctx | ||
} | ||
} | ||
|
||
func withNoWrap() TokenSourceOption { | ||
return func(t *tokenSource) { | ||
t.noWrap = true | ||
} | ||
} |
Oops, something went wrong.
bcd9081
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⏱️ Benchmark results