Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,26 @@ def resume_orchestration(self, instance_id: str):
self._logger.info(f"Resuming instance '{instance_id}'.")
self._stub.ResumeInstance(req)

def restart_orchestration(self, instance_id: str, *,
restart_with_new_instance_id: bool = False) -> str:
"""Restarts an existing orchestration instance.

Args:
instance_id: The ID of the orchestration instance to restart.
restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
If False (default), the restarted orchestration will reuse the same instance ID.

Returns:
The instance ID of the restarted orchestration.
"""
req = pb.RestartInstanceRequest(
instanceId=instance_id,
restartWithNewInstanceId=restart_with_new_instance_id)

self._logger.info(f"Restarting instance '{instance_id}'.")
res: pb.RestartInstanceResponse = self._stub.RestartInstance(req)
return res.instanceId

def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
Expand Down
1 change: 1 addition & 0 deletions durabletask/internal/PROTO_SOURCE_COMMIT_HASH
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
443b333f4f65a438dc9eb4f090560d232afec4b7
fd9369c6a03d6af4e95285e432b7c4e943c06970
026329c53fe6363985655857b9ca848ec7238bd2
482 changes: 262 additions & 220 deletions durabletask/internal/orchestrator_service_pb2.py

Large diffs are not rendered by default.

208 changes: 186 additions & 22 deletions durabletask/internal/orchestrator_service_pb2.pyi

Large diffs are not rendered by default.

132 changes: 132 additions & 0 deletions durabletask/internal/orchestrator_service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceResponse.FromString,
_registered_method=True)
self.RestartInstance = channel.unary_unary(
'/TaskHubSidecarService/RestartInstance',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceResponse.FromString,
_registered_method=True)
self.WaitForInstanceStart = channel.unary_unary(
'/TaskHubSidecarService/WaitForInstanceStart',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceRequest.SerializeToString,
Expand Down Expand Up @@ -95,6 +100,11 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesResponse.FromString,
_registered_method=True)
self.ListInstanceIds = channel.unary_unary(
'/TaskHubSidecarService/ListInstanceIds',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsResponse.FromString,
_registered_method=True)
self.PurgeInstances = channel.unary_unary(
'/TaskHubSidecarService/PurgeInstances',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.PurgeInstancesRequest.SerializeToString,
Expand Down Expand Up @@ -170,6 +180,11 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskResponse.FromString,
_registered_method=True)
self.SkipGracefulOrchestrationTerminations = channel.unary_unary(
'/TaskHubSidecarService/SkipGracefulOrchestrationTerminations',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsResponse.FromString,
_registered_method=True)


class TaskHubSidecarServiceServicer(object):
Expand Down Expand Up @@ -203,6 +218,13 @@ def RewindInstance(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def RestartInstance(self, request, context):
"""Restarts an orchestration instance.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def WaitForInstanceStart(self, request, context):
"""Waits for an orchestration instance to reach a running or completion state.
"""
Expand Down Expand Up @@ -253,6 +275,12 @@ def QueryInstances(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ListInstanceIds(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PurgeInstances(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -353,6 +381,14 @@ def AbandonTaskEntityWorkItem(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SkipGracefulOrchestrationTerminations(self, request, context):
""""Skip" graceful termination of orchestrations by immediately changing their status in storage to "terminated".
Note that a maximum of 500 orchestrations can be terminated at a time using this method.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -376,6 +412,11 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceResponse.SerializeToString,
),
'RestartInstance': grpc.unary_unary_rpc_method_handler(
servicer.RestartInstance,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceResponse.SerializeToString,
),
'WaitForInstanceStart': grpc.unary_unary_rpc_method_handler(
servicer.WaitForInstanceStart,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceRequest.FromString,
Expand Down Expand Up @@ -411,6 +452,11 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesResponse.SerializeToString,
),
'ListInstanceIds': grpc.unary_unary_rpc_method_handler(
servicer.ListInstanceIds,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsResponse.SerializeToString,
),
'PurgeInstances': grpc.unary_unary_rpc_method_handler(
servicer.PurgeInstances,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.PurgeInstancesRequest.FromString,
Expand Down Expand Up @@ -486,6 +532,11 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskResponse.SerializeToString,
),
'SkipGracefulOrchestrationTerminations': grpc.unary_unary_rpc_method_handler(
servicer.SkipGracefulOrchestrationTerminations,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'TaskHubSidecarService', rpc_method_handlers)
Expand Down Expand Up @@ -605,6 +656,33 @@ def RewindInstance(request,
metadata,
_registered_method=True)

@staticmethod
def RestartInstance(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/RestartInstance',
durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def WaitForInstanceStart(request,
target,
Expand Down Expand Up @@ -794,6 +872,33 @@ def QueryInstances(request,
metadata,
_registered_method=True)

@staticmethod
def ListInstanceIds(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/ListInstanceIds',
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def PurgeInstances(request,
target,
Expand Down Expand Up @@ -1198,3 +1303,30 @@ def AbandonTaskEntityWorkItem(request,
timeout,
metadata,
_registered_method=True)

@staticmethod
def SkipGracefulOrchestrationTerminations(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/SkipGracefulOrchestrationTerminations',
durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
3 changes: 3 additions & 0 deletions durabletask/internal/proto_task_hub_sidecar_service_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ class ProtoTaskHubSidecarServiceStub(Protocol):
StartInstance: Callable[..., Any]
GetInstance: Callable[..., Any]
RewindInstance: Callable[..., Any]
RestartInstance: Callable[..., Any]
WaitForInstanceStart: Callable[..., Any]
WaitForInstanceCompletion: Callable[..., Any]
RaiseEvent: Callable[..., Any]
TerminateInstance: Callable[..., Any]
SuspendInstance: Callable[..., Any]
ResumeInstance: Callable[..., Any]
QueryInstances: Callable[..., Any]
ListInstanceIds: Callable[..., Any]
PurgeInstances: Callable[..., Any]
GetWorkItems: Callable[..., Any]
CompleteActivityTask: Callable[..., Any]
Expand All @@ -31,3 +33,4 @@ class ProtoTaskHubSidecarServiceStub(Protocol):
AbandonTaskActivityWorkItem: Callable[..., Any]
AbandonTaskOrchestratorWorkItem: Callable[..., Any]
AbandonTaskEntityWorkItem: Callable[..., Any]
SkipGracefulOrchestrationTerminations: Callable[..., Any]
77 changes: 77 additions & 0 deletions tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")


def _get_credential():
"""Returns DefaultAzureCredential if endpoint is https, otherwise None (for emulator)."""
if endpoint.startswith("https://"):
from azure.identity import DefaultAzureCredential
return DefaultAzureCredential()
return None


def test_empty_orchestration():

invoked = False
Expand Down Expand Up @@ -371,6 +379,75 @@ def child(ctx: task.OrchestrationContext, _):
assert state is None


def test_restart_with_same_instance_id():
def orchestrator(ctx: task.OrchestrationContext, _):
result = yield ctx.call_activity(say_hello, input="World")
return result

def say_hello(ctx: task.ActivityContext, input: str):
return f"Hello, {input}!"

credential = _get_credential()

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential) as w:
w.add_orchestrator(orchestrator)
w.add_activity(say_hello)
w.start()

task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential)
id = task_hub_client.schedule_new_orchestration(orchestrator)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Hello, World!")

# Restart the orchestration with the same instance ID
restarted_id = task_hub_client.restart_orchestration(id)
assert restarted_id == id

state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Hello, World!")


def test_restart_with_new_instance_id():
def orchestrator(ctx: task.OrchestrationContext, _):
result = yield ctx.call_activity(say_hello, input="World")
return result

def say_hello(ctx: task.ActivityContext, input: str):
return f"Hello, {input}!"

credential = _get_credential()

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential) as w:
w.add_orchestrator(orchestrator)
w.add_activity(say_hello)
w.start()

task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential)
id = task_hub_client.schedule_new_orchestration(orchestrator)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED

# Restart the orchestration with a new instance ID
restarted_id = task_hub_client.restart_orchestration(id, restart_with_new_instance_id=True)
assert restarted_id != id

state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Hello, World!")


# def test_continue_as_new():
# all_results = []

Expand Down
Loading
Loading