Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from collections import namedtuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from azure.mgmt.containerinstance.models import (
Container,
Expand All @@ -33,8 +33,10 @@
DnsConfiguration,
EnvironmentVariable,
IpAddress,
ResourceIdentityType,
ResourceRequests,
ResourceRequirements,
UserAssignedIdentities,
Volume as _AzureVolume,
VolumeMount,
)
Expand Down Expand Up @@ -147,10 +149,13 @@ class AzureContainerInstancesOperator(BaseOperator):
},
priority="Regular",
identity = {
{
"type": "UserAssigned",
"resource_ids": ["/subscriptions/00000000-0000-0000-0000-00000000000/resourceGroups/my_rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my_identity"],
},
"type": "UserAssigned" | "SystemAssigned" | "SystemAssigned,UserAssigned",
"resource_ids": [
"/subscriptions/<sub>/resourceGroups/<rg>/providers/Microsoft.ManagedIdentity/userAssignedIdentities/<id>"
]
"user_assigned_identities": {
"/subscriptions/.../userAssignedIdentities/<id>": {}
}
}
command=["/bin/echo", "world"],
task_id="start_container",
Expand Down Expand Up @@ -188,7 +193,7 @@ def __init__(
dns_config: DnsConfiguration | None = None,
diagnostics: ContainerGroupDiagnostics | None = None,
priority: str | None = "Regular",
identity: ContainerGroupIdentity | None = None,
identity: ContainerGroupIdentity | dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -231,14 +236,74 @@ def __init__(
self.dns_config = dns_config
self.diagnostics = diagnostics
self.priority = priority
self.identity = identity
self.identity = self._ensure_identity(identity)
if self.priority not in ["Regular", "Spot"]:
raise AirflowException(
"Invalid value for the priority argument. "
"Please set 'Regular' or 'Spot' as the priority. "
f"Found `{self.priority}`."
)

# helper to accept dict (user-friendly) or ContainerGroupIdentity (SDK object)
@staticmethod
def _ensure_identity(identity: ContainerGroupIdentity | dict | None) -> ContainerGroupIdentity | None:
"""
Normalize identity input into a ContainerGroupIdentity instance.

Accepts:
- None -> returns None
- ContainerGroupIdentity -> returned as-is
- dict -> converted to ContainerGroupIdentity
- any other object -> returned as-is (pass-through) to preserve backwards compatibility

Expected dict shapes:
{"type": "UserAssigned", "resource_ids": ["/.../userAssignedIdentities/id1", ...]}
or
{"type": "SystemAssigned"}
or
{"type": "SystemAssigned,UserAssigned", "resource_ids": [...]}
"""
if identity is None:
return None

if isinstance(identity, ContainerGroupIdentity):
return identity

if isinstance(identity, dict):
# require type
id_type = identity.get("type")
if not id_type:
raise AirflowException(
"identity dict must include 'type' key with value 'UserAssigned' or 'SystemAssigned'"
)

# map common string type names to ResourceIdentityType enum values if available
type_map = {
"SystemAssigned": ResourceIdentityType.system_assigned,
"UserAssigned": ResourceIdentityType.user_assigned,
"SystemAssigned,UserAssigned": ResourceIdentityType.system_assigned_user_assigned,
"SystemAssigned, UserAssigned": ResourceIdentityType.system_assigned_user_assigned,
}
cg_type = type_map.get(id_type, id_type)

# build user_assigned_identities mapping if resource_ids provided
resource_ids = identity.get("resource_ids")
if resource_ids:
if not isinstance(resource_ids, (list, tuple)):
raise AirflowException("identity['resource_ids'] must be a list of resource id strings")
user_assigned_identities: dict[str, Any] = {rid: {} for rid in resource_ids}
else:
# accept a pre-built mapping if given
user_assigned_identities = identity.get("user_assigned_identities") or {}

return ContainerGroupIdentity(
type=cg_type,
user_assigned_identities=cast(
"dict[str, UserAssignedIdentities] | None", user_assigned_identities
),
)
return identity

def execute(self, context: Context) -> int:
# Check name again in case it was templated.
self._check_name(self.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,41 @@ def test_execute_with_identity(self, aci_mock):

assert called_cg.identity == identity

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_identity_dict(self, aci_mock):
# New test: pass a dict and verify operator converts it to ContainerGroupIdentity
resource_id = "/subscriptions/00000000-0000-0000-0000-00000000000/resourceGroups/my_rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my_identity"
identity_dict = {
"type": "UserAssigned",
"resource_ids": [resource_id],
}

aci_mock.return_value.get_state.return_value = make_mock_container(
state="Terminated", exit_code=0, detail_status="test"
)

aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
resource_group="resource-group",
name="container-name",
image="container-image",
region="region",
task_id="task",
identity=identity_dict,
)
aci.execute(None)
assert aci_mock.return_value.create_or_update.call_count == 1
(_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args

# verify the operator converted dict -> ContainerGroupIdentity with proper mapping
assert hasattr(called_cg, "identity")
assert called_cg.identity is not None
# user_assigned_identities should contain the resource id as a key
assert resource_id in (called_cg.identity.user_assigned_identities or {})


class XcomMock:
def __init__(self) -> None:
Expand Down