Skip to content

Commit 80de6c2

Browse files
DeepanshuAcgillum
andauthored
Retry policies implementation (#11)
Signed-off-by: Deepanshu Agarwal <deepanshu.agarwal1984@gmail.com> Co-authored-by: Chris Gillum <cgillum@microsoft.com>
1 parent c999097 commit 80de6c2

File tree

5 files changed

+721
-80
lines changed

5 files changed

+721
-80
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Changelog
2+
3+
All notable changes to this project will be documented in this file.
4+
5+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7+
8+
## Unreleased
9+
10+
### New
11+
12+
- Retry policies for activities and sub-orchestrations ([#11](https://github.com/microsoft/durabletask-python/pull/11)) - contributed by [@DeepanshuA](https://github.com/DeepanshuA)
13+
114
## v0.1.0a5
215

316
### New

durabletask/task.py

Lines changed: 157 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# See https://peps.python.org/pep-0563/
55
from __future__ import annotations
66

7+
import math
78
from abc import ABC, abstractmethod
89
from datetime import datetime, timedelta
9-
from typing import Any, Callable, Generator, Generic, List, TypeVar, Union
10+
from typing import (Any, Callable, Generator, Generic, List, Optional, TypeVar,
11+
Union)
1012

1113
import durabletask.internal.helpers as pbh
1214
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -87,17 +89,18 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
8789

8890
@abstractmethod
8991
def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
90-
input: Union[TInput, None] = None) -> Task[TOutput]:
92+
input: Optional[TInput] = None,
93+
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
9194
"""Schedule an activity for execution.
9295
9396
Parameters
9497
----------
9598
activity: Union[Activity[TInput, TOutput], str]
9699
A reference to the activity function to call.
97-
input: Union[TInput, None]
100+
input: Optional[TInput]
98101
The JSON-serializable input (or None) to pass to the activity.
99-
return_type: task.Task[TOutput]
100-
The JSON-serializable output type to expect from the activity result.
102+
retry_policy: Optional[RetryPolicy]
103+
The retry policy to use for this activity call.
101104
102105
Returns
103106
-------
@@ -108,19 +111,22 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
108111

109112
@abstractmethod
110113
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
111-
input: Union[TInput, None] = None,
112-
instance_id: Union[str, None] = None) -> Task[TOutput]:
114+
input: Optional[TInput] = None,
115+
instance_id: Optional[str] = None,
116+
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
113117
"""Schedule sub-orchestrator function for execution.
114118
115119
Parameters
116120
----------
117121
orchestrator: Orchestrator[TInput, TOutput]
118122
A reference to the orchestrator function to call.
119-
input: Union[TInput, None]
123+
input: Optional[TInput]
120124
The optional JSON-serializable input to pass to the orchestrator function.
121-
instance_id: Union[str, None]
125+
instance_id: Optional[str]
122126
A unique ID to use for the sub-orchestration instance. If not specified, a
123127
random UUID will be used.
128+
retry_policy: Optional[RetryPolicy]
129+
The retry policy to use for this sub-orchestrator call.
124130
125131
Returns
126132
-------
@@ -162,7 +168,7 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
162168

163169

164170
class FailureDetails:
165-
def __init__(self, message: str, error_type: str, stack_trace: Union[str, None]):
171+
def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
166172
self._message = message
167173
self._error_type = error_type
168174
self._stack_trace = stack_trace
@@ -176,7 +182,7 @@ def error_type(self) -> str:
176182
return self._error_type
177183

178184
@property
179-
def stack_trace(self) -> Union[str, None]:
185+
def stack_trace(self) -> Optional[str]:
180186
return self._stack_trace
181187

182188

@@ -206,8 +212,8 @@ class OrchestrationStateError(Exception):
206212
class Task(ABC, Generic[T]):
207213
"""Abstract base class for asynchronous tasks in a durable orchestration."""
208214
_result: T
209-
_exception: Union[TaskFailedError, None]
210-
_parent: Union[CompositeTask[T], None]
215+
_exception: Optional[TaskFailedError]
216+
_parent: Optional[CompositeTask[T]]
211217

212218
def __init__(self) -> None:
213219
super().__init__()
@@ -261,29 +267,6 @@ def get_tasks(self) -> List[Task]:
261267
def on_child_completed(self, task: Task[T]):
262268
pass
263269

264-
265-
class CompletableTask(Task[T]):
266-
267-
def __init__(self):
268-
super().__init__()
269-
270-
def complete(self, result: T):
271-
if self._is_complete:
272-
raise ValueError('The task has already completed.')
273-
self._result = result
274-
self._is_complete = True
275-
if self._parent is not None:
276-
self._parent.on_child_completed(self)
277-
278-
def fail(self, message: str, details: pb.TaskFailureDetails):
279-
if self._is_complete:
280-
raise ValueError('The task has already completed.')
281-
self._exception = TaskFailedError(message, details)
282-
self._is_complete = True
283-
if self._parent is not None:
284-
self._parent.on_child_completed(self)
285-
286-
287270
class WhenAllTask(CompositeTask[List[T]]):
288271
"""A task that completes when all of its child tasks complete."""
289272

@@ -313,6 +296,76 @@ def get_completed_tasks(self) -> int:
313296
return self._completed_tasks
314297

315298

299+
class CompletableTask(Task[T]):
300+
301+
def __init__(self):
302+
super().__init__()
303+
self._retryable_parent = None
304+
305+
def complete(self, result: T):
306+
if self._is_complete:
307+
raise ValueError('The task has already completed.')
308+
self._result = result
309+
self._is_complete = True
310+
if self._parent is not None:
311+
self._parent.on_child_completed(self)
312+
313+
def fail(self, message: str, details: pb.TaskFailureDetails):
314+
if self._is_complete:
315+
raise ValueError('The task has already completed.')
316+
self._exception = TaskFailedError(message, details)
317+
self._is_complete = True
318+
if self._parent is not None:
319+
self._parent.on_child_completed(self)
320+
321+
322+
class RetryableTask(CompletableTask[T]):
323+
"""A task that can be retried according to a retry policy."""
324+
325+
def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction,
326+
start_time:datetime, is_sub_orch: bool) -> None:
327+
super().__init__()
328+
self._action = action
329+
self._retry_policy = retry_policy
330+
self._attempt_count = 1
331+
self._start_time = start_time
332+
self._is_sub_orch = is_sub_orch
333+
334+
def increment_attempt_count(self) -> None:
335+
self._attempt_count += 1
336+
337+
def compute_next_delay(self) -> Union[timedelta, None]:
338+
if self._attempt_count >= self._retry_policy.max_number_of_attempts:
339+
return None
340+
341+
retry_expiration: datetime = datetime.max
342+
if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max:
343+
retry_expiration = self._start_time + self._retry_policy.retry_timeout
344+
345+
if self._retry_policy.backoff_coefficient is None:
346+
backoff_coefficient = 1.0
347+
else:
348+
backoff_coefficient = self._retry_policy.backoff_coefficient
349+
350+
if datetime.utcnow() < retry_expiration:
351+
next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds()
352+
353+
if self._retry_policy.max_retry_interval is not None:
354+
next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds())
355+
return timedelta(seconds=next_delay_f)
356+
357+
return None
358+
359+
360+
class TimerTask(CompletableTask[T]):
361+
362+
def __init__(self) -> None:
363+
super().__init__()
364+
365+
def set_retryable_parent(self, retryable_task: RetryableTask):
366+
self._retryable_parent = retryable_task
367+
368+
316369
class WhenAnyTask(CompositeTask[Task]):
317370
"""A task that completes when any of its child tasks complete."""
318371

@@ -376,6 +429,74 @@ def task_id(self) -> int:
376429
Activity = Callable[[ActivityContext, TInput], TOutput]
377430

378431

432+
class RetryPolicy:
433+
"""Represents the retry policy for an orchestration or activity function."""
434+
435+
def __init__(self, *,
436+
first_retry_interval: timedelta,
437+
max_number_of_attempts: int,
438+
backoff_coefficient: Optional[float] = 1.0,
439+
max_retry_interval: Optional[timedelta] = None,
440+
retry_timeout: Optional[timedelta] = None):
441+
"""Creates a new RetryPolicy instance.
442+
443+
Parameters
444+
----------
445+
first_retry_interval : timedelta
446+
The retry interval to use for the first retry attempt.
447+
max_number_of_attempts : int
448+
The maximum number of retry attempts.
449+
backoff_coefficient : Optional[float]
450+
The backoff coefficient to use for calculating the next retry interval.
451+
max_retry_interval : Optional[timedelta]
452+
The maximum retry interval to use for any retry attempt.
453+
retry_timeout : Optional[timedelta]
454+
The maximum amount of time to spend retrying the operation.
455+
"""
456+
# validate inputs
457+
if first_retry_interval < timedelta(seconds=0):
458+
raise ValueError('first_retry_interval must be >= 0')
459+
if max_number_of_attempts < 1:
460+
raise ValueError('max_number_of_attempts must be >= 1')
461+
if backoff_coefficient is not None and backoff_coefficient < 1:
462+
raise ValueError('backoff_coefficient must be >= 1')
463+
if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0):
464+
raise ValueError('max_retry_interval must be >= 0')
465+
if retry_timeout is not None and retry_timeout < timedelta(seconds=0):
466+
raise ValueError('retry_timeout must be >= 0')
467+
468+
self._first_retry_interval = first_retry_interval
469+
self._max_number_of_attempts = max_number_of_attempts
470+
self._backoff_coefficient = backoff_coefficient
471+
self._max_retry_interval = max_retry_interval
472+
self._retry_timeout = retry_timeout
473+
474+
@property
475+
def first_retry_interval(self) -> timedelta:
476+
"""The retry interval to use for the first retry attempt."""
477+
return self._first_retry_interval
478+
479+
@property
480+
def max_number_of_attempts(self) -> int:
481+
"""The maximum number of retry attempts."""
482+
return self._max_number_of_attempts
483+
484+
@property
485+
def backoff_coefficient(self) -> Optional[float]:
486+
"""The backoff coefficient to use for calculating the next retry interval."""
487+
return self._backoff_coefficient
488+
489+
@property
490+
def max_retry_interval(self) -> Optional[timedelta]:
491+
"""The maximum retry interval to use for any retry attempt."""
492+
return self._max_retry_interval
493+
494+
@property
495+
def retry_timeout(self) -> Optional[timedelta]:
496+
"""The maximum amount of time to spend retrying the operation."""
497+
return self._retry_timeout
498+
499+
379500
def get_name(fn: Callable) -> str:
380501
"""Returns the name of the provided function"""
381502
name = fn.__name__

0 commit comments

Comments
 (0)