Skip to content

Commit

Permalink
feat: Add RemoteOAuth Token helper to refresh access_token from clo…
Browse files Browse the repository at this point in the history
…ud environment (#1866)

Implements cloudquery/cloudquery-issues#1978 (internal issue)
  • Loading branch information
disq authored Aug 12, 2024
1 parent d1dd099 commit bcd9081
Show file tree
Hide file tree
Showing 6 changed files with 495 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.28.0
go.opentelemetry.io/otel/trace v1.28.0
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56
golang.org/x/oauth2 v0.20.0
golang.org/x/sync v0.7.0
golang.org/x/text v0.16.0
google.golang.org/grpc v1.65.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo=
golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
Expand Down
187 changes: 187 additions & 0 deletions helpers/remoteoauth/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package remoteoauth

import (
"context"
"errors"
"fmt"
"net/http"
"os"

cloudquery_api "github.com/cloudquery/cloudquery-api-go"
"github.com/google/uuid"
"golang.org/x/oauth2"
)

func NewTokenSource(opts ...TokenSourceOption) (oauth2.TokenSource, error) {
t := &tokenSource{}
for _, opt := range opts {
opt(t)
}

if _, cloudEnabled := os.LookupEnv("CQ_CLOUD"); !cloudEnabled {
return oauth2.StaticTokenSource(&t.currentToken), nil
}

cloudToken, err := newCloudTokenSource(t.defaultContext)
if err != nil {
return nil, err
}
if t.noWrap {
return cloudToken, nil
}

return oauth2.ReuseTokenSource(nil, cloudToken), nil
}

type tokenSource struct {
defaultContext context.Context
currentToken oauth2.Token
noWrap bool
}

type cloudTokenSource struct {
defaultContext context.Context
apiClient *cloudquery_api.ClientWithResponses

apiURL string
apiToken string
teamName string
syncName string
testConnUUID uuid.UUID
syncRunUUID uuid.UUID
connectorUUID uuid.UUID
isTestConnection bool
}

var _ oauth2.TokenSource = (*cloudTokenSource)(nil)

func newCloudTokenSource(defaultContext context.Context) (oauth2.TokenSource, error) {
t := &cloudTokenSource{
defaultContext: defaultContext,
}
if t.defaultContext == nil {
t.defaultContext = context.Background()
}

err := t.initCloudOpts()
if err != nil {
return nil, err
}

t.apiClient, err = cloudquery_api.NewClientWithResponses(t.apiURL,
cloudquery_api.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
req.Header.Set("Authorization", "Bearer "+t.apiToken)
return nil
}))
if err != nil {
return nil, fmt.Errorf("failed to create api client: %w", err)
}

return t, nil
}

// Token returns a new token from the remote source using the default context.
func (t *cloudTokenSource) Token() (*oauth2.Token, error) {
return t.retrieveToken(t.defaultContext)
}

func (t *cloudTokenSource) retrieveToken(ctx context.Context) (*oauth2.Token, error) {
var oauthResp *cloudquery_api.ConnectorCredentialsResponseOAuth
if !t.isTestConnection {
resp, err := t.apiClient.GetSyncRunConnectorCredentialsWithResponse(ctx, t.teamName, t.syncName, t.syncRunUUID, t.connectorUUID)
if err != nil {
return nil, fmt.Errorf("failed to get sync run connector credentials: %w", err)
}
if resp.StatusCode() != http.StatusOK {
if resp.JSON422 != nil {
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.JSON422.Message)
}
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.Status())
}
oauthResp = resp.JSON200.Oauth
} else {
resp, err := t.apiClient.GetTestConnectionConnectorCredentialsWithResponse(ctx, t.teamName, t.testConnUUID, t.connectorUUID)
if err != nil {
return nil, fmt.Errorf("failed to get test connection connector credentials: %w", err)
}
if resp.StatusCode() != http.StatusOK {
if resp.JSON422 != nil {
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.JSON422.Message)
}
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.Status())
}
oauthResp = resp.JSON200.Oauth
}

if oauthResp == nil {
return nil, fmt.Errorf("missing oauth credentials in response")
}

tok := &oauth2.Token{
AccessToken: oauthResp.AccessToken,
}
if oauthResp.Expires != nil {
tok.Expiry = *oauthResp.Expires
}
return tok, nil
}

func (t *cloudTokenSource) initCloudOpts() error {
var allErr error

t.apiToken = os.Getenv("CLOUDQUERY_API_KEY")
if t.apiToken == "" {
allErr = errors.Join(allErr, errors.New("CLOUDQUERY_API_KEY missing"))
}
t.apiURL = os.Getenv("CLOUDQUERY_API_URL")
if t.apiURL == "" {
t.apiURL = "https://api.cloudquery.io"
}

t.teamName = os.Getenv("_CQ_TEAM_NAME")
if t.teamName == "" {
allErr = errors.Join(allErr, errors.New("_CQ_TEAM_NAME missing"))
}
t.syncName = os.Getenv("_CQ_SYNC_NAME")
syncRunID := os.Getenv("_CQ_SYNC_RUN_ID")
testConnID := os.Getenv("_CQ_SYNC_TEST_CONNECTION_ID")
if testConnID == "" && syncRunID == "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID or _CQ_SYNC_RUN_ID missing"))
} else if testConnID != "" && syncRunID != "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID and _CQ_SYNC_RUN_ID are mutually exclusive"))
}

var err error
if syncRunID != "" {
if t.syncName == "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME missing"))
}

t.syncRunUUID, err = uuid.Parse(syncRunID)
if err != nil {
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_RUN_ID is not a valid UUID: %w", err))
}
}
if testConnID != "" {
if t.syncName != "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME should be empty"))
}

t.testConnUUID, err = uuid.Parse(testConnID)
if err != nil {
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_TEST_CONNECTION_ID is not a valid UUID: %w", err))
}
t.isTestConnection = true
}

connectorID := os.Getenv("_CQ_CONNECTOR_ID")
if connectorID == "" {
allErr = errors.Join(allErr, errors.New("_CQ_CONNECTOR_ID missing"))
} else {
t.connectorUUID, err = uuid.Parse(connectorID)
if err != nil {
allErr = errors.Join(allErr, fmt.Errorf("_CQ_CONNECTOR_ID is not a valid UUID: %w", err))
}
}
return allErr
}
149 changes: 149 additions & 0 deletions helpers/remoteoauth/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package remoteoauth

import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

const testAPIKey = "test-key"

func TestLocalTokenAccess(t *testing.T) {
r := require.New(t)
_, cloud := os.LookupEnv("CQ_CLOUD")
r.False(cloud, "CQ_CLOUD should not be set")
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("token", tk.AccessToken)
}

func TestFirstLocalTokenAccess(t *testing.T) {
runID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, map[string]string{
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
})
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": testAPIKey,
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_NAME": "the-sync",
"_CQ_SYNC_RUN_ID": runID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("new-token", tk.AccessToken)
}

func TestInvalidAPIKeyTokenAccess(t *testing.T) {
runID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, nil)
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": "invalid",
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_NAME": "the-sync",
"_CQ_SYNC_RUN_ID": runID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
r.NoError(err)
tk, err := tok.Token()
r.Nil(tk)
r.False(tk.Valid())
r.ErrorContains(err, "failed to get sync run connector credentials")
}

func TestSyncRunTokenAccess(t *testing.T) {
runID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, map[string]string{
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
})
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": testAPIKey,
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_NAME": "the-sync",
"_CQ_SYNC_RUN_ID": runID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource()
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("new-token", tk.AccessToken)
}

func TestTestConnectionTokenAccess(t *testing.T) {
testID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, map[string]string{
"/teams/the-team/syncs/test-connections/" + testID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
})
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": testAPIKey,
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_TEST_CONNECTION_ID": testID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("new-token", tk.AccessToken)
}

func setEnvs(t *testing.T, envs map[string]string) {
t.Helper()
for k, v := range envs {
t.Setenv(k, v)
}
}

func setupMockTokenServer(t *testing.T, responses map[string]string) string {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if a := r.Header.Get("Authorization"); a != "Bearer "+testAPIKey {
w.WriteHeader(http.StatusUnauthorized)
return
}

resp, ok := responses[r.URL.Path]
if !ok {
w.WriteHeader(http.StatusNotFound)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(resp))
}))
t.Cleanup(func() {
ts.Close()
})
return ts.URL
}
32 changes: 32 additions & 0 deletions helpers/remoteoauth/tokenoptions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package remoteoauth

import (
"context"
"time"

"golang.org/x/oauth2"
)

type TokenSourceOption func(source *tokenSource)

func WithAccessToken(token, tokenType string, expiry time.Time) TokenSourceOption {
return func(t *tokenSource) {
t.currentToken = oauth2.Token{
AccessToken: token,
TokenType: tokenType,
Expiry: expiry,
}
}
}

func WithDefaultContext(ctx context.Context) TokenSourceOption {
return func(t *tokenSource) {
t.defaultContext = ctx
}
}

func withNoWrap() TokenSourceOption {
return func(t *tokenSource) {
t.noWrap = true
}
}
Loading

1 comment on commit bcd9081

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⏱️ Benchmark results

  • Glob-8 ns/op: 94.47

Please sign in to comment.