Skip to content

Create commands after payload conversion #591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 63 additions & 39 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,6 @@ def _apply_query_workflow(
) -> None:
# Wrap entire bunch of work in a task
async def run_query() -> None:
command = self._add_command()
command.respond_to_query.query_id = job.query_id
try:
with self._as_read_only():
# Named query or dynamic
Expand Down Expand Up @@ -632,11 +630,13 @@ async def run_query() -> None:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
)
command.respond_to_query.succeeded.response.CopyFrom(
result_payloads[0]
)
command = self._add_command()
command.respond_to_query.query_id = job.query_id
command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0])
Copy link
Contributor

@dandavison dandavison Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: in general my view is that try blocks should enclose as little as possible. it seems that these 3 lines could go after?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because this needs to only be called on the success path, there is a failure path that does another thing, there is no need for any common code after the try

except Exception as err:
try:
command = self._add_command()
command.respond_to_query.query_id = job.query_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: I assume these lines cannot raise an exception; I would put them outside the try.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were outside of try in the previous code, but the purpose of this fix is to not invoke these until it's time to make the command because if you call this before an exception is raised, you'll have an unfinished command that was added

self._failure_converter.to_failure(
err,
self._payload_converter,
Expand Down Expand Up @@ -1427,7 +1427,7 @@ async def run_activity() -> Any:
await asyncio.sleep(
err.backoff.backoff_duration.ToTimedelta().total_seconds()
)
handle._apply_schedule_command(self._add_command(), err.backoff)
handle._apply_schedule_command(err.backoff)
# We have to put the handle back on the pending activity
# dict with its new seq
self._pending_activities[handle._seq] = handle
Expand All @@ -1437,35 +1437,41 @@ async def run_activity() -> Any:

# Create the handle and set as pending
handle = _ActivityHandle(self, input, run_activity())
handle._apply_schedule_command(self._add_command())
handle._apply_schedule_command()
self._pending_activities[handle._seq] = handle
return handle

async def _outbound_signal_child_workflow(
self, input: SignalChildWorkflowInput
) -> None:
payloads = (
self._payload_converter.to_payloads(input.args) if input.args else None
)
command = self._add_command()
v = command.signal_external_workflow_execution
v.child_workflow_id = input.child_workflow_id
v.signal_name = input.signal
if input.args:
v.args.extend(self._payload_converter.to_payloads(input.args))
if payloads:
v.args.extend(payloads)
if input.headers:
temporalio.common._apply_headers(input.headers, v.headers)
await self._signal_external_workflow(command)

async def _outbound_signal_external_workflow(
self, input: SignalExternalWorkflowInput
) -> None:
payloads = (
self._payload_converter.to_payloads(input.args) if input.args else None
)
command = self._add_command()
v = command.signal_external_workflow_execution
v.workflow_execution.namespace = input.namespace
v.workflow_execution.workflow_id = input.workflow_id
if input.workflow_run_id:
v.workflow_execution.run_id = input.workflow_run_id
v.signal_name = input.signal
if input.args:
v.args.extend(self._payload_converter.to_payloads(input.args))
if payloads:
v.args.extend(payloads)
if input.headers:
temporalio.common._apply_headers(input.headers, v.headers)
await self._signal_external_workflow(command)
Expand Down Expand Up @@ -1510,7 +1516,7 @@ async def run_child() -> Any:
handle = _ChildWorkflowHandle(
self, self._next_seq("child_workflow"), input, run_child()
)
handle._apply_start_command(self._add_command())
handle._apply_start_command()
self._pending_child_workflows[handle._seq] = handle

# Wait on start before returning
Expand Down Expand Up @@ -1761,7 +1767,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
await coro
except _ContinueAsNewError as err:
logger.debug("Workflow requested continue as new")
err._apply_command(self._add_command())
err._apply_command()
except (Exception, asyncio.CancelledError) as err:
# During tear down we can ignore exceptions. Technically the
# command-adding done later would throw a not-in-workflow exception
Expand All @@ -1776,7 +1782,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
# Handle continue as new
if isinstance(err, _ContinueAsNewError):
logger.debug("Workflow requested continue as new")
err._apply_command(self._add_command())
err._apply_command()
return

logger.debug(
Expand Down Expand Up @@ -2261,11 +2267,18 @@ def _resolve_backoff(

def _apply_schedule_command(
self,
command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
local_backoff: Optional[
temporalio.bridge.proto.activity_result.DoBackoff
] = None,
) -> None:
# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)

command = self._instance._add_command()
# TODO(cretz): Why can't MyPy infer this?
v: Union[
temporalio.bridge.proto.workflow_commands.ScheduleActivity,
Expand All @@ -2280,10 +2293,8 @@ def _apply_schedule_command(
v.activity_type = self._input.activity
if self._input.headers:
temporalio.common._apply_headers(self._input.headers, v.headers)
if self._input.args:
v.arguments.extend(
self._instance._payload_converter.to_payloads(self._input.args)
)
if payloads:
v.arguments.extend(payloads)
if self._input.schedule_to_close_timeout:
v.schedule_to_close_timeout.FromTimedelta(
self._input.schedule_to_close_timeout
Expand Down Expand Up @@ -2403,20 +2414,23 @@ def _resolve_failure(self, err: BaseException) -> None:
# future
self._result_fut.set_result(None)

def _apply_start_command(
self,
command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
) -> None:
def _apply_start_command(self) -> None:
# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)

command = self._instance._add_command()
v = command.start_child_workflow_execution
v.seq = self._seq
v.namespace = self._instance._info.namespace
v.workflow_id = self._input.id
v.workflow_type = self._input.workflow
v.task_queue = self._input.task_queue or self._instance._info.task_queue
if self._input.args:
v.input.extend(
self._instance._payload_converter.to_payloads(self._input.args)
)
if payloads:
v.input.extend(payloads)
if self._input.execution_timeout:
v.workflow_execution_timeout.FromTimedelta(self._input.execution_timeout)
if self._input.run_timeout:
Expand Down Expand Up @@ -2520,19 +2534,31 @@ def __init__(
self._instance = instance
self._input = input

def _apply_command(
self, command: temporalio.bridge.proto.workflow_commands.WorkflowCommand
) -> None:
def _apply_command(self) -> None:
# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)
memo_payloads = (
{
k: self._instance._payload_converter.to_payloads([val])[0]
for k, val in self._input.memo.items()
}
if self._input.memo
else None
)

command = self._instance._add_command()
v = command.continue_as_new_workflow_execution
v.SetInParent()
if self._input.workflow:
v.workflow_type = self._input.workflow
if self._input.task_queue:
v.task_queue = self._input.task_queue
if self._input.args:
v.arguments.extend(
self._instance._payload_converter.to_payloads(self._input.args)
)
if payloads:
v.arguments.extend(payloads)
if self._input.run_timeout:
v.workflow_run_timeout.FromTimedelta(self._input.run_timeout)
if self._input.task_timeout:
Expand All @@ -2541,11 +2567,9 @@ def _apply_command(
temporalio.common._apply_headers(self._input.headers, v.headers)
if self._input.retry_policy:
self._input.retry_policy.apply_to_proto(v.retry_policy)
if self._input.memo:
for k, val in self._input.memo.items():
v.memo[k].CopyFrom(
self._instance._payload_converter.to_payloads([val])[0]
)
if memo_payloads:
for k, val in memo_payloads.items():
v.memo[k].CopyFrom(val)
if self._input.search_attributes:
_encode_search_attributes(
self._input.search_attributes, v.search_attributes
Expand Down
56 changes: 51 additions & 5 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3350,15 +3350,27 @@ async def test_workflow_optional_param(client: Client):


class ExceptionRaisingPayloadConverter(DefaultPayloadConverter):
bad_str = "bad-payload-str"
bad_outbound_str = "bad-outbound-payload-str"
bad_inbound_str = "bad-inbound-payload-str"

def to_payloads(self, values: Sequence[Any]) -> List[Payload]:
if any(
value == ExceptionRaisingPayloadConverter.bad_outbound_str
for value in values
):
raise ApplicationError("Intentional outbound converter failure")
return super().to_payloads(values)

def from_payloads(
self, payloads: Sequence[Payload], type_hints: Optional[List] = None
) -> List[Any]:
# Check if any payloads contain the bad data
for payload in payloads:
if ExceptionRaisingPayloadConverter.bad_str.encode() in payload.data:
raise ApplicationError("Intentional converter failure")
if (
ExceptionRaisingPayloadConverter.bad_inbound_str.encode()
in payload.data
):
raise ApplicationError("Intentional inbound converter failure")
return super().from_payloads(payloads, type_hints)


Expand All @@ -3383,12 +3395,46 @@ async def test_exception_raising_converter_param(client: Client):
with pytest.raises(WorkflowFailureError) as err:
await client.execute_workflow(
ExceptionRaisingConverterWorkflow.run,
ExceptionRaisingPayloadConverter.bad_str,
ExceptionRaisingPayloadConverter.bad_inbound_str,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
assert isinstance(err.value.cause, ApplicationError)
assert "Intentional converter failure" in str(err.value.cause)
assert "Intentional inbound converter failure" in str(err.value.cause)


@workflow.defn
class ActivityOutboundConversionFailureWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.execute_activity(
"some-activity",
ExceptionRaisingPayloadConverter.bad_outbound_str,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_workflow_activity_outbound_conversion_failure(client: Client):
# This test used to fail because we created commands _before_ we attempted
# to convert the arguments thereby causing half-built commands to get sent
# to the server.

# Clone the client but change the data converter to use our converter
config = client.config()
config["data_converter"] = dataclasses.replace(
config["data_converter"],
payload_converter_class=ExceptionRaisingPayloadConverter,
)
client = Client(**config)
async with new_worker(client, ActivityOutboundConversionFailureWorkflow) as worker:
with pytest.raises(WorkflowFailureError) as err:
await client.execute_workflow(
ActivityOutboundConversionFailureWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
assert isinstance(err.value.cause, ApplicationError)
assert "Intentional outbound converter failure" in str(err.value.cause)


@dataclass
Expand Down
Loading