Skip to content

Commit cf54981

Browse files
committed
feat(worker): return workflow completed result on decision processing
Signed-off-by: Shijie Sheng <liouvetren@gmail.com>
1 parent 72cb7c7 commit cf54981

File tree

5 files changed

+142
-238
lines changed

5 files changed

+142
-238
lines changed

cadence/_internal/workflow/context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from contextlib import contextmanager
12
from datetime import timedelta
23
from math import ceil
3-
from typing import Optional, Any, Unpack, Type, cast
4+
from typing import Iterator, Optional, Any, Unpack, Type, cast
45

56
from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
67
from cadence._internal.workflow.decisions_helper import DecisionsHelper
@@ -15,13 +16,12 @@ class Context(WorkflowContext):
1516
def __init__(
1617
self,
1718
info: WorkflowInfo,
18-
decision_helper: DecisionsHelper,
1919
decision_manager: DecisionManager,
2020
):
2121
self._info = info
2222
self._replay_mode = True
2323
self._replay_current_time_milliseconds: Optional[int] = None
24-
self._decision_helper = decision_helper
24+
self._decision_helper = DecisionsHelper()
2525
self._decision_manager = decision_manager
2626

2727
def info(self) -> WorkflowInfo:
@@ -110,6 +110,12 @@ def get_replay_current_time_milliseconds(self) -> Optional[int]:
110110
"""Get the current replay time in milliseconds."""
111111
return self._replay_current_time_milliseconds
112112

113+
@contextmanager
114+
def _activate(self) -> Iterator["Context"]:
115+
token = WorkflowContext._var.set(self)
116+
yield self
117+
WorkflowContext._var.reset(token)
118+
113119

114120
def _round_to_nearest_second(delta: timedelta) -> timedelta:
115121
return timedelta(seconds=ceil(delta.total_seconds()))
Lines changed: 95 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
import asyncio
21
import logging
32
from dataclasses import dataclass
4-
from typing import Any, Optional
53

64
from cadence._internal.workflow.context import Context
7-
from cadence._internal.workflow.decisions_helper import DecisionsHelper
85
from cadence._internal.workflow.decision_events_iterator import DecisionEventsIterator
9-
from cadence._internal.workflow.deterministic_event_loop import DeterministicEventLoop
106
from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
117
from cadence._internal.workflow.workflow_intance import WorkflowInstance
12-
from cadence.api.v1.decision_pb2 import Decision
8+
from cadence.api.v1.common_pb2 import Payload
9+
from cadence.api.v1.decision_pb2 import (
10+
CompleteWorkflowExecutionDecisionAttributes,
11+
Decision,
12+
)
13+
from cadence.api.v1.history_pb2 import WorkflowExecutionStartedEventAttributes
1314
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
1415
from cadence.workflow import WorkflowDefinition, WorkflowInfo
1516

@@ -23,12 +24,13 @@ class DecisionResult:
2324

2425
class WorkflowEngine:
2526
def __init__(self, info: WorkflowInfo, workflow_definition: WorkflowDefinition):
26-
self._workflow_instance = WorkflowInstance(workflow_definition)
27-
self._decision_manager = DecisionManager()
28-
self._decisions_helper = DecisionsHelper()
29-
self._context = Context(info, self._decisions_helper, self._decision_manager)
30-
self._loop = DeterministicEventLoop()
31-
self._task: Optional[asyncio.Task] = None
27+
self._workflow_instance = WorkflowInstance(
28+
workflow_definition, info.data_converter
29+
)
30+
self._decision_manager = (
31+
DecisionManager()
32+
) # TODO: remove this stateful object and use the context instead
33+
self._context = Context(info, self._decision_manager)
3234

3335
def process_decision(
3436
self, decision_task: PollForDecisionTaskResponse
@@ -46,54 +48,58 @@ def process_decision(
4648
DecisionResult containing the list of decisions
4749
"""
4850
try:
49-
# Log decision task processing start with full context (matches Java ReplayDecisionTaskHandler)
50-
logger.info(
51-
"Processing decision task for workflow",
52-
extra={
53-
"workflow_type": self._context.info().workflow_type,
54-
"workflow_id": self._context.info().workflow_id,
55-
"run_id": self._context.info().workflow_run_id,
56-
"started_event_id": decision_task.started_event_id,
57-
"attempt": decision_task.attempt,
58-
},
59-
)
60-
6151
# Activate workflow context for the entire decision processing
62-
with self._context._activate():
52+
with self._context._activate() as ctx:
53+
# Log decision task processing start with full context (matches Java ReplayDecisionTaskHandler)
54+
logger.info(
55+
"Processing decision task for workflow",
56+
extra={
57+
"workflow_type": ctx.info().workflow_type,
58+
"workflow_id": ctx.info().workflow_id,
59+
"run_id": ctx.info().workflow_run_id,
60+
"started_event_id": decision_task.started_event_id,
61+
"attempt": decision_task.attempt,
62+
},
63+
)
64+
6365
# Create DecisionEventsIterator for structured event processing
6466
events_iterator = DecisionEventsIterator(
65-
decision_task, self._context.info().workflow_events
67+
decision_task, ctx.info().workflow_events
6668
)
6769

6870
# Process decision events using iterator-driven approach
69-
self._process_decision_events(events_iterator, decision_task)
71+
self._process_decision_events(ctx, events_iterator, decision_task)
7072

7173
# Collect all pending decisions from state machines
7274
decisions = self._decision_manager.collect_pending_decisions()
7375

74-
# Log decision task completion with metrics (matches Java ReplayDecisionTaskHandler)
75-
logger.debug(
76-
"Decision task completed",
77-
extra={
78-
"workflow_type": self._context.info().workflow_type,
79-
"workflow_id": self._context.info().workflow_id,
80-
"run_id": self._context.info().workflow_run_id,
81-
"started_event_id": decision_task.started_event_id,
82-
"decisions_count": len(decisions),
83-
"replay_mode": self._context.is_replay_mode(),
84-
},
85-
)
76+
# complete workflow if it is done
77+
try:
78+
if self._workflow_instance.is_done():
79+
result = self._workflow_instance.get_result()
80+
decisions.append(
81+
Decision(
82+
complete_workflow_execution_decision_attributes=CompleteWorkflowExecutionDecisionAttributes(
83+
result=result
84+
)
85+
)
86+
)
87+
return DecisionResult(decisions=decisions)
8688

87-
return DecisionResult(decisions=decisions)
89+
except Exception:
90+
# TODO: handle CancellationError
91+
# TODO: handle WorkflowError
92+
# TODO: handle unknown error, fail decision task and try again instead of breaking the engine
93+
raise
8894

8995
except Exception as e:
9096
# Log decision task failure with full context (matches Java ReplayDecisionTaskHandler)
9197
logger.error(
9298
"Decision task processing failed",
9399
extra={
94-
"workflow_type": self._context.info().workflow_type,
95-
"workflow_id": self._context.info().workflow_id,
96-
"run_id": self._context.info().workflow_run_id,
100+
"workflow_type": ctx.info().workflow_type,
101+
"workflow_id": ctx.info().workflow_id,
102+
"run_id": ctx.info().workflow_run_id,
97103
"started_event_id": decision_task.started_event_id,
98104
"attempt": decision_task.attempt,
99105
"error_type": type(e).__name__,
@@ -104,10 +110,11 @@ def process_decision(
104110
raise
105111

106112
def is_done(self) -> bool:
107-
return self._task is not None and self._task.done()
113+
return self._workflow_instance.is_done()
108114

109115
def _process_decision_events(
110116
self,
117+
ctx: Context,
111118
events_iterator: DecisionEventsIterator,
112119
decision_task: PollForDecisionTaskResponse,
113120
) -> None:
@@ -131,7 +138,7 @@ def _process_decision_events(
131138
logger.debug(
132139
"Processing decision events batch",
133140
extra={
134-
"workflow_id": self._context.info().workflow_id,
141+
"workflow_id": ctx.info().workflow_id,
135142
"events_count": len(decision_events.get_events()),
136143
"markers_count": len(decision_events.get_markers()),
137144
"replay_mode": decision_events.is_replay(),
@@ -140,109 +147,55 @@ def _process_decision_events(
140147
)
141148

142149
# Update context with replay information
143-
self._context.set_replay_mode(decision_events.is_replay())
150+
ctx.set_replay_mode(decision_events.is_replay())
144151
if decision_events.replay_current_time_milliseconds:
145-
self._context.set_replay_current_time_milliseconds(
152+
ctx.set_replay_current_time_milliseconds(
146153
decision_events.replay_current_time_milliseconds
147154
)
148155

149-
# Phase 1: Process markers first for deterministic replay
150-
for marker_event in decision_events.get_markers():
151-
try:
152-
logger.debug(
153-
"Processing marker event",
154-
extra={
155-
"workflow_id": self._context.info().workflow_id,
156-
"marker_name": getattr(
157-
marker_event, "marker_name", "unknown"
158-
),
159-
"event_id": getattr(marker_event, "event_id", None),
160-
"replay_mode": self._context.is_replay_mode(),
161-
},
162-
)
163-
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
164-
self._decision_manager.handle_history_event(marker_event)
165-
except Exception as e:
166-
# Warning for unexpected markers (matches Java ClockDecisionContext)
167-
logger.warning(
168-
"Unexpected marker event encountered",
169-
extra={
170-
"workflow_id": self._context.info().workflow_id,
171-
"marker_name": getattr(
172-
marker_event, "marker_name", "unknown"
173-
),
174-
"event_id": getattr(marker_event, "event_id", None),
175-
"error_type": type(e).__name__,
176-
},
177-
exc_info=True,
178-
)
179-
180-
# Phase 2: Process regular events to update workflow state
181-
for event in decision_events.get_events():
182-
try:
183-
logger.debug(
184-
"Processing history event",
185-
extra={
186-
"workflow_id": self._context.info().workflow_id,
187-
"event_type": getattr(event, "event_type", "unknown"),
188-
"event_id": getattr(event, "event_id", None),
189-
"replay_mode": self._context.is_replay_mode(),
190-
},
191-
)
192-
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
193-
self._decision_manager.handle_history_event(event)
194-
except Exception as e:
195-
logger.warning(
196-
"Error processing history event",
197-
extra={
198-
"workflow_id": self._context.info().workflow_id,
199-
"event_type": getattr(event, "event_type", "unknown"),
200-
"event_id": getattr(event, "event_id", None),
201-
"error_type": type(e).__name__,
202-
},
203-
exc_info=True,
204-
)
205-
206-
# Phase 3: Execute workflow logic
207-
self._execute_workflow_once(decision_task)
208-
209-
def _execute_workflow_once(
210-
self, decision_task: PollForDecisionTaskResponse
211-
) -> None:
212-
"""
213-
Execute the workflow function to generate new decisions.
156+
# Phase 1: Process markers first
157+
for marker_event in decision_events.markers:
158+
logger.debug(
159+
"Processing marker event",
160+
extra={
161+
"workflow_id": ctx.info().workflow_id,
162+
"marker_name": getattr(marker_event, "marker_name", "unknown"),
163+
"event_id": getattr(marker_event, "event_id", None),
164+
"replay_mode": ctx.is_replay_mode(),
165+
},
166+
)
167+
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
168+
self._decision_manager.handle_history_event(marker_event)
214169

215-
This blocks until the workflow schedules an activity or completes.
170+
# Phase 2: Process regular input events
171+
for event in decision_events.input:
172+
logger.debug(
173+
"Processing history event",
174+
extra={
175+
"workflow_id": ctx.info().workflow_id,
176+
"event_type": getattr(event, "event_type", "unknown"),
177+
"event_id": getattr(event, "event_id", None),
178+
"replay_mode": ctx.is_replay_mode(),
179+
},
180+
)
181+
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
182+
self._decision_manager.handle_history_event(event)
216183

217-
Args:
218-
decision_task: The decision task containing workflow context
219-
"""
220-
try:
221-
# Extract workflow input from history
222-
if self._task is None:
223-
workflow_input = self._extract_workflow_input(decision_task)
224-
self._task = self._loop.create_task(
225-
self._workflow_instance.run(workflow_input)
184+
# Phase 3: Execute workflow logic
185+
if not self._workflow_instance.is_started():
186+
self._workflow_instance.start(
187+
self._extract_workflow_input(decision_task)
226188
)
227189

228-
self._loop.run_until_yield()
190+
self._workflow_instance.run_once()
229191

230-
except Exception as e:
231-
logger.error(
232-
"Error executing workflow function",
233-
extra={
234-
"workflow_type": self._context.info().workflow_type,
235-
"workflow_id": self._context.info().workflow_id,
236-
"run_id": self._context.info().workflow_run_id,
237-
"error_type": type(e).__name__,
238-
},
239-
exc_info=True,
240-
)
241-
raise
192+
# Phase 4: update state machine with output events
193+
for event in decision_events.output:
194+
self._decision_manager.handle_history_event(event)
242195

243196
def _extract_workflow_input(
244197
self, decision_task: PollForDecisionTaskResponse
245-
) -> Any:
198+
) -> Payload:
246199
"""
247200
Extract workflow input from the decision task history.
248201
@@ -253,26 +206,15 @@ def _extract_workflow_input(
253206
The workflow input data, or None if not found
254207
"""
255208
if not decision_task.history or not hasattr(decision_task.history, "events"):
256-
logger.warning("No history events found in decision task")
257-
return None
209+
raise ValueError("No history events found in decision task")
258210

259211
# Look for WorkflowExecutionStarted event
260212
for event in decision_task.history.events:
261213
if hasattr(event, "workflow_execution_started_event_attributes"):
262-
started_attrs = event.workflow_execution_started_event_attributes
214+
started_attrs: WorkflowExecutionStartedEventAttributes = (
215+
event.workflow_execution_started_event_attributes
216+
)
263217
if started_attrs and hasattr(started_attrs, "input"):
264-
# Deserialize the input using the client's data converter
265-
try:
266-
# Use from_data method with a single type hint of None (no type conversion)
267-
input_data_list = self._context.data_converter().from_data(
268-
started_attrs.input, [None]
269-
)
270-
input_data = input_data_list[0] if input_data_list else None
271-
logger.debug(f"Extracted workflow input: {input_data}")
272-
return input_data
273-
except Exception as e:
274-
logger.warning(f"Failed to deserialize workflow input: {e}")
275-
return None
276-
277-
logger.warning("No WorkflowExecutionStarted event found in history")
278-
return None
218+
return started_attrs.input
219+
220+
raise ValueError("No WorkflowExecutionStarted event found in history")

0 commit comments

Comments
 (0)