44# See https://peps.python.org/pep-0563/
55from __future__ import annotations
66
7+ import math
78from abc import ABC , abstractmethod
89from datetime import datetime , timedelta
9- from typing import Any , Callable , Generator , Generic , List , TypeVar , Union
10+ from typing import (Any , Callable , Generator , Generic , List , Optional , TypeVar ,
11+ Union )
1012
1113import durabletask .internal .helpers as pbh
1214import durabletask .internal .orchestrator_service_pb2 as pb
@@ -87,17 +89,18 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
8789
8890 @abstractmethod
8991 def call_activity (self , activity : Union [Activity [TInput , TOutput ], str ], * ,
90- input : Union [TInput , None ] = None ) -> Task [TOutput ]:
92+ input : Optional [TInput ] = None ,
93+ retry_policy : Optional [RetryPolicy ] = None ) -> Task [TOutput ]:
9194 """Schedule an activity for execution.
9295
9396 Parameters
9497 ----------
9598 activity: Union[Activity[TInput, TOutput], str]
9699 A reference to the activity function to call.
97- input: Union [TInput, None ]
100+ input: Optional [TInput]
98101 The JSON-serializable input (or None) to pass to the activity.
99- return_type: task.Task[TOutput ]
100- The JSON-serializable output type to expect from the activity result .
102+ retry_policy: Optional[RetryPolicy ]
103+ The retry policy to use for this activity call .
101104
102105 Returns
103106 -------
@@ -108,19 +111,22 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
108111
109112 @abstractmethod
110113 def call_sub_orchestrator (self , orchestrator : Orchestrator [TInput , TOutput ], * ,
111- input : Union [TInput , None ] = None ,
112- instance_id : Union [str , None ] = None ) -> Task [TOutput ]:
114+ input : Optional [TInput ] = None ,
115+ instance_id : Optional [str ] = None ,
116+ retry_policy : Optional [RetryPolicy ] = None ) -> Task [TOutput ]:
113117 """Schedule sub-orchestrator function for execution.
114118
115119 Parameters
116120 ----------
117121 orchestrator: Orchestrator[TInput, TOutput]
118122 A reference to the orchestrator function to call.
119- input: Union [TInput, None ]
123+ input: Optional [TInput]
120124 The optional JSON-serializable input to pass to the orchestrator function.
121- instance_id: Union [str, None ]
125+ instance_id: Optional [str]
122126 A unique ID to use for the sub-orchestration instance. If not specified, a
123127 random UUID will be used.
128+ retry_policy: Optional[RetryPolicy]
129+ The retry policy to use for this sub-orchestrator call.
124130
125131 Returns
126132 -------
@@ -162,7 +168,7 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
162168
163169
164170class FailureDetails :
165- def __init__ (self , message : str , error_type : str , stack_trace : Union [str , None ]):
171+ def __init__ (self , message : str , error_type : str , stack_trace : Optional [str ]):
166172 self ._message = message
167173 self ._error_type = error_type
168174 self ._stack_trace = stack_trace
@@ -176,7 +182,7 @@ def error_type(self) -> str:
176182 return self ._error_type
177183
178184 @property
179- def stack_trace (self ) -> Union [str , None ]:
185+ def stack_trace (self ) -> Optional [str ]:
180186 return self ._stack_trace
181187
182188
@@ -206,8 +212,8 @@ class OrchestrationStateError(Exception):
206212class Task (ABC , Generic [T ]):
207213 """Abstract base class for asynchronous tasks in a durable orchestration."""
208214 _result : T
209- _exception : Union [TaskFailedError , None ]
210- _parent : Union [CompositeTask [T ], None ]
215+ _exception : Optional [TaskFailedError ]
216+ _parent : Optional [CompositeTask [T ]]
211217
212218 def __init__ (self ) -> None :
213219 super ().__init__ ()
@@ -261,29 +267,6 @@ def get_tasks(self) -> List[Task]:
261267 def on_child_completed (self , task : Task [T ]):
262268 pass
263269
264-
265- class CompletableTask (Task [T ]):
266-
267- def __init__ (self ):
268- super ().__init__ ()
269-
270- def complete (self , result : T ):
271- if self ._is_complete :
272- raise ValueError ('The task has already completed.' )
273- self ._result = result
274- self ._is_complete = True
275- if self ._parent is not None :
276- self ._parent .on_child_completed (self )
277-
278- def fail (self , message : str , details : pb .TaskFailureDetails ):
279- if self ._is_complete :
280- raise ValueError ('The task has already completed.' )
281- self ._exception = TaskFailedError (message , details )
282- self ._is_complete = True
283- if self ._parent is not None :
284- self ._parent .on_child_completed (self )
285-
286-
287270class WhenAllTask (CompositeTask [List [T ]]):
288271 """A task that completes when all of its child tasks complete."""
289272
@@ -313,6 +296,76 @@ def get_completed_tasks(self) -> int:
313296 return self ._completed_tasks
314297
315298
299+ class CompletableTask (Task [T ]):
300+
301+ def __init__ (self ):
302+ super ().__init__ ()
303+ self ._retryable_parent = None
304+
305+ def complete (self , result : T ):
306+ if self ._is_complete :
307+ raise ValueError ('The task has already completed.' )
308+ self ._result = result
309+ self ._is_complete = True
310+ if self ._parent is not None :
311+ self ._parent .on_child_completed (self )
312+
313+ def fail (self , message : str , details : pb .TaskFailureDetails ):
314+ if self ._is_complete :
315+ raise ValueError ('The task has already completed.' )
316+ self ._exception = TaskFailedError (message , details )
317+ self ._is_complete = True
318+ if self ._parent is not None :
319+ self ._parent .on_child_completed (self )
320+
321+
322+ class RetryableTask (CompletableTask [T ]):
323+ """A task that can be retried according to a retry policy."""
324+
325+ def __init__ (self , retry_policy : RetryPolicy , action : pb .OrchestratorAction ,
326+ start_time :datetime , is_sub_orch : bool ) -> None :
327+ super ().__init__ ()
328+ self ._action = action
329+ self ._retry_policy = retry_policy
330+ self ._attempt_count = 1
331+ self ._start_time = start_time
332+ self ._is_sub_orch = is_sub_orch
333+
334+ def increment_attempt_count (self ) -> None :
335+ self ._attempt_count += 1
336+
337+ def compute_next_delay (self ) -> Union [timedelta , None ]:
338+ if self ._attempt_count >= self ._retry_policy .max_number_of_attempts :
339+ return None
340+
341+ retry_expiration : datetime = datetime .max
342+ if self ._retry_policy .retry_timeout is not None and self ._retry_policy .retry_timeout != datetime .max :
343+ retry_expiration = self ._start_time + self ._retry_policy .retry_timeout
344+
345+ if self ._retry_policy .backoff_coefficient is None :
346+ backoff_coefficient = 1.0
347+ else :
348+ backoff_coefficient = self ._retry_policy .backoff_coefficient
349+
350+ if datetime .utcnow () < retry_expiration :
351+ next_delay_f = math .pow (backoff_coefficient , self ._attempt_count - 1 ) * self ._retry_policy .first_retry_interval .total_seconds ()
352+
353+ if self ._retry_policy .max_retry_interval is not None :
354+ next_delay_f = min (next_delay_f , self ._retry_policy .max_retry_interval .total_seconds ())
355+ return timedelta (seconds = next_delay_f )
356+
357+ return None
358+
359+
360+ class TimerTask (CompletableTask [T ]):
361+
362+ def __init__ (self ) -> None :
363+ super ().__init__ ()
364+
365+ def set_retryable_parent (self , retryable_task : RetryableTask ):
366+ self ._retryable_parent = retryable_task
367+
368+
316369class WhenAnyTask (CompositeTask [Task ]):
317370 """A task that completes when any of its child tasks complete."""
318371
@@ -376,6 +429,74 @@ def task_id(self) -> int:
376429Activity = Callable [[ActivityContext , TInput ], TOutput ]
377430
378431
432+ class RetryPolicy :
433+ """Represents the retry policy for an orchestration or activity function."""
434+
435+ def __init__ (self , * ,
436+ first_retry_interval : timedelta ,
437+ max_number_of_attempts : int ,
438+ backoff_coefficient : Optional [float ] = 1.0 ,
439+ max_retry_interval : Optional [timedelta ] = None ,
440+ retry_timeout : Optional [timedelta ] = None ):
441+ """Creates a new RetryPolicy instance.
442+
443+ Parameters
444+ ----------
445+ first_retry_interval : timedelta
446+ The retry interval to use for the first retry attempt.
447+ max_number_of_attempts : int
448+ The maximum number of retry attempts.
449+ backoff_coefficient : Optional[float]
450+ The backoff coefficient to use for calculating the next retry interval.
451+ max_retry_interval : Optional[timedelta]
452+ The maximum retry interval to use for any retry attempt.
453+ retry_timeout : Optional[timedelta]
454+ The maximum amount of time to spend retrying the operation.
455+ """
456+ # validate inputs
457+ if first_retry_interval < timedelta (seconds = 0 ):
458+ raise ValueError ('first_retry_interval must be >= 0' )
459+ if max_number_of_attempts < 1 :
460+ raise ValueError ('max_number_of_attempts must be >= 1' )
461+ if backoff_coefficient is not None and backoff_coefficient < 1 :
462+ raise ValueError ('backoff_coefficient must be >= 1' )
463+ if max_retry_interval is not None and max_retry_interval < timedelta (seconds = 0 ):
464+ raise ValueError ('max_retry_interval must be >= 0' )
465+ if retry_timeout is not None and retry_timeout < timedelta (seconds = 0 ):
466+ raise ValueError ('retry_timeout must be >= 0' )
467+
468+ self ._first_retry_interval = first_retry_interval
469+ self ._max_number_of_attempts = max_number_of_attempts
470+ self ._backoff_coefficient = backoff_coefficient
471+ self ._max_retry_interval = max_retry_interval
472+ self ._retry_timeout = retry_timeout
473+
474+ @property
475+ def first_retry_interval (self ) -> timedelta :
476+ """The retry interval to use for the first retry attempt."""
477+ return self ._first_retry_interval
478+
479+ @property
480+ def max_number_of_attempts (self ) -> int :
481+ """The maximum number of retry attempts."""
482+ return self ._max_number_of_attempts
483+
484+ @property
485+ def backoff_coefficient (self ) -> Optional [float ]:
486+ """The backoff coefficient to use for calculating the next retry interval."""
487+ return self ._backoff_coefficient
488+
489+ @property
490+ def max_retry_interval (self ) -> Optional [timedelta ]:
491+ """The maximum retry interval to use for any retry attempt."""
492+ return self ._max_retry_interval
493+
494+ @property
495+ def retry_timeout (self ) -> Optional [timedelta ]:
496+ """The maximum amount of time to spend retrying the operation."""
497+ return self ._retry_timeout
498+
499+
379500def get_name (fn : Callable ) -> str :
380501 """Returns the name of the provided function"""
381502 name = fn .__name__
0 commit comments