Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/snowflake/connector/wif_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,10 @@ def create_azure_attestation(
issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
if not issuer or not subject:
return None
if not issuer.startswith("https://sts.windows.net/"):
if not (
issuer.startswith("https://sts.windows.net/")
or issuer.startswith("https://login.microsoftonline.com/")
):
# This might happen if we're running on a different platform that responds to the same metadata request signature as Azure.
logger.debug("Unexpected Azure token issuer '%s'", issuer)
return None
Expand Down
20 changes: 19 additions & 1 deletion test/unit/test_auth_workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,32 @@ def test_explicit_azure_metadata_server_error_raises_auth_error(exception):


def test_explicit_azure_wrong_issuer_raises_error(fake_azure_metadata_service):
fake_azure_metadata_service.iss = "not-azure"
fake_azure_metadata_service.iss = "https://notazure.com"

auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE)
with pytest.raises(ProgrammingError) as excinfo:
auth_class.prepare()
assert "No workload identity credential was found for 'AZURE'" in str(excinfo.value)


@pytest.mark.parametrize(
"issuer",
[
"https://sts.windows.net/067802cd-8f92-4c7c-bceb-ea8f15d31cc5",
"https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5",
"https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5/v2.0",
],
ids=["v1", "v2_without_suffix", "v2_with_suffix"],
)
def test_explicit_azure_v1_and_v2_issuers_accepted(fake_azure_metadata_service, issuer):
fake_azure_metadata_service.iss = issuer

auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE)
auth_class.prepare()

assert issuer == json.loads(auth_class.assertion_content)["iss"]


def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE)
auth_class.prepare()
Expand Down
Loading