Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 44 additions & 48 deletions temporalio/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ async def _run_activities(self) -> None:

if task.HasField("start"):
# Cancelled event and sync field will be updated inside
# _run_activity when the activity function is obtained
activity = _RunningActivity()
# _run_activity when the activity function is obtained. Max
# size of 1000 should be plenty for the heartbeat queue.
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
activity.task = asyncio.create_task(
self._run_activity(task.task_token, task.start, activity)
)
Expand Down Expand Up @@ -409,22 +410,27 @@ def _heartbeat_activity(self, task_token: bytes, *details: Any) -> None:
logger = temporalio.activity.logger
activity = self._running_activities.get(task_token)
if activity and not activity.done:
# Just set as next pending if one is already running
coro = self._heartbeat_activity_async(
logger, activity, task_token, *details
# Put on queue and schedule a task. We will let the queue-full error
# be thrown here
activity.pending_heartbeats.put_nowait(details)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably wrap this error with a meaningful error so users can understand this happens because they heartbeat too fast.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered it, but I was under the impression wrapping was less common in Python. I can do so on my next PR (which will be moving these files to a completely separate place).

activity.last_heartbeat_task = asyncio.create_task(
self._heartbeat_activity_async(logger, activity, task_token)
)
if activity.current_heartbeat_task:
activity.pending_heartbeat = coro
else:
activity.current_heartbeat_task = asyncio.create_task(coro)

async def _heartbeat_activity_async(
self,
logger: logging.LoggerAdapter,
activity: _RunningActivity,
task_token: bytes,
*details: Any,
) -> None:
# Drain the queue, only taking the last value to actually heartbeat
details: Optional[Iterable[Any]] = None
while not activity.pending_heartbeats.empty():
details = activity.pending_heartbeats.get_nowait()
if details is None:
return

# Perform the heartbeat
try:
heartbeat = temporalio.bridge.proto.ActivityHeartbeat(task_token=task_token)
if details:
Expand All @@ -437,16 +443,7 @@ async def _heartbeat_activity_async(
)
logger.debug("Recording heartbeat with details %s", details)
self._bridge_worker.record_activity_heartbeat(heartbeat)
# If there is one pending, schedule it
if activity.pending_heartbeat:
activity.current_heartbeat_task = asyncio.create_task(
activity.pending_heartbeat
)
activity.pending_heartbeat = None
else:
activity.current_heartbeat_task = None
except Exception as err:
activity.current_heartbeat_task = None
# If the activity is done, nothing we can do but log
if activity.done:
logger.exception(
Expand Down Expand Up @@ -696,12 +693,12 @@ async def _run_activity(

# Do final completion
try:
# We mark the activity as done and let the currently running (and next
# pending) heartbeat task finish
# We mark the activity as done and let the currently running
# heartbeat task finish
running_activity.done = True
while running_activity.current_heartbeat_task:
if running_activity.last_heartbeat_task:
try:
await running_activity.current_heartbeat_task
await running_activity.last_heartbeat_task
except:
# Should never happen because it's trapped in-task
temporalio.activity.logger.exception(
Expand Down Expand Up @@ -749,12 +746,12 @@ class _ActivityDefinition:

@dataclass
class _RunningActivity:
pending_heartbeats: asyncio.Queue[Iterable[Any]]
# Most of these optional values are set before use
info: Optional[temporalio.activity.Info] = None
task: Optional[asyncio.Task] = None
cancelled_event: Optional[temporalio.activity._CompositeEvent] = None
pending_heartbeat: Optional[Coroutine] = None
current_heartbeat_task: Optional[asyncio.Task] = None
last_heartbeat_task: Optional[asyncio.Task] = None
sync: bool = False
done: bool = False
cancelled_by_request: bool = False
Expand Down Expand Up @@ -895,19 +892,16 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
# loop (even though it's sync). So we need a call that puts the
# context back on the activity and calls heartbeat, then another
# call schedules it.
def heartbeat_with_context(*details: Any) -> None:
async def heartbeat_with_context(*details: Any) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this become async?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I could use asyncio.run_coroutine_threadsafe which returns a future instead of loop.call_soon_threadsafe (which it uses internally anyways). It also helps enforce that an event loop is available.

temporalio.activity._Context.set(ctx)
assert orig_heartbeat
orig_heartbeat(*details)

def thread_safe_heartbeat(*details: Any) -> None:
# TODO(cretz): Final heartbeat can be flaky if we don't wait on
# result here, but waiting on result of
# asyncio.run_coroutine_threadsafe times out in rare cases.
# Need more investigation: https://github.com/temporalio/sdk-python/issues/12
loop.call_soon_threadsafe(heartbeat_with_context, *details)

ctx.heartbeat = thread_safe_heartbeat
# Invoke the async heartbeat waiting a max of 10 seconds for
# accepting
ctx.heartbeat = lambda *details: asyncio.run_coroutine_threadsafe(
heartbeat_with_context(*details), loop
).result(10)

# For heartbeats, we use the existing heartbeat callable for thread
# pool executors or a multiprocessing queue for others
Expand All @@ -917,7 +911,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
# Should always be present in worker, pre-checked on init
shared_manager = input._worker._config["shared_state_manager"]
assert shared_manager
heartbeat = shared_manager.register_heartbeater(
heartbeat = await shared_manager.register_heartbeater(
info.task_token, ctx.heartbeat
)

Expand All @@ -935,7 +929,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
)
finally:
if shared_manager:
shared_manager.unregister_heartbeater(info.task_token)
await shared_manager.unregister_heartbeater(info.task_token)

# Otherwise for async activity, just run
return await input.fn(*input.args)
Expand Down Expand Up @@ -1032,7 +1026,7 @@ def new_event(self) -> threading.Event:
raise NotImplementedError

@abstractmethod
def register_heartbeater(
async def register_heartbeater(
self, task_token: bytes, heartbeat: Callable[..., None]
) -> SharedHeartbeatSender:
"""Register a heartbeat function.
Expand All @@ -1048,7 +1042,7 @@ def register_heartbeater(
raise NotImplementedError

@abstractmethod
def unregister_heartbeater(self, task_token: bytes) -> None:
async def unregister_heartbeater(self, task_token: bytes) -> None:
"""Unregisters a previously registered heartbeater for the task
token. This should also flush any pending heartbeats.
"""
Expand Down Expand Up @@ -1084,12 +1078,12 @@ def __init__(
1000
)
self._heartbeats: Dict[bytes, Callable[..., None]] = {}
self._heartbeat_completions: Dict[bytes, Callable[[], None]] = {}
self._heartbeat_completions: Dict[bytes, Callable] = {}

def new_event(self) -> threading.Event:
return self._mgr.Event()

def register_heartbeater(
async def register_heartbeater(
self, task_token: bytes, heartbeat: Callable[..., None]
) -> SharedHeartbeatSender:
self._heartbeats[task_token] = heartbeat
Expand All @@ -1098,17 +1092,19 @@ def register_heartbeater(
self._queue_poller_executor.submit(self._heartbeat_processor)
return _MultiprocessingSharedHeartbeatSender(self._heartbeat_queue)

def unregister_heartbeater(self, task_token: bytes) -> None:
# Put a completion on the queue and wait for it to happen
flush_complete = threading.Event()
self._heartbeat_completions[task_token] = flush_complete.set
async def unregister_heartbeater(self, task_token: bytes) -> None:
# Put a callback on the queue and wait for it to happen
loop = asyncio.get_running_loop()
finish_event = asyncio.Event()
self._heartbeat_completions[task_token] = lambda: loop.call_soon_threadsafe(
finish_event.set
)
try:
# 30 seconds to put complete, 30 to get notified should be plenty
# We only give the queue a few seconds to have enough room
self._heartbeat_queue.put(
(task_token, _multiprocess_heartbeat_complete), True, 30
(task_token, _multiprocess_heartbeat_complete), True, 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: instead of this string you can use object() to create a unique empty object.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original version had that, but I didn't want to change the type of the queue

)
if not flush_complete.wait(30):
raise RuntimeError("Timeout waiting for heartbeat flush")
await finish_event.wait()
finally:
del self._heartbeat_completions[task_token]

Expand Down
10 changes: 4 additions & 6 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ async def test_sync_activity_thread_cancel(
):
def wait_cancel() -> str:
while not temporalio.activity.is_cancelled():
temporalio.activity.heartbeat()
time.sleep(1)
temporalio.activity.heartbeat()
return "Cancelled"

with concurrent.futures.ThreadPoolExecutor() as executor:
Expand All @@ -228,16 +228,16 @@ def wait_cancel() -> str:
wait_cancel,
cancel_after_ms=100,
wait_for_cancellation=True,
heartbeat_timeout_ms=30000,
heartbeat_timeout_ms=3000,
worker_config={"activity_executor": executor},
)
assert result.result == "Cancelled"


def picklable_activity_wait_cancel() -> str:
while not temporalio.activity.is_cancelled():
temporalio.activity.heartbeat()
time.sleep(1)
temporalio.activity.heartbeat()
return "Cancelled"


Expand All @@ -251,7 +251,7 @@ async def test_sync_activity_process_cancel(
picklable_activity_wait_cancel,
cancel_after_ms=100,
wait_for_cancellation=True,
heartbeat_timeout_ms=30000,
heartbeat_timeout_ms=3000,
worker_config={"activity_executor": executor},
)
assert result.result == "Cancelled"
Expand Down Expand Up @@ -430,8 +430,6 @@ def picklable_heartbeat_details_activity() -> str:
some_list.append(f"attempt: {info.attempt}")
temporalio.activity.logger.debug("Heartbeating with value: %s", some_list)
temporalio.activity.heartbeat(some_list)
# TODO(cretz): Remove when we fix multiprocess heartbeats
time.sleep(1)
if len(some_list) < 2:
raise RuntimeError(f"Try again, list contains: {some_list}")
return ", ".join(some_list)
Expand Down