Skip to content

Commit a2a34b9

Browse files
Add support to authenticate with Account-wide token federation (#1219)
## What changes are proposed in this pull request? This PR adds support to authenticate with [Account-wide token federation](https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation#account-wide-token-federation) from the following auth methods: `env-oidc`, `file-oidc`, and `github-oidc`. The PR also slightly re-organize the code by moving the OIDC token source and Github IDTokenSource in the `oidc` package. ## How is this tested? Unit test + local validation.
1 parent 4f0fe87 commit a2a34b9

File tree

7 files changed

+116
-54
lines changed

7 files changed

+116
-54
lines changed

NEXT_CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
### New Features and Improvements
66

7+
- Add support to authenticate with Account-wide token federation from the
8+
following auth methods: `env-oidc`, `file-oidc`, and `github-oidc`.
9+
710
### Bug Fixes
811

912
### Documentation

config/auth_azure_github_oidc.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"time"
88

99
"github.com/databricks/databricks-sdk-go/config/credentials"
10+
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
1011
"github.com/databricks/databricks-sdk-go/httpclient"
1112
"golang.org/x/oauth2"
1213
)
@@ -26,10 +27,11 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
2627
if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" {
2728
return nil, nil
2829
}
29-
supplier := githubIDTokenSource{actionsIDTokenRequestURL: cfg.ActionsIDTokenRequestURL,
30-
actionsIDTokenRequestToken: cfg.ActionsIDTokenRequestToken,
31-
refreshClient: cfg.refreshClient,
32-
}
30+
supplier := oidc.NewGithubIDTokenSource(
31+
cfg.refreshClient,
32+
cfg.ActionsIDTokenRequestURL,
33+
cfg.ActionsIDTokenRequestToken,
34+
)
3335

3436
idToken, err := supplier.IDToken(ctx, "api://AzureADTokenExchange")
3537
if err != nil {

config/auth_default.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,28 @@ func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy {
3535
},
3636
{
3737
name: "github-oidc",
38-
tokenSource: &githubIDTokenSource{
39-
actionsIDTokenRequestURL: cfg.ActionsIDTokenRequestURL,
40-
actionsIDTokenRequestToken: cfg.ActionsIDTokenRequestToken,
41-
refreshClient: cfg.refreshClient,
42-
},
38+
tokenSource: oidc.NewGithubIDTokenSource(
39+
cfg.refreshClient,
40+
cfg.ActionsIDTokenRequestURL,
41+
cfg.ActionsIDTokenRequestToken,
42+
),
4343
},
4444
// Add new providers at the end of the list
4545
}
4646

4747
strategies := []CredentialsStrategy{}
4848
for _, idTokenSource := range idTokenSources {
49-
oidcConfig := DatabricksOIDCTokenSourceConfig{
49+
oidcConfig := oidc.DatabricksOIDCTokenSourceConfig{
5050
ClientID: cfg.ClientID,
5151
Host: cfg.CanonicalHostName(),
5252
TokenEndpointProvider: cfg.getOidcEndpoints,
5353
Audience: cfg.TokenAudience,
54-
IdTokenSource: idTokenSource.tokenSource,
54+
IDTokenSource: idTokenSource.tokenSource,
5555
}
5656
if cfg.IsAccountClient() {
5757
oidcConfig.AccountID = cfg.AccountID
5858
}
59-
tokenSource := NewDatabricksOIDCTokenSource(oidcConfig)
59+
tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig)
6060
strategies = append(strategies, NewTokenSourceStrategy(idTokenSource.name, tokenSource))
6161
}
6262
return strategies

config/id_token_source_github_oidc.go renamed to config/experimental/auth/oidc/github.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
package config
1+
package oidc
22

33
import (
44
"context"
55
"errors"
66
"fmt"
77

8-
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
98
"github.com/databricks/databricks-sdk-go/httpclient"
109
"github.com/databricks/databricks-sdk-go/logger"
1110
)
1211

12+
// NewGithubIDTokenSource returns a new IDTokenSource that retrieves an IDToken
13+
// from the Github Actions environment. This IDTokenSource is only valid when
14+
// running in Github Actions with OIDC enabled.
15+
func NewGithubIDTokenSource(client *httpclient.ApiClient, actionsIDTokenRequestURL, actionsIDTokenRequestToken string) IDTokenSource {
16+
return &githubIDTokenSource{
17+
actionsIDTokenRequestURL: actionsIDTokenRequestURL,
18+
actionsIDTokenRequestToken: actionsIDTokenRequestToken,
19+
refreshClient: client,
20+
}
21+
}
22+
1323
// githubIDTokenSource retrieves JWT Tokens from Github Actions.
1424
type githubIDTokenSource struct {
1525
actionsIDTokenRequestURL string
@@ -19,7 +29,7 @@ type githubIDTokenSource struct {
1929

2030
// IDToken returns a JWT Token for the specified audience. It will return
2131
// an error if not running in GitHub Actions.
22-
func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*oidc.IDToken, error) {
32+
func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*IDToken, error) {
2333
if g.actionsIDTokenRequestURL == "" {
2434
logger.Debugf(ctx, "Missing ActionsIDTokenRequestURL, likely not calling from a Github action")
2535
return nil, errors.New("missing ActionsIDTokenRequestURL")
@@ -29,7 +39,7 @@ func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*oi
2939
return nil, errors.New("missing ActionsIDTokenRequestToken")
3040
}
3141

32-
resp := &oidc.IDToken{}
42+
resp := &IDToken{}
3343
requestUrl := g.actionsIDTokenRequestURL
3444
if audience != "" {
3545
requestUrl = fmt.Sprintf("%s&audience=%s", requestUrl, audience)

config/id_token_source_github_oidc_test.go renamed to config/experimental/auth/oidc/github_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
package config
1+
package oidc
22

33
import (
44
"context"
55
"net/http"
66
"testing"
77

8-
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
98
"github.com/databricks/databricks-sdk-go/httpclient"
109
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
1110
"github.com/google/go-cmp/cmp"
@@ -18,7 +17,7 @@ func TestGithubIDTokenSource(t *testing.T) {
1817
tokenRequestToken string
1918
audience string
2019
httpTransport http.RoundTripper
21-
wantToken *oidc.IDToken
20+
wantToken *IDToken
2221
wantErrPrefix *string
2322
}{
2423
{
@@ -60,7 +59,7 @@ func TestGithubIDTokenSource(t *testing.T) {
6059
Response: `{"value": "id-token-42"}`,
6160
},
6261
},
63-
wantToken: &oidc.IDToken{
62+
wantToken: &IDToken{
6463
Value: "id-token-42",
6564
},
6665
},

config/auth_databricks_oidc.go renamed to config/experimental/auth/oidc/tokensource.go

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,49 @@
1-
package config
1+
package oidc
22

33
import (
44
"context"
55
"errors"
66
"net/url"
77

88
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
9-
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
109
"github.com/databricks/databricks-sdk-go/credentials/u2m"
1110
"github.com/databricks/databricks-sdk-go/logger"
1211
"golang.org/x/oauth2"
1312
"golang.org/x/oauth2/clientcredentials"
1413
)
1514

16-
// Creates a new Databricks OIDC TokenSource.
17-
func NewDatabricksOIDCTokenSource(cfg DatabricksOIDCTokenSourceConfig) auth.TokenSource {
18-
return &databricksOIDCTokenSource{
19-
cfg: cfg,
20-
}
21-
}
22-
23-
// Config for Databricks OIDC TokenSource.
15+
// DatabricksOIDCTokenSourceConfig is the configuration for a Databricks OIDC
16+
// TokenSource.
2417
type DatabricksOIDCTokenSourceConfig struct {
25-
// ClientID is the client ID of the Databricks OIDC application. For
26-
// Databricks Service Principal, this is the Application ID of the Service Principal.
18+
// ClientID of the Databricks OIDC application. It corresponds to the
19+
// Application ID of the Databricks Service Principal.
20+
//
21+
// This field is only required for Workload Identity Federation and should
22+
// be empty for Account-wide token federation.
2723
ClientID string
28-
// [Optional] AccountID is the account ID of the Databricks Account.
29-
// This is only used for Account level tokens.
24+
25+
// AccountID is the account ID of the Databricks Account. This field is
26+
// only required for Account-wide token federation.
3027
AccountID string
28+
3129
// Host is the host of the Databricks account or workspace.
3230
Host string
33-
// TokenEndpointProvider returns the token endpoint for the Databricks OIDC application.
31+
32+
// TokenEndpointProvider returns the token endpoint for the Databricks OIDC
33+
// application.
3434
TokenEndpointProvider func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error)
35+
3536
// Audience is the audience of the Databricks OIDC application.
3637
// This is only used for Workspace level tokens.
3738
Audience string
38-
// IdTokenSource returns the IDToken to be used for the token exchange.
39-
IdTokenSource oidc.IDTokenSource
39+
40+
// IDTokenSource returns the IDToken to be used for the token exchange.
41+
IDTokenSource IDTokenSource
42+
}
43+
44+
// NewDatabricksOIDCTokenSource returns a new Databricks OIDC TokenSource.
45+
func NewDatabricksOIDCTokenSource(cfg DatabricksOIDCTokenSourceConfig) auth.TokenSource {
46+
return &databricksOIDCTokenSource{cfg: cfg}
4047
}
4148

4249
// databricksOIDCTokenSource is a auth.TokenSource which exchanges a token using
@@ -47,10 +54,6 @@ type databricksOIDCTokenSource struct {
4754

4855
// Token implements [TokenSource.Token]
4956
func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, error) {
50-
if w.cfg.ClientID == "" {
51-
logger.Debugf(ctx, "Missing ClientID")
52-
return nil, errors.New("missing ClientID")
53-
}
5457
if w.cfg.Host == "" {
5558
logger.Debugf(ctx, "Missing Host")
5659
return nil, errors.New("missing Host")
@@ -59,8 +62,17 @@ func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, e
5962
if err != nil {
6063
return nil, err
6164
}
65+
66+
if w.cfg.ClientID == "" {
67+
logger.Debugf(ctx, "No ClientID provided, authenticating with Account-wide token federation")
68+
} else {
69+
logger.Debugf(ctx, "Client ID provided, authenticating with Workload Identity Federation")
70+
}
71+
72+
// TODO: The audience is a concept of the IDToken that should likely be
73+
// configured when the IDTokenSource is created.
6274
audience := w.determineAudience(endpoints)
63-
idToken, err := w.cfg.IdTokenSource.IDToken(ctx, audience)
75+
idToken, err := w.cfg.IDTokenSource.IDToken(ctx, audience)
6476
if err != nil {
6577
return nil, err
6678
}

config/auth_databricks_oidc_test.go renamed to config/experimental/auth/oidc/tokensource_test.go

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1-
package config
1+
package oidc
22

33
import (
44
"context"
55
"errors"
66
"net/http"
77
"net/url"
8+
"strings"
89
"testing"
910

10-
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
1111
"github.com/databricks/databricks-sdk-go/credentials/u2m"
1212
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
1313
"github.com/google/go-cmp/cmp"
1414
"golang.org/x/oauth2"
1515
)
1616

17+
func errPrefix(s string) *string {
18+
return &s
19+
}
20+
21+
func hasPrefix(err error, prefix string) bool {
22+
return strings.HasPrefix(err.Error(), prefix)
23+
}
24+
1725
func TestDatabricksOidcTokenSource(t *testing.T) {
1826
testCases := []struct {
1927
desc string
@@ -35,12 +43,6 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
3543
tokenAudience: "token-audience",
3644
wantErrPrefix: errPrefix("missing Host"),
3745
},
38-
{
39-
desc: "missing client ID",
40-
host: "http://host.com",
41-
tokenAudience: "token-audience",
42-
wantErrPrefix: errPrefix("missing ClientID"),
43-
},
4446
{
4547
desc: "token provider error",
4648

@@ -104,7 +106,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
104106
wantErrPrefix: errPrefix("oauth2: server response missing access_token"),
105107
},
106108
{
107-
desc: "success workspace",
109+
desc: "success WIF workspace",
108110
clientID: "client-id",
109111
host: "http://host.com",
110112
tokenAudience: "token-audience",
@@ -140,7 +142,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
140142
wantToken: "test-auth-token",
141143
},
142144
{
143-
desc: "success account",
145+
desc: "success WIF account",
144146
clientID: "client-id",
145147
accountID: "ac123",
146148
host: "https://accounts.databricks.com",
@@ -230,6 +232,40 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
230232
idToken: "id-token-42",
231233
wantToken: "test-auth-token",
232234
},
235+
{
236+
desc: "success account-wide",
237+
host: "http://host.com",
238+
tokenAudience: "token-audience",
239+
oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) {
240+
return &u2m.OAuthAuthorizationServer{
241+
TokenEndpoint: "https://host.com/oidc/v1/token",
242+
}, nil
243+
},
244+
httpTransport: fixtures.MappingTransport{
245+
"POST /oidc/v1/token": {
246+
247+
Status: http.StatusOK,
248+
ExpectedHeaders: map[string]string{
249+
"Content-Type": "application/x-www-form-urlencoded",
250+
},
251+
ExpectedRequest: url.Values{
252+
"scope": {"all-apis"},
253+
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
254+
"subject_token": {"id-token-42"},
255+
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
256+
},
257+
Response: map[string]string{
258+
"token_type": "access-token",
259+
"access_token": "test-auth-token",
260+
"refresh_token": "refresh",
261+
"expires_on": "0",
262+
},
263+
},
264+
},
265+
wantAudience: "token-audience",
266+
idToken: "id-token-42",
267+
wantToken: "test-auth-token",
268+
},
233269
}
234270

235271
for _, tc := range testCases {
@@ -241,9 +277,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
241277
Host: tc.host,
242278
TokenEndpointProvider: tc.oidcEndpointProvider,
243279
Audience: tc.tokenAudience,
244-
IdTokenSource: oidc.IDTokenSourceFn(func(ctx context.Context, aud string) (*oidc.IDToken, error) {
280+
IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
245281
gotAudience = aud
246-
return &oidc.IDToken{Value: tc.idToken}, tc.tokenProviderError
282+
return &IDToken{Value: tc.idToken}, tc.tokenProviderError
247283
}),
248284
}
249285

0 commit comments

Comments
 (0)