-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathauth_databricks_oidc.go
91 lines (83 loc) · 2.9 KB
/
auth_databricks_oidc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package config
import (
"context"
"errors"
"net/url"
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
// Creates a new Databricks OIDC TokenSource.
func NewDatabricksOIDCTokenSource(cfg DatabricksOIDCTokenSourceConfig) auth.TokenSource {
return &databricksOIDCTokenSource{
cfg: cfg,
}
}
// Config for Databricks OIDC TokenSource.
type DatabricksOIDCTokenSourceConfig struct {
// ClientID is the client ID of the Databricks OIDC application. For
// Databricks Service Principal, this is the Application ID of the Service Principal.
ClientID string
// [Optional] AccountID is the account ID of the Databricks Account.
// This is only used for Account level tokens.
AccountID string
// Host is the host of the Databricks account or workspace.
Host string
// TokenEndpointProvider returns the token endpoint for the Databricks OIDC application.
TokenEndpointProvider func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error)
// Audience is the audience of the Databricks OIDC application.
// This is only used for Workspace level tokens.
Audience string
// IdTokenSource returns the IDToken to be used for the token exchange.
IdTokenSource IDTokenSource
}
// databricksOIDCTokenSource is a auth.TokenSource which exchanges a token using
// Workload Identity Federation.
type databricksOIDCTokenSource struct {
cfg DatabricksOIDCTokenSourceConfig
}
// Token implements [TokenSource.Token]
func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, error) {
if w.cfg.ClientID == "" {
logger.Debugf(ctx, "Missing ClientID")
return nil, errors.New("missing ClientID")
}
if w.cfg.Host == "" {
logger.Debugf(ctx, "Missing Host")
return nil, errors.New("missing Host")
}
endpoints, err := w.cfg.TokenEndpointProvider(ctx)
if err != nil {
return nil, err
}
audience := w.determineAudience(endpoints)
idToken, err := w.cfg.IdTokenSource.IDToken(ctx, audience)
if err != nil {
return nil, err
}
c := &clientcredentials.Config{
ClientID: w.cfg.ClientID,
AuthStyle: oauth2.AuthStyleInParams,
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
EndpointParams: url.Values{
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
"subject_token": {idToken.Value},
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
},
}
return c.Token(ctx)
}
func (w *databricksOIDCTokenSource) determineAudience(endpoints *u2m.OAuthAuthorizationServer) string {
if w.cfg.Audience != "" {
return w.cfg.Audience
}
// For Databricks Accounts, the account id is the default audience.
if w.cfg.AccountID != "" {
return w.cfg.AccountID
}
// For Databricks Workspaces, the auth endpoint is the default audience.
return endpoints.TokenEndpoint
}