Skip to content

Commit c95c29b

Browse files
Update enum parsing from TOML, add test for TOML connection config
1 parent d8638d8 commit c95c29b

File tree

3 files changed

+70
-19
lines changed

3 files changed

+70
-19
lines changed

src/snowflake/connector/connection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,13 @@ def __open_connection(self):
11581158
"errno": ER_INVALID_WIF_SETTINGS,
11591159
},
11601160
)
1161+
# Standardize the provider enum.
1162+
if self._workload_identity_provider and isinstance(
1163+
self._workload_identity_provider, str
1164+
):
1165+
self._workload_identity_provider = AttestationProvider.from_string(
1166+
self._workload_identity_provider
1167+
)
11611168
self.auth_class = AuthByWorkloadIdentity(
11621169
provider=self._workload_identity_provider,
11631170
token=self._token,

src/snowflake/connector/wif_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class AttestationProvider(Enum):
4141
OIDC = "OIDC"
4242
"""Provider that looks for an OIDC ID token."""
4343

44+
@staticmethod
45+
def from_string(provider: str) -> AttestationProvider:
46+
"""Converts a string to a strongly-typed enum value of AttestationProvider."""
47+
return AttestationProvider[provider.upper()]
48+
4449

4550
@dataclass
4651
class WorkloadIdentityAttestation:

test/unit/test_connection.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ def mock_post_request(request, url, headers, json_body, **kwargs):
9898
return request_body
9999

100100

101+
def write_temp_file(file_path: Path, contents: str) -> Path:
102+
"""Write the given string text to the given path, chmods it to be accessible, and returns the same path."""
103+
file_path.write_text(contents)
104+
file_path.chmod(stat.S_IRUSR | stat.S_IWUSR)
105+
return file_path
106+
107+
101108
def test_connect_with_service_name(mock_post_requests):
102109
assert fake_connector().service_name == "FAKE_SERVICE_NAME"
103110

@@ -628,24 +635,56 @@ def test_cannot_set_wlid_authenticator_without_env_variable(mock_post_requests):
628635
)
629636

630637

631-
@patch("snowflake.connector.SnowflakeConnection._authenticate", return_value=None)
632-
@patch("snowflake.connector.auth.AuthByWorkloadIdentity.__init__", return_value=None)
633-
def test_connection_params_are_plumbed_into_authbyworkloadidentity(
634-
mock_auth_constructor, mock_authenticate
638+
def test_connection_params_are_plumbed_into_authbyworkloadidentity(monkeypatch):
639+
with monkeypatch.context() as m:
640+
m.setattr(
641+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
642+
)
643+
m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") # Can be set to anything.
644+
645+
conn = snowflake.connector.connect(
646+
account="my_account_1",
647+
workload_identity_provider=AttestationProvider.AWS,
648+
workload_identity_entra_resource="api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
649+
token="my_token",
650+
authenticator="WORKLOAD_IDENTITY",
651+
)
652+
assert conn.auth_class.provider == AttestationProvider.AWS
653+
assert (
654+
conn.auth_class.entra_resource
655+
== "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b"
656+
)
657+
assert conn.auth_class.token == "my_token"
658+
659+
660+
def test_toml_connection_params_are_plumbed_into_authbyworkloadidentity(
661+
monkeypatch, tmp_path
635662
):
636-
# We can set this to any value.
637-
os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = ""
638-
snowflake.connector.connect(
639-
account="my_account_1",
640-
workload_identity_provider=AttestationProvider.AWS,
641-
workload_identity_entra_resource="api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
642-
token="my_token",
643-
authenticator="WORKLOAD_IDENTITY",
663+
token_file = write_temp_file(tmp_path / "token.txt", contents="my_token")
664+
connections_file = write_temp_file(
665+
tmp_path / "connections.toml",
666+
contents=dedent(
667+
f"""\
668+
[default]
669+
account = "my_account_1"
670+
authenticator = "WORKLOAD_IDENTITY"
671+
workload_identity_provider = "OIDC"
672+
workload_identity_entra_resource = "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b"
673+
token_file_path = "{token_file}"
674+
"""
675+
),
644676
)
645-
_, kwargs = mock_auth_constructor.call_args
646-
assert kwargs == {
647-
"provider": AttestationProvider.AWS,
648-
"entra_resource": "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
649-
"token": "my_token",
650-
}
651-
os.environ.pop("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION")
677+
678+
with monkeypatch.context() as m:
679+
m.setattr(
680+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
681+
)
682+
m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "")
683+
684+
conn = snowflake.connector.connect(connections_file_path=connections_file)
685+
assert conn.auth_class.provider == AttestationProvider.OIDC
686+
assert (
687+
conn.auth_class.entra_resource
688+
== "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b"
689+
)
690+
assert conn.auth_class.token == "my_token"

0 commit comments

Comments
 (0)