Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,14 @@ def _get_default_azure_credential(
*,
managed_identity_client_id: str | None = None,
workload_identity_tenant_id: str | None = None,
use_async: bool = False,
) -> DefaultAzureCredential | AsyncDefaultAzureCredential:
) -> DefaultAzureCredential:
"""
Get DefaultAzureCredential based on provided arguments.

If managed_identity_client_id and workload_identity_tenant_id are provided, this function returns
DefaultAzureCredential with managed identity.
"""
credential_cls: type[AsyncDefaultAzureCredential] | type[DefaultAzureCredential] = (
AsyncDefaultAzureCredential if use_async else DefaultAzureCredential
)
credential_cls: type[DefaultAzureCredential] = DefaultAzureCredential
if managed_identity_client_id and workload_identity_tenant_id:
return credential_cls(
managed_identity_client_id=managed_identity_client_id,
Expand All @@ -80,14 +77,31 @@ def _get_default_azure_credential(
return credential_cls()


get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial(
_get_default_azure_credential, # type: ignore[arg-type]
use_async=False,
)
def _get_async_default_azure_credential(
*,
managed_identity_client_id: str | None = None,
workload_identity_tenant_id: str | None = None,
) -> AsyncDefaultAzureCredential:
"""
Get AsyncDefaultAzureCredential based on provided arguments.

If managed_identity_client_id and workload_identity_tenant_id are provided, this function returns
AsyncDefaultAzureCredential with managed identity.
"""
credential_cls: type[AsyncDefaultAzureCredential] = AsyncDefaultAzureCredential
if managed_identity_client_id and workload_identity_tenant_id:
return credential_cls(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
additionally_allowed_tenants=[workload_identity_tenant_id],
)
return credential_cls()


get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial(_get_default_azure_credential)

get_async_default_azure_credential: partial[AsyncDefaultAzureCredential] = partial(
_get_default_azure_credential, # type: ignore[arg-type]
use_async=True,
_get_async_default_azure_credential
)


Expand Down