Skip to content

Access current update info with ID inside update handler #544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,11 @@ def _apply_do_update(
# inside the task, since the update may not be defined until after we have started the workflow - for example
# if an update is in the first WFT & is also registered dynamically at the top of workflow code.
async def run_update() -> None:
# Set the current update for the life of this task
temporalio.workflow._set_current_update_info(
temporalio.workflow.UpdateInfo(id=job.id, name=job.name)
)

command = self._add_command()
command.update_response.protocol_instance_id = job.protocol_instance_id
past_validation = False
Expand Down
37 changes: 37 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import contextvars
import inspect
import logging
import threading
Expand Down Expand Up @@ -424,6 +425,17 @@ class ParentInfo:
workflow_id: str


@dataclass(frozen=True)
class UpdateInfo:
"""Information about a workflow update."""

id: str
Copy link
Member Author

@cretz cretz Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use workflow_id to differentiate from all the other _id fields in Info, should we call this update_id? A simple id I think works here but can change if needed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Go we did ID, but update_id is also fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I think id works best here.

"""Update ID."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray string (should use comment, not string, it it's explaining the field).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For dataclasses we have use docstrings like this in the past. For example, see https://python.temporal.io/temporalio.client.WorkflowExecution.html.


name: str
"""Update type name."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This meaning is unclear to me. Is name the function name (unscoped by the class name) unless it's overridden with the attribute?

Copy link
Member Author

@cretz cretz Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is also used in the decorator and the client calls. The meaning of update ID above it may be unclear to users too.

The "update name" is a Temporal concept like "workflow type" in "Info.workflow_type". Ideally the general documentation would detail this, but if necessary we can update all of the Python docs for signal/update/query name, all the workflow info stuff, and all of that if we want. But this one field isn't unique (it applies to the field above it, where "name" is used elsewhere, all the other fields we delegate to docs to describe, etc).



class _Runtime(ABC):
@staticmethod
def current() -> _Runtime:
Expand Down Expand Up @@ -654,6 +666,31 @@ async def workflow_wait_condition(
...


_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
"__temporal_current_update_info"
)


def _set_current_update_info(info: UpdateInfo) -> None:
_current_update_info.set(info)


def current_update_info() -> Optional[UpdateInfo]:
"""Info for the current update if any.

This is powered by :py:mod:`contextvars` so it is only valid within the
update handler and coroutines/tasks it has started.

.. warning::
This API is experimental

Returns:
Info for the current update handler the code calling this is executing
within if any.
"""
return _current_update_info.get(None)


def deprecate_patch(id: str) -> None:
"""Mark a patch as deprecated.

Expand Down
75 changes: 75 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4927,3 +4927,78 @@ async def test_workflow_wait_utility(client: Client):
task_queue=worker.task_queue,
)
assert len(result) == 10


@workflow.defn
class CurrentUpdateWorkflow:
def __init__(self) -> None:
self._pending_get_update_id_tasks: List[asyncio.Task[str]] = []

@workflow.run
async def run(self) -> List[str]:
# Confirm no update info
assert not workflow.current_update_info()

# Wait for all tasks to come in, then return the full set
await workflow.wait_condition(
lambda: len(self._pending_get_update_id_tasks) == 5
)
assert not workflow.current_update_info()
return list(await asyncio.gather(*self._pending_get_update_id_tasks))

@workflow.update
async def do_update(self) -> str:
# Check that simple helper awaited has the ID
info = workflow.current_update_info()
assert info
assert info.name == "do_update"
assert info.id == await self.get_update_id()

# Also schedule the task and wait for it in the main workflow to confirm
# it still gets the update ID
self._pending_get_update_id_tasks.append(
asyncio.create_task(self.get_update_id())
)

# Re-fetch and return
info = workflow.current_update_info()
assert info
return info.id

@do_update.validator
def do_update_validator(self) -> None:
info = workflow.current_update_info()
assert info
assert info.name == "do_update"

async def get_update_id(self) -> str:
await asyncio.sleep(0.01)
info = workflow.current_update_info()
assert info
return info.id


async def test_workflow_current_update(client: Client, env: WorkflowEnvironment):
if env.supports_time_skipping:
pytest.skip(
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
)
async with new_worker(client, CurrentUpdateWorkflow) as worker:
handle = await client.start_workflow(
CurrentUpdateWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
update_ids = await asyncio.gather(
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update1"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update2"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update3"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update4"),
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update5"),
)
assert {"update1", "update2", "update3", "update4", "update5"} == set(
update_ids
)
assert {"update1", "update2", "update3", "update4", "update5"} == set(
await handle.result()
)
Loading