Skip to content

Fallback to account-level auth if possible when using CLI auth #943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codegen/accounts.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func NewAccountClient(c ...*Config) (*AccountClient, error) {
if err != nil {
return nil, err
}
if cfg.AccountID == "" || !cfg.IsAccountClient() {
if !cfg.IsAccountClient() {
return nil, ErrNotAccountClient
}
apiClient, err := client.New(cfg)
Expand Down
2 changes: 1 addition & 1 deletion account_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions common/environment/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func (de DatabricksEnvironment) AzureActiveDirectoryEndpoint() string {
return de.AzureEnvironment.ActiveDirectoryEndpoint
}

func (de DatabricksEnvironment) AccountsHost() string {
return "https://accounts" + de.DnsZone
}

// we default to AWS Prod environment since this case will be a hit for PVC
func DefaultEnvironment() DatabricksEnvironment {
return DatabricksEnvironment{
Expand Down
46 changes: 34 additions & 12 deletions config/auth_databricks_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (c
return nil, nil
}

ts, err := newDatabricksCliTokenSource(cfg)
ts, err := newDatabricksCliTokenSource(ctx, cfg)
if err != nil {
if errors.Is(err, exec.ErrNotFound) {
logger.Debugf(ctx, "Most likely the Databricks CLI is not installed")
Expand Down Expand Up @@ -61,17 +61,12 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (c
var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected")

type databricksCliTokenSource struct {
ctx context.Context
name string
args []string
cfg *Config
}

func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error) {
args := []string{"auth", "token", "--host", cfg.Host}

if cfg.IsAccountClient() {
args = append(args, "--account-id", cfg.AccountID)
}

func newDatabricksCliTokenSource(ctx context.Context, cfg *Config) (*databricksCliTokenSource, error) {
databricksCliPath := cfg.DatabricksCliPath
if databricksCliPath == "" {
databricksCliPath = "databricks"
Expand Down Expand Up @@ -101,16 +96,43 @@ func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error)
return nil, errLegacyDatabricksCli
}

return &databricksCliTokenSource{name: path, args: args}, nil
return &databricksCliTokenSource{ctx: ctx, name: path, cfg: cfg}, nil
}

func (ts *databricksCliTokenSource) Token() (*oauth2.Token, error) {
out, err := exec.Command(ts.name, ts.args...).Output()
baseArgs := []string{"auth", "token"}
if ts.cfg.IsAccountClient() {
args := append(baseArgs, "--host", ts.cfg.Host, "--account-id", ts.cfg.AccountID)
return ts.tokenInner(args)
}
// Try workspace-level auth first, falling back to account-level auth if account ID is available
args := append(baseArgs, "--host", ts.cfg.Host)
t, wsErr := ts.tokenInner(args)
if wsErr == nil {
return t, nil
}
if ts.cfg.AccountID == "" {
return nil, wsErr
}
logger.Debugf(ts.ctx, "account ID available, falling back to account-level authentication")
args = append(baseArgs, "--host", ts.cfg.Environment().AccountsHost(), "--account-id", ts.cfg.AccountID)
t, acctErr := ts.tokenInner(args)
if acctErr == nil {
return t, nil
}
return nil, acctErr
}

func (ts *databricksCliTokenSource) tokenInner(args []string) (*oauth2.Token, error) {
logger.Debugf(ts.ctx, "running command: '%s %s'", ts.name, strings.Join(args, " "))
out, err := exec.Command(ts.name, args...).Output()
if ee, ok := err.(*exec.ExitError); ok {
logger.Debugf(ts.ctx, "command '%s %s' failed: %s", ts.name, strings.Join(args, " "), string(ee.Stderr))
return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr))
}
if err != nil {
return nil, fmt.Errorf("cannot get access token: %v", err)
logger.Debugf(ts.ctx, "command '%s %s' failed to run: %w", ts.name, strings.Join(args, " "), err)
return nil, fmt.Errorf("cannot get access token: %w", err)
}
var t oauth2.Token
err = json.Unmarshal(out, &t)
Expand Down
80 changes: 53 additions & 27 deletions config/auth_databricks_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,54 @@ import (

var cliDummy = &Config{Host: "https://abc.cloud.databricks.com/"}

func writeSmallDummyExecutable(t *testing.T, path string) {
f, err := os.Create(filepath.Join(path, "databricks"))
require.NoError(t, err)
defer f.Close()
err = os.Chmod(f.Name(), 0755)
require.NoError(t, err)
_, err = f.WriteString("#!/bin/sh\necho hello world\n")
require.NoError(t, err)
const smallExecutable = `#!/bin/sh
echo hello world
`

const largeExecutable = `#!/bin/sh
cat <<EOF
{
"access_token": "token",
"token_type": "Bearer",
"expiry": "2023-05-22T00:00:00.000000+00:00"
}
EOF
exit 0
`

func writeLargeDummyExecutable(t *testing.T, path string) {
f, err := os.Create(filepath.Join(path, "databricks"))
require.NoError(t, err)
defer f.Close()
err = os.Chmod(f.Name(), 0755)
require.NoError(t, err)
_, err = f.WriteString("#!/bin/sh\n")
require.NoError(t, err)
const failFirstSucceedThereafter = `#!/bin/sh

# Check if a token file exists in the same directory as the script
if [! -f "$(dirname "$0")/.token_file" ]; then
# If not, create the file and set the token
echo "error: workspace auth not configured" >&2
touch "$(dirname "$0")/.token_file"
exit 1
fi

f.WriteString(`
cat <<EOF
{
"access_token": "token",
"token_type": "Bearer",
"expiry": "2023-05-22T00:00:00.000000+00:00"
"access_token": "token",
"token_type": "Bearer",
"expiry": "2023-05-22T00:00:00.000000+00:00"
}
EOF
`)
exit 0
`

func writeDummyExecutable(t *testing.T, path, contents string, truncateSize int) {
f, err := os.Create(filepath.Join(path, "databricks"))
require.NoError(t, err)
defer f.Close()
err = os.Chmod(f.Name(), 0755)
require.NoError(t, err)
_, err = f.WriteString("exit 0\n")
_, err = f.WriteString(contents)
require.NoError(t, err)

err = f.Truncate(1024 * 1024)
require.NoError(t, err)
if truncateSize > 0 {
err = f.Truncate(int64(truncateSize))
require.NoError(t, err)
}
}

func TestDatabricksCliCredentials_SkipAzure(t *testing.T) {
Expand All @@ -73,7 +87,7 @@ func TestDatabricksCliCredentials_NotInstalled(t *testing.T) {
func TestDatabricksCliCredentials_InstalledLegacy(t *testing.T) {
// Create a dummy databricks executable.
tmp := t.TempDir()
writeSmallDummyExecutable(t, tmp)
writeDummyExecutable(t, tmp, smallExecutable, 0)
t.Setenv("PATH", tmp)

aa := DatabricksCliCredentials{}
Expand All @@ -85,7 +99,7 @@ func TestDatabricksCliCredentials_InstalledLegacyWithSymlink(t *testing.T) {
// Create a dummy databricks executable.
tmp1 := t.TempDir()
tmp2 := t.TempDir()
writeSmallDummyExecutable(t, tmp1)
writeDummyExecutable(t, tmp1, smallExecutable, 0)
os.Symlink(filepath.Join(tmp1, "databricks"), filepath.Join(tmp2, "databricks"))
t.Setenv("PATH", tmp2+string(os.PathListSeparator)+os.Getenv("PATH"))

Expand All @@ -99,7 +113,19 @@ func TestDatabricksCliCredentials_InstalledNew(t *testing.T) {

// Create a dummy databricks executable.
tmp := t.TempDir()
writeLargeDummyExecutable(t, tmp)
writeDummyExecutable(t, tmp, largeExecutable, 1024*1024)
t.Setenv("PATH", tmp+string(os.PathListSeparator)+os.Getenv("PATH"))

aa := DatabricksCliCredentials{}
_, err := aa.Configure(context.Background(), cliDummy)
require.NoError(t, err)
}

func TestDatabricksCliCredentials_FallbackToAccountLevel(t *testing.T) {
env.CleanupEnvironment(t)

tmp := t.TempDir()
writeDummyExecutable(t, tmp, failFirstSucceedThereafter, 0)
t.Setenv("PATH", tmp+string(os.PathListSeparator)+os.Getenv("PATH"))

aa := DatabricksCliCredentials{}
Expand Down
2 changes: 1 addition & 1 deletion config/auth_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials

func oidcEndpoints(ctx context.Context, cfg *Config) (*oauthAuthorizationServer, error) {
prefix := cfg.Host
if cfg.IsAccountClient() && cfg.AccountID != "" {
if cfg.IsAccountClient() {
// TODO: technically, we could use the same config profile for both workspace
// and account, but we have to add logic for determining accounts host from
// workspace host.
Expand Down
17 changes: 8 additions & 9 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,18 +248,17 @@ func (c *Config) IsAws() bool {

// IsAccountClient returns true if client is configured for Accounts API
func (c *Config) IsAccountClient() bool {
if c.AccountID != "" && c.isTesting {
if c.AccountID == "" {
return false
}
if c.isTesting {
return true
}

accountsPrefixes := []string{
"https://accounts.",
"https://accounts-dod.",
if c.Host == c.Environment().AccountsHost() {
return true
}
for _, prefix := range accountsPrefixes {
if strings.HasPrefix(c.Host, prefix) {
return true
}
if strings.HasPrefix(c.Host, "https://accounts-dod.") {
return true
}
return false
}
Expand Down