Skip to content

Commit aa34c48

Browse files
committed
More tests
1 parent 3a1c309 commit aa34c48

File tree

10 files changed

+476
-120
lines changed

10 files changed

+476
-120
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,10 +467,8 @@ synchronous multithreaded activities.
467467

468468
###### Synchronous Multiprocess/Other Activities
469469

470-
Synchronous activities, i.e. functions that do not have `async def`, can be used with workers, but the
471-
`activity_executor` worker parameter must be set with a `concurrent.futures.Executor` instance to use for executing the
472-
activities. If this is _not_ set to an instance of `concurrent.futures.ThreadPoolExecutor` then the synchronous
473-
activities are considered multiprocess/other activities.
470+
If `activity_executor` is set to an instance of `concurrent.futures.Executor` that is _not_
471+
`concurrent.futures.ThreadPoolExecutor`, then the synchronous activities are considered multiprocess/other activities.
474472

475473
These require special primitives for heartbeating and cancellation. The `shared_state_manager` worker parameter must be
476474
set to an instance of `temporalio.worker.SharedStateManager`. The most common implementation can be created by passing a

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ asyncio_mode = "auto"
9090
log_cli = true
9191
log_cli_level = "INFO"
9292
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
93-
timeout = 300
93+
timeout = 600
9494
timeout_func_only = true
9595

9696
[tool.isort]

temporalio/client.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,13 @@ async def start_workflow(
347347
"""
348348
# Use definition if callable
349349
name: str
350-
arg_types: Optional[List[Type]] = None
351350
ret_type: Optional[Type] = None
352351
if isinstance(workflow, str):
353352
name = workflow
354353
elif callable(workflow):
355354
defn = temporalio.workflow._Definition.must_from_run_fn(workflow)
356355
name = defn.name
357-
arg_types, ret_type = self._type_lookup.get_type_hints(defn.run_fn)
356+
_, ret_type = self._type_lookup.get_type_hints(defn.run_fn)
358357
else:
359358
raise TypeError("Workflow must be a string or callable")
360359

@@ -375,7 +374,6 @@ async def start_workflow(
375374
header=header,
376375
start_signal=start_signal,
377376
start_signal_args=start_signal_args,
378-
arg_types=arg_types,
379377
ret_type=ret_type,
380378
)
381379
)
@@ -646,13 +644,15 @@ def __init__(
646644
run_id: Optional[str] = None,
647645
result_run_id: Optional[str] = None,
648646
first_execution_run_id: Optional[str] = None,
647+
result_type: Optional[Type] = None,
649648
) -> None:
650649
"""Create workflow handle."""
651650
self._client = client
652651
self._id = id
653652
self._run_id = run_id
654653
self._result_run_id = result_run_id
655654
self._first_execution_run_id = first_execution_run_id
655+
self._result_type = result_type
656656

657657
@property
658658
def id(self) -> str:
@@ -754,9 +754,10 @@ async def result(self, *, follow_runs: bool = True) -> WorkflowReturnType:
754754
req.next_page_token = b""
755755
continue
756756
# Ignoring anything after the first response like TypeScript
757-
# TODO(cretz): Support type hints
757+
type_hints = [self._result_type] if self._result_type else None
758758
results = await self._client.data_converter.decode_wrapper(
759-
complete_attr.result
759+
complete_attr.result,
760+
type_hints,
760761
)
761762
if not results:
762763
return cast(WorkflowReturnType, None)
@@ -969,6 +970,7 @@ async def query(
969970
RPCError: Workflow details could not be fetched.
970971
"""
971972
query_name: str
973+
ret_type: Optional[Type] = None
972974
if callable(query):
973975
defn = temporalio.workflow._QueryDefinition.from_fn(query)
974976
if not defn:
@@ -980,6 +982,7 @@ async def query(
980982
raise RuntimeError("Cannot invoke dynamic query definition")
981983
# TODO(cretz): Check count/type of args at runtime?
982984
query_name = defn.name
985+
_, ret_type = self._client._type_lookup.get_type_hints(defn.fn)
983986
else:
984987
query_name = str(query)
985988

@@ -991,6 +994,7 @@ async def query(
991994
args=temporalio.common._arg_or_args(arg, args),
992995
reject_condition=reject_condition
993996
or self._client._config["default_workflow_query_reject_condition"],
997+
ret_type=ret_type,
994998
)
995999
)
9961000

@@ -1245,8 +1249,7 @@ class StartWorkflowInput:
12451249
header: Optional[Mapping[str, Any]]
12461250
start_signal: Optional[str]
12471251
start_signal_args: Iterable[Any]
1248-
# The types may be absent
1249-
arg_types: Optional[List[Type]]
1252+
# Type may be absent
12501253
ret_type: Optional[Type]
12511254

12521255

@@ -1268,6 +1271,8 @@ class QueryWorkflowInput:
12681271
query: str
12691272
args: Iterable[Any]
12701273
reject_condition: Optional[temporalio.common.QueryRejectCondition]
1274+
# Type may be absent
1275+
ret_type: Optional[Type]
12711276

12721277

12731278
@dataclass
@@ -1429,6 +1434,7 @@ async def start_workflow(
14291434
req.workflow_id,
14301435
result_run_id=resp.run_id,
14311436
first_execution_run_id=first_execution_run_id,
1437+
result_type=input.ret_type,
14321438
)
14331439

14341440
async def cancel_workflow(self, input: CancelWorkflowInput) -> None:
@@ -1474,7 +1480,10 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
14741480
)
14751481
if not resp.query_result.payloads:
14761482
return None
1477-
results = await self._client.data_converter.decode(resp.query_result.payloads)
1483+
type_hints = [input.ret_type] if input.ret_type else None
1484+
results = await self._client.data_converter.decode(
1485+
resp.query_result.payloads, type_hints
1486+
)
14781487
if not results:
14791488
return None
14801489
elif len(results) > 1:

temporalio/converter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,14 +571,16 @@ async def encode_wrapper(
571571
return temporalio.api.common.v1.Payloads(payloads=(await self.encode(values)))
572572

573573
async def decode_wrapper(
574-
self, payloads: Optional[temporalio.api.common.v1.Payloads]
574+
self,
575+
payloads: Optional[temporalio.api.common.v1.Payloads],
576+
type_hints: Optional[List[Type]] = None,
575577
) -> List[Any]:
576578
""":py:meth:`decode` for the
577579
:py:class:`temporalio.api.common.v1.Payloads` wrapper.
578580
"""
579581
if not payloads or not payloads.payloads:
580582
return []
581-
return await self.decode(payloads.payloads)
583+
return await self.decode(payloads.payloads, type_hints)
582584

583585

584586
_default: Optional[DataConverter] = None

temporalio/worker/activity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ async def _run_activity(
356356
heartbeat_timeout=start.heartbeat_timeout.ToTimedelta()
357357
if start.HasField("heartbeat_timeout")
358358
else None,
359-
is_local=False,
359+
is_local=start.is_local,
360360
retry_policy=temporalio.bridge.worker.retry_policy_from_proto(
361361
start.retry_policy
362362
)

temporalio/worker/worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ def __init__(
224224
data_converter=client_config["data_converter"],
225225
interceptors=interceptors,
226226
type_hint_eval_str=client_config["type_hint_eval_str"],
227-
max_concurrent_workflow_tasks=max_concurrent_workflow_tasks,
228227
)
229228

230229
# Create bridge worker last. We have empirically observed that if it is

temporalio/worker/workflow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(
5353
data_converter: temporalio.converter.DataConverter,
5454
interceptors: Iterable[Interceptor],
5555
type_hint_eval_str: bool,
56-
max_concurrent_workflow_tasks: int,
5756
) -> None:
5857
self._bridge_worker = bridge_worker
5958
self._namespace = namespace
@@ -182,6 +181,9 @@ async def _handle_activation(
182181
f"Failed converting activation exception: {inner_err}"
183182
)
184183

184+
# Always set the run ID on the completion
185+
completion.run_id = act.run_id
186+
185187
# Encode the completion if there's a codec
186188
if self._data_converter.payload_codec:
187189
try:
@@ -234,6 +236,13 @@ async def _create_workflow_instance(
234236

235237
# Build info
236238
start = start_job.start_workflow
239+
parent: Optional[temporalio.workflow.ParentInfo] = None
240+
if start.HasField("parent_workflow_info"):
241+
parent = temporalio.workflow.ParentInfo(
242+
namespace=start.parent_workflow_info.namespace,
243+
run_id=start.parent_workflow_info.run_id,
244+
workflow_id=start.parent_workflow_info.workflow_id,
245+
)
237246
info = temporalio.workflow.Info(
238247
attempt=start.attempt,
239248
continued_run_id=start.continued_from_execution_run_id or None,
@@ -242,6 +251,7 @@ async def _create_workflow_instance(
242251
if start.HasField("workflow_execution_timeout")
243252
else None,
244253
namespace=self._namespace,
254+
parent=parent,
245255
run_id=act.run_id,
246256
run_timeout=start.workflow_run_timeout.ToTimedelta()
247257
if start.HasField("workflow_run_timeout")

temporalio/worker/workflow_instance.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def activate(
213213
self._current_completion = (
214214
temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion()
215215
)
216-
self._current_completion.run_id = act.run_id
217216
self._current_completion.successful.SetInParent()
217+
self._current_activation_error: Optional[Exception] = None
218218
self._time = act.timestamp.ToMicroseconds() / 1e6
219219
self._is_replaying = act.is_replaying
220220

@@ -641,6 +641,15 @@ def workflow_get_query_handler(self, name: Optional[str]) -> Optional[Callable]:
641641
return cast(Optional[Callable], defn)
642642
# Curry instance on the definition function since that represents an
643643
# unbound method
644+
if inspect.iscoroutinefunction(defn.fn):
645+
# We cannot use functools.partial here because in <= 3.7 that isn't
646+
# considered an inspect.iscoroutinefunction
647+
fn = cast(Callable[..., Awaitable[Any]], defn.fn)
648+
649+
async def with_object(*args, **kwargs) -> Any:
650+
return await fn(self._object, *args, **kwargs)
651+
652+
return with_object
644653
return partial(defn.fn, self._object)
645654

646655
def workflow_get_signal_handler(self, name: Optional[str]) -> Optional[Callable]:
@@ -650,6 +659,15 @@ def workflow_get_signal_handler(self, name: Optional[str]) -> Optional[Callable]
650659
return cast(Optional[Callable], defn)
651660
# Curry instance on the definition function since that represents an
652661
# unbound method
662+
if inspect.iscoroutinefunction(defn.fn):
663+
# We cannot use functools.partial here because in <= 3.7 that isn't
664+
# considered an inspect.iscoroutinefunction
665+
fn = cast(Callable[..., Awaitable[Any]], defn.fn)
666+
667+
async def with_object(*args, **kwargs) -> Any:
668+
return await fn(self._object, *args, **kwargs)
669+
670+
return with_object
653671
return partial(defn.fn, self._object)
654672

655673
def workflow_info(self) -> temporalio.workflow.Info:
@@ -1008,6 +1026,11 @@ def _run_once(self) -> None:
10081026
handle = self._ready.popleft()
10091027
handle._run()
10101028

1029+
# Must throw here. Only really set inside
1030+
# _run_top_level_workflow_function.
1031+
if self._current_activation_error:
1032+
raise self._current_activation_error
1033+
10111034
# Check conditions which may add to the ready list
10121035
self._conditions[:] = [
10131036
t for t in self._conditions if not self._check_condition(*t)
@@ -1018,8 +1041,6 @@ def _run_once(self) -> None:
10181041
# This is used for the primary workflow function and signal handlers in
10191042
# order to apply common exception handling to each
10201043
async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
1021-
# We intentionally don't catch all errors, instead we bubble those
1022-
# out as task failures
10231044
try:
10241045
await coro
10251046
except _ContinueAsNewError as err:
@@ -1048,6 +1069,8 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
10481069
self._payload_converter,
10491070
command.fail_workflow_execution.failure,
10501071
)
1072+
except Exception as err:
1073+
self._current_activation_error = err
10511074

10521075
async def _signal_external_workflow(
10531076
self,
@@ -1514,10 +1537,14 @@ def _resolve_success(self, result: Any) -> None:
15141537
self._result_fut.set_result(result)
15151538

15161539
def _resolve_failure(self, err: Exception) -> None:
1517-
if not self._start_fut.done():
1540+
if self._start_fut.done():
1541+
# We intentionally let this error if already done
1542+
self._result_fut.set_exception(err)
1543+
else:
15181544
self._start_fut.set_exception(err)
1519-
# We intentionally let this error if already done
1520-
self._result_fut.set_exception(err)
1545+
# Set the result as none to avoid Python warning about unhandled
1546+
# future
1547+
self._result_fut.set_result(None)
15211548

15221549
def _apply_start_command(
15231550
self,

temporalio/workflow.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class Info:
265265
cron_schedule: Optional[str]
266266
execution_timeout: Optional[timedelta]
267267
namespace: str
268+
parent: Optional[ParentInfo]
268269
run_id: str
269270
run_timeout: Optional[timedelta]
270271
start_time: datetime
@@ -274,9 +275,6 @@ class Info:
274275
workflow_type: str
275276

276277
# TODO(cretz): memo
277-
# TODO(cretz): parent_namespace
278-
# TODO(cretz): parent_run_id
279-
# TODO(cretz): parent_workflow_id
280278
# TODO(cretz): retry_policy
281279
# TODO(cretz): search_attributes
282280

@@ -288,6 +286,15 @@ def _logger_details(self) -> Mapping[str, Any]:
288286
}
289287

290288

289+
@dataclass(frozen=True)
290+
class ParentInfo:
291+
"""Information about the parent workflow."""
292+
293+
namespace: str
294+
run_id: str
295+
workflow_id: str
296+
297+
291298
class _Runtime(ABC):
292299
@staticmethod
293300
def current() -> _Runtime:
@@ -548,7 +555,11 @@ class _Definition:
548555

549556
@staticmethod
550557
def from_class(cls: Type) -> Optional[_Definition]:
551-
return getattr(cls, "__temporal_workflow_definition", None)
558+
# We make sure to only return it if it's on _this_ class
559+
defn = getattr(cls, "__temporal_workflow_definition", None)
560+
if defn and defn.cls == cls:
561+
return defn
562+
return None
552563

553564
@staticmethod
554565
def must_from_class(cls: Type) -> _Definition:
@@ -576,7 +587,8 @@ def must_from_run_fn(fn: Callable[..., Awaitable[Any]]) -> _Definition:
576587

577588
@staticmethod
578589
def _apply_to_class(cls: Type, workflow_name: str) -> None:
579-
if hasattr(cls, "__temporal_workflow_definition"):
590+
# Check it's not being doubly applied
591+
if _Definition.from_class(cls):
580592
raise ValueError("Class already contains workflow definition")
581593
issues: List[str] = []
582594

0 commit comments

Comments
 (0)