Skip to content

Commit aa9d198

Browse files
committed
Async client for durabletask-azuremanged
1 parent f5367b6 commit aa9d198

File tree

5 files changed

+655
-2
lines changed

5 files changed

+655
-2
lines changed

durabletask-azuremanaged/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
9+
10+
- Added `AsyncDurableTaskSchedulerClient` for async/await usage with `grpc.aio`
11+
- Added `DTSAsyncDefaultClientInterceptorImpl` async gRPC interceptor for DTS authentication
12+
813
## v1.3.0
914

1015
- Updates base dependency to durabletask v1.3.0

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from azure.core.credentials import TokenCredential
99

1010
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
11+
DTSAsyncDefaultClientInterceptorImpl,
1112
DTSDefaultClientInterceptorImpl,
1213
)
13-
from durabletask.client import TaskHubGrpcClient
14+
from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient
1415

1516

1617
# Client class used for Durable Task Scheduler (DTS)
@@ -39,3 +40,65 @@ def __init__(self, *,
3940
log_formatter=log_formatter,
4041
interceptors=interceptors,
4142
default_version=default_version)
43+
44+
45+
# Async client class used for Durable Task Scheduler (DTS)
46+
class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient):
47+
"""An async client implementation for Azure Durable Task Scheduler (DTS).
48+
49+
This class extends AsyncTaskHubGrpcClient to provide integration with Azure's
50+
Durable Task Scheduler service using async gRPC. It handles authentication via
51+
Azure credentials and configures the necessary gRPC interceptors for DTS
52+
communication.
53+
54+
Args:
55+
host_address (str): The gRPC endpoint address of the DTS service.
56+
taskhub (str): The name of the task hub. Cannot be empty.
57+
token_credential (Optional[TokenCredential]): Azure credential for authentication.
58+
If None, anonymous authentication will be used.
59+
secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
60+
Defaults to True.
61+
default_version (Optional[str], optional): Default version string for orchestrations.
62+
log_handler (Optional[logging.Handler], optional): Custom logging handler for client logs.
63+
log_formatter (Optional[logging.Formatter], optional): Custom log formatter for client logs.
64+
65+
Raises:
66+
ValueError: If taskhub is empty or None.
67+
68+
Example:
69+
>>> from azure.identity.aio import DefaultAzureCredential
70+
>>> from durabletask.azuremanaged import AsyncDurableTaskSchedulerClient
71+
>>>
72+
>>> credential = DefaultAzureCredential()
73+
>>> async with AsyncDurableTaskSchedulerClient(
74+
... host_address="my-dts-service.azure.com:443",
75+
... taskhub="my-task-hub",
76+
... token_credential=credential
77+
... ) as client:
78+
... instance_id = await client.schedule_new_orchestration("my_orchestrator")
79+
"""
80+
81+
def __init__(self, *,
82+
host_address: str,
83+
taskhub: str,
84+
token_credential: Optional[TokenCredential],
85+
secure_channel: bool = True,
86+
default_version: Optional[str] = None,
87+
log_handler: Optional[logging.Handler] = None,
88+
log_formatter: Optional[logging.Formatter] = None):
89+
90+
if not taskhub:
91+
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
92+
93+
interceptors = [DTSAsyncDefaultClientInterceptorImpl(token_credential, taskhub)]
94+
95+
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
96+
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
97+
super().__init__(
98+
host_address=host_address,
99+
secure_channel=secure_channel,
100+
metadata=None,
101+
log_handler=log_handler,
102+
log_formatter=log_formatter,
103+
interceptors=interceptors,
104+
default_version=default_version)

durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
1111
from durabletask.internal.grpc_interceptor import (
12+
DefaultAsyncClientInterceptorImpl,
1213
DefaultClientInterceptorImpl,
14+
_AsyncClientCallDetails,
1315
_ClientCallDetails,
1416
)
1517

@@ -52,3 +54,44 @@ def _intercept_call(
5254
self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token
5355

5456
return super()._intercept_call(client_call_details)
57+
58+
59+
class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl):
60+
"""Async version of DTSDefaultClientInterceptorImpl for use with grpc.aio channels.
61+
62+
This class implements async gRPC interceptors to add DTS-specific headers
63+
(task hub name, user agent, and authentication token) to all async calls."""
64+
65+
def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str):
66+
try:
67+
# Get the version of the azuremanaged package
68+
sdk_version = version('durabletask-azuremanaged')
69+
except Exception:
70+
# Fallback if version cannot be determined
71+
sdk_version = "unknown"
72+
user_agent = f"durabletask-python/{sdk_version}"
73+
self._metadata = [
74+
("taskhub", taskhub_name),
75+
("x-user-agent", user_agent)]
76+
super().__init__(self._metadata)
77+
78+
if token_credential is not None:
79+
self._token_credential = token_credential
80+
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
81+
access_token = self._token_manager.get_access_token()
82+
if access_token is not None:
83+
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
84+
85+
def _intercept_call(
86+
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
87+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
88+
call details."""
89+
# Refresh the auth token if it is present and needed
90+
if self._metadata is not None:
91+
for i, (key, _) in enumerate(self._metadata):
92+
if key.lower() == "authorization": # Ensure case-insensitive comparison
93+
new_token = self._token_manager.get_access_token() # Get the new token
94+
if new_token is not None:
95+
self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token
96+
97+
return super()._intercept_call(client_call_details)

durabletask/internal/grpc_interceptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class _ClientCallDetails(
2222
class _AsyncClientCallDetails(
2323
namedtuple(
2424
'_AsyncClientCallDetails',
25-
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
25+
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']),
2626
grpc.aio.ClientCallDetails):
2727
"""This is an implementation of the aio ClientCallDetails interface needed for async interceptors.
2828
This class takes six named values and inherits the ClientCallDetails from grpc.aio package.

0 commit comments

Comments
 (0)