Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions msticpy/auth/azure_auth_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,37 +351,28 @@ def _az_connect_core(
azure_identity_logger.handlers = [handler]

if not credential:
chained_credential: ChainedTokenCredential = _build_chained_creds(
chained_credential: ChainedTokenCredential = _create_chained_credential(
aad_uri=aad_uri,
requested_clients=auth_methods,
tenant_id=tenant_id,
**kwargs,
)
legacy_creds: CredentialWrapper = CredentialWrapper(
wrapped_credentials: CredentialWrapper = CredentialWrapper(
chained_credential, resource_id=az_config.token_uri
)
else:
# Connect to the subscription client to validate
legacy_creds = CredentialWrapper(credential, resource_id=az_config.token_uri)
return AzCredentials(wrapped_credentials, chained_credential) # type: ignore[arg-type]

if not credential:
err_msg: str = (
"Cannot authenticate with specified credential types. "
"At least one valid authentication method required."
)
raise MsticpyAzureConfigError(
err_msg,
help_uri=_HELP_URI,
title="Authentication failure",
)

return AzCredentials(legacy_creds, ChainedTokenCredential(credential)) # type: ignore[arg-type]
# Create the wrapped credential using the passed credential
wrapped_credentials = CredentialWrapper(credential, resource_id=az_config.token_uri)
return AzCredentials(
wrapped_credentials, ChainedTokenCredential(credential) # type: ignore[arg-type]
)


az_connect_core: Callable[..., AzCredentials] = _az_connect_core


def _build_chained_creds(
def _create_chained_credential(
aad_uri,
requested_clients: list[str] | None = None,
tenant_id: str | None = None,
Expand Down
50 changes: 50 additions & 0 deletions tests/auth/test_azure_auth_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import pytest_check as check

from msticpy.auth.azure_auth_core import (
AzCredentials,
AzureCliStatus,
AzureCloudConfig,
DeviceCodeCredential,
_az_connect_core,
_build_env_client,
check_cli_credentials,
)
Expand Down Expand Up @@ -178,3 +181,50 @@ def test_build_env_client(env_vars, expected, monkeypatch):
check.is_true(
mock_env_cred.called_once_with(authority="test_aad_uri") or not expected
)


@pytest.mark.parametrize(
"auth_methods, cloud, tenant_id, silent, region, credential",
[
(["env", "cli"], "global", "tenant1", False, "region1", None),
(["msi", "interactive"], "usgov", "tenant2", True, "region2", None),
(None, None, None, False, None, DeviceCodeCredential()),
],
)
def test_az_connect_core(auth_methods, cloud, tenant_id, silent, region, credential):
"""
Test _az_connect_core function with different parameters.

Parameters
----------
auth_methods : list[str]
List of authentication methods to try.
cloud : str
Azure cloud to connect to.
tenant_id : str
Tenant to authenticate against.
silent : bool
Whether to display any output during auth process.
region : str
Azure region to connect to.
credential : AzCredentials
Azure credential to use directly.

Returns
-------
None
"""
# Call the function with the test parameters
result = _az_connect_core(
auth_methods=auth_methods,
cloud=cloud,
tenant_id=tenant_id,
silent=silent,
region=region,
credential=credential,
)

# Assert that the result matches the expected credential
assert isinstance(result, AzCredentials)
assert result.legacy is not None
assert result.modern is not None