Skip to content

Commit

Permalink
Merge pull request #48 from wayfair-incubator/efficient_writes
Browse files Browse the repository at this point in the history
efficient writes and update tests
  • Loading branch information
patkivikram authored Sep 5, 2023
2 parents 275530d + 869e2e7 commit 2e751b1
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 23 deletions.
7 changes: 7 additions & 0 deletions dagger/modeler/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ async def create_instance(
*,
repartition: bool = True,
seed: random.Random = None,
submit_task: bool = False,
**kwargs,
) -> ITemplateDAGInstance[KT, VT]:
"""Method for creating an instance of a workflow definition
Expand All @@ -506,6 +507,7 @@ async def create_instance(
:param repartition: Flag indicating if the creation of this instance needs to be stored on the current node or
by the owner of the partition defined by the partition_key_lookup
:param seed: the seed to use to create all internal instances of the workflow
:param submit_task: if True also submit the task for execution
:param **kwargs: Other keyword arguments
:return: An instance of the workflow
"""
Expand Down Expand Up @@ -541,6 +543,11 @@ async def create_instance(
if repartition:
await self.app.tasks_topic.send(key=template_instance.runtime_parameters[partition_key_lookup], value=template_instance) # type: ignore
else:
if submit_task:
template_instance.status = TaskStatus(
code=TaskStatusEnum.SUBMITTED.name,
value=TaskStatusEnum.SUBMITTED.value,
)
await self.app._store_and_create_task(template_instance) # type: ignore
return template_instance

Expand Down
2 changes: 1 addition & 1 deletion dagger/service/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,9 @@ def main(self, override_logging=False) -> None:

async def _store_and_create_task(self, task):
if isinstance(task, ITemplateDAGInstance):
await self._store_root_template_instance(task)
if task.status.code == TaskStatusEnum.SUBMITTED.name:
await task.start(workflow_instance=task)
await self._store_root_template_instance(task)

async def _process_tasks_create_event(self, stream):
"""Upon creation of tasks, store them in the datastore.
Expand Down
1 change: 1 addition & 0 deletions dagger/store/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ async def execute_system_timer_task(self) -> None: # pragma: no cover
finished = await task.start(workflow_instance)
else:
await task.start(workflow_instance)
await self.app._update_instance(task=workflow_instance) # type: ignore
if finished:
await self.remove_trigger(trigger)
if not task or task.status.code in TERMINAL_STATUSES:
Expand Down
14 changes: 3 additions & 11 deletions dagger/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ async def on_complete( # noqa: C901
else:
time_completed = int(time.time())
self.time_completed = time_completed
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
if not iterate:
logger.debug("Skipping on_complete as iterate is false")
return
Expand Down Expand Up @@ -341,7 +340,6 @@ async def start(
runtime_parameters=workflow_instance.runtime_parameters,
workflow_instance=workflow_instance,
)
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
if self.status.code == TaskStatusEnum.FAILURE.name:
await self.on_complete(
status=self.status, workflow_instance=workflow_instance
Expand Down Expand Up @@ -421,7 +419,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> bool
code=TaskStatusEnum.EXECUTING.name, value=TaskStatusEnum.EXECUTING.value
)
self.time_submitted = int(time.time())
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
if self.time_to_execute and int(time.time()) < self.time_to_execute:
return False
if (
Expand Down Expand Up @@ -582,7 +579,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
logger.warning(
f"The task instance to skip with id {next_task_id} was not found. Skipped but did not set status to {TaskStatusEnum.SKIPPED.value}"
)
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
await self.on_complete(workflow_instance=workflow_instance)

async def execute(
Expand Down Expand Up @@ -715,7 +711,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
code=TaskStatusEnum.EXECUTING.name, value=TaskStatusEnum.EXECUTING.value
)
self.time_submitted = int(time.time())
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore

async def _update_correletable_key(self, workflow_instance: ITask) -> None:
"""Updates the correletable key if the local is not the same as global key.
Expand Down Expand Up @@ -889,10 +884,9 @@ async def process_event_helper(self, event): # noqa: C901
await task_instance.on_complete(
workflow_instance=workflow_instance
)
else:
await dagger.service.services.Dagger.app._update_instance(
task=workflow_instance
) # type: ignore
await dagger.service.services.Dagger.app._update_instance(
task=workflow_instance
) # type: ignore
processed_task = True

if getattr(self.__task, "match_only_one", False):
Expand Down Expand Up @@ -985,7 +979,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
runtime_parameters=workflow_instance.runtime_parameters,
workflow_instance=workflow_instance,
)
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
logger.debug(
f"Starting task {self.task_name} with root dag id {self.root_dag}, parent task id {self.parent_id}, and task id {self.id}"
)
Expand Down Expand Up @@ -1062,7 +1055,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
runtime_parameters=workflow_instance.runtime_parameters,
workflow_instance=workflow_instance,
)
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
logger.debug(
f"Starting task {self.task_name} with parent task id {self.parent_id}, and task id {self.id}"
)
Expand Down
2 changes: 2 additions & 0 deletions dagger/templates/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def create_instance(
*,
repartition: bool = True,
seed: random.Random = None,
submit_task: bool = False,
**kwargs,
) -> ITemplateDAGInstance: # pragma: no cover
"""Method for creating an instance of a workflow definition
Expand All @@ -75,6 +76,7 @@ async def create_instance(
:param repartition: Flag indicating if the creation of this instance needs to be stored on the current node or
by the owner of the partition defined by the partition_key_lookup
:param seed: the seed to use to create all internal instances of the workflow
:param submit_task: if True also submit the task for execution
:param **kwargs: Other keyword arguments
:return: An instance of the workflow
"""
Expand Down
8 changes: 5 additions & 3 deletions integration_tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,21 @@ async def create_and_submit_pizza_delivery_workflow(
pizza_workflow_instance = await pizza_workflow_template.create_instance(
uuid.uuid1(),
repartition=False, # Create this instance on the current worker
submit_task=True,
order_id=order_id,
customer_id=customer_id,
pizza_type=pizza_type,
)
await workflow_engine.submit(pizza_workflow_instance, repartition=False)


@workflow_engine.faust_app.agent(simple_topic_stop)
async def simple_data_stream_stop(stream):
async for value in stream:

instance = await workflow_engine.get_instance(running_task_ids[-1])
await instance.stop()
await instance.stop(
runtime_parameters=instance.runtime_parameters, workflow_instance=instance
)


@workflow_engine.faust_app.agent(simple_topic)
Expand Down Expand Up @@ -767,10 +769,10 @@ async def simple_data_stream(stream):
complete_by_time=120000,
repartition=False,
seed=rd,
submit_task=True,
)
templates.append(instance)
running_task_ids.append(instance.id)
await workflow_engine.submit(instance, repartition=False)


@workflow_engine.faust_app.agent(orders_topic)
Expand Down
8 changes: 0 additions & 8 deletions tests/tasks/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,9 @@ async def test_parallel_composite_task_start_non_terminal(
workflow_instance=workflow_instance_fixture
)
assert parallel_composite_task_fixture.execute.called
assert dagger.service.services.Dagger.app._update_instance.called
assert not parallel_composite_task_fixture.on_complete.called
assert child_task1.start.called
assert child_task2.start.called
assert dagger.service.services.Dagger.app._update_instance.called
except Exception:
pytest.fail("Error should not be thrown")

Expand Down Expand Up @@ -544,7 +542,6 @@ async def test_executortask(self, executor_fixture, workflow_instance_fixture):
parent_task.notify = CoroutineMock()
assert executor_fixture.get_id() == executor_fixture.id
await executor_fixture.start(workflow_instance=workflow_instance_fixture)
assert dagger.service.services.Dagger.app._update_instance.called
assert executor_fixture.status.code == TaskStatusEnum.COMPLETED.name
assert executor_fixture.time_completed != 0
assert parent_task.notify.called
Expand All @@ -569,7 +566,6 @@ async def test_decisiontask(self, decision_fixture, workflow_instance_fixture):
workflow_instance_fixture.runtime_parameters = {}
assert decision_fixture.get_id() == decision_fixture.id
await decision_fixture.start(workflow_instance=workflow_instance_fixture)
assert dagger.service.services.Dagger.app._update_instance.called
assert decision_fixture.on_complete.called
with pytest.raises(NotImplementedError):
await decision_fixture.execute(
Expand Down Expand Up @@ -646,7 +642,6 @@ async def test_sensortask(self, sensor_fixture, workflow_instance_fixture):
ret_val = sensor_fixture.get_correlatable_key(payload)
assert payload == ret_val
await sensor_fixture.start(workflow_instance=workflow_instance_fixture)
assert dagger.service.services.Dagger.app._update_instance.called
assert not sensor_fixture.on_complete.called
with pytest.raises(NotImplementedError):
await sensor_fixture.execute(
Expand Down Expand Up @@ -714,7 +709,6 @@ async def test_current_triggertask(
)
assert trigger_fixture.get_id() == trigger_fixture.id
await trigger_fixture.start(workflow_instance=workflow_instance_fixture)
assert dagger.service.services.Dagger.app._update_instance.called
assert trigger_fixture.status.code == TaskStatusEnum.COMPLETED.name
assert (
dagger.service.services.Dagger.app._store.process_trigger_task_complete.called
Expand All @@ -729,7 +723,6 @@ async def test_future_interval_fixture(
dagger.service.services.Dagger.app._update_instance = CoroutineMock()
assert interval_fixture.get_id() == interval_fixture.id
await interval_fixture.start(workflow_instance=workflow_instance_fixture)
assert dagger.service.services.Dagger.app._update_instance.called
assert interval_fixture.status.code == TaskStatusEnum.EXECUTING.name

@pytest.mark.asyncio
Expand All @@ -743,7 +736,6 @@ async def test_current_interval_fixture(
dagger.service.services.Dagger.app._store.insert_trigger = CoroutineMock()
assert interval_fixture.get_id() == interval_fixture.id
await interval_fixture.start(workflow_instance=workflow_instance_fixture)
assert dagger.service.services.Dagger.app._update_instance.called
assert interval_fixture.status.code == TaskStatusEnum.COMPLETED.name

@pytest.mark.asyncio
Expand Down

0 comments on commit 2e751b1

Please sign in to comment.