Skip to content

Added filter step. #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions taskiq_pipelines/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def act(
self,
broker: AsyncBroker,
step_number: int,
parent_task_id: str,
task_id: str,
pipe_data: str,
result: "TaskiqResult[Any]",
Expand All @@ -57,6 +58,7 @@ async def act(

:param broker: current broker.
:param step_number: current step number.
:param parent_task_id: current task id.
:param task_id: task_id to use.
:param pipe_data: serialized pipeline must be in labels.
:param result: result of a previous task.
Expand Down
1 change: 1 addition & 0 deletions taskiq_pipelines/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ async def post_execute( # noqa: C901, WPS212
await next_step.act(
broker=self.broker,
step_number=current_step_num + 1,
parent_task_id=message.task_id,
task_id=next_step_data.task_id,
pipe_data=pipeline_data,
result=result,
Expand Down
74 changes: 73 additions & 1 deletion taskiq_pipelines/pipeliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import ParamSpec

from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
from taskiq_pipelines.steps import MapperStep, SequentialStep, parse_step
from taskiq_pipelines.steps import FilterStep, MapperStep, SequentialStep, parse_step

_ReturnType = TypeVar("_ReturnType")
_FuncParams = ParamSpec("_FuncParams")
Expand Down Expand Up @@ -182,6 +182,78 @@ def map(
)
return self

@overload
def filter(
self: "Pipeline[_FuncParams, _ReturnType]",
task: Union[
AsyncKicker[Any, Coroutine[Any, Any, bool]],
AsyncTaskiqDecoratedTask[Any, Coroutine[Any, Any, bool]],
],
param_name: Optional[str] = None,
skip_errors: bool = False,
check_interval: float = 0.5,
**additional_kwargs: Any,
) -> "Pipeline[_FuncParams, _ReturnType]":
...

@overload
def filter(
self: "Pipeline[_FuncParams, _ReturnType]",
task: Union[
AsyncKicker[Any, bool],
AsyncTaskiqDecoratedTask[Any, bool],
],
param_name: Optional[str] = None,
skip_errors: bool = False,
check_interval: float = 0.5,
**additional_kwargs: Any,
) -> "Pipeline[_FuncParams, _ReturnType]":
...

def filter(
self,
task: Union[
AsyncKicker[Any, Any],
AsyncTaskiqDecoratedTask[Any, Any],
],
param_name: Optional[str] = None,
skip_errors: bool = False,
check_interval: float = 0.5,
**additional_kwargs: Any,
) -> Any:
"""
Add filter step.

This step is executed on a list of items,
like map.

It runs many small subtasks for each item
in sequence and if task returns true,
the result is added to the final list.

:param task: task to execute on every item.
:param param_name: parameter name to pass item into, defaults to None
:param skip_errors: skip errors if any, defaults to False
:param check_interval: how often the result of all subtasks is checked,
defaults to 0.5
:param additional_kwargs: additional function's kwargs.
:return: pipeline with filtering step.
"""
self.steps.append(
DumpedStep(
step_type=FilterStep.step_name,
step_data=FilterStep.from_task(
task=task,
param_name=param_name,
skip_errors=skip_errors,
check_interval=check_interval,
**additional_kwargs,
).dumps(),
task_id="",
),
)
return self

def dumps(self) -> str:
"""
Dumps current pipeline as string.
Expand Down
2 changes: 2 additions & 0 deletions taskiq_pipelines/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from logging import getLogger

from taskiq_pipelines.abc import AbstractStep
from taskiq_pipelines.steps.filter import FilterStep
from taskiq_pipelines.steps.mapper import MapperStep
from taskiq_pipelines.steps.sequential import SequentialStep

Expand All @@ -19,4 +20,5 @@ def parse_step(step_type: str, step_data: str) -> AbstractStep:
__all__ = [
"MapperStep",
"SequentialStep",
"FilterStep",
]
185 changes: 185 additions & 0 deletions taskiq_pipelines/steps/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import asyncio
from typing import Any, Dict, Iterable, List, Optional, Union

import pydantic
from taskiq import AsyncBroker, TaskiqError, TaskiqResult
from taskiq.brokers.shared_broker import async_shared_broker
from taskiq.context import Context, default_context
from taskiq.decor import AsyncTaskiqDecoratedTask
from taskiq.kicker import AsyncKicker

from taskiq_pipelines.abc import AbstractStep
from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
from taskiq_pipelines.exceptions import AbortPipeline


@async_shared_broker.task(task_name="taskiq_pipelines.shared.filter_tasks")
async def filter_tasks( # noqa: C901, WPS210, WPS231
task_ids: List[str],
parent_task_id: str,
check_interval: float,
context: Context = default_context,
skip_errors: bool = False,
) -> List[Any]:
"""
Filter resulted tasks.

It takes list of task ids,
and parent task id.

After all subtasks are completed it gets
result of a parent task, and
if subtask's result of execution can be
converted to True, the item from the original
tasks is added to the resulting array.

:param task_ids: ordered list of task ids.
:param parent_task_id: task id of a parent task.
:param check_interval: how often checks are performed.
:param context: context of the execution, defaults to default_context
:param skip_errors: skip errors of subtasks, defaults to False
:raises TaskiqError: if any subtask has returned error.
:return: fitlered results.
"""
ordered_ids = task_ids[:]
tasks_set = set(task_ids)
while tasks_set:
for task_id in task_ids: # noqa: WPS327
if await context.broker.result_backend.is_result_ready(task_id):
try:
tasks_set.remove(task_id)
except LookupError:
continue
await asyncio.sleep(check_interval)

results = await context.broker.result_backend.get_result(parent_task_id)
filtered_results = []
for task_id, value in zip( # type: ignore # noqa: WPS352, WPS440
ordered_ids,
results.return_value,
):
result = await context.broker.result_backend.get_result(task_id)
if result.is_err:
if skip_errors:
continue
raise TaskiqError(f"Task {task_id} returned error. Filtering failed.")
if result.return_value:
filtered_results.append(value)
return filtered_results


class FilterStep(pydantic.BaseModel, AbstractStep, step_name="filter"):
"""Task to filter results."""

task_name: str
labels: Dict[str, str]
param_name: Optional[str]
additional_kwargs: Dict[str, Any]
skip_errors: bool
check_interval: float

def dumps(self) -> str:
"""
Dumps step as string.

:return: returns json.
"""
return self.json()

@classmethod
def loads(cls, data: str) -> "FilterStep":
"""
Parses mapper step from string.

:param data: dumped data.
:return: parsed step.
"""
return pydantic.parse_raw_as(FilterStep, data)

async def act(
self,
broker: AsyncBroker,
step_number: int,
parent_task_id: str,
task_id: str,
pipe_data: str,
result: "TaskiqResult[Any]",
) -> None:
"""
Run filter action.

This function creates many small filter steps,
and then collects all results in one big filtered array,
using 'filter_tasks' shared task.

:param broker: current broker.
:param step_number: current step number.
:param parent_task_id: task_id of the previous step.
:param task_id: task_id to use in this step.
:param pipe_data: serialized pipeline.
:param result: result of the previous task.
:raises AbortPipeline: if result is not iterable.
"""
if not isinstance(result.return_value, Iterable):
raise AbortPipeline("Result of the previous task is not iterable.")
sub_task_ids = []
for item in result.return_value:
kicker: "AsyncKicker[Any, Any]" = AsyncKicker(
task_name=self.task_name,
broker=broker,
labels=self.labels,
)
if self.param_name:
self.additional_kwargs[self.param_name] = item
task = await kicker.kiq(**self.additional_kwargs)
else:
task = await kicker.kiq(item, **self.additional_kwargs)
sub_task_ids.append(task.task_id)

await filter_tasks.kicker().with_task_id(task_id).with_broker(
broker,
).with_labels(
**{CURRENT_STEP: step_number, PIPELINE_DATA: pipe_data}, # type: ignore
).kiq(
sub_task_ids,
parent_task_id,
check_interval=self.check_interval,
skip_errors=self.skip_errors,
)

@classmethod
def from_task(
cls,
task: Union[
AsyncKicker[Any, Any],
AsyncTaskiqDecoratedTask[Any, Any],
],
param_name: Optional[str],
skip_errors: bool,
check_interval: float,
**additional_kwargs: Any,
) -> "FilterStep":
"""
Create new filter step from task.

:param task: task to execute.
:param param_name: parameter name.
:param skip_errors: don't fail collector
task on errors.
:param check_interval: how often tasks are checked.
:param additional_kwargs: additional function's kwargs.
:return: new mapper step.
"""
if isinstance(task, AsyncTaskiqDecoratedTask):
kicker = task.kicker()
else:
kicker = task
message = kicker._prepare_message() # noqa: WPS437
return FilterStep(
task_name=message.task_name,
labels=message.labels,
param_name=param_name,
additional_kwargs=additional_kwargs,
skip_errors=skip_errors,
check_interval=check_interval,
)
13 changes: 9 additions & 4 deletions taskiq_pipelines/steps/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from taskiq_pipelines.exceptions import AbortPipeline


@async_shared_broker.task(task_name="taskiq_pipelines.wait_tasks")
async def wait_tasks( # noqa: C901
@async_shared_broker.task(task_name="taskiq_pipelines.shared.wait_tasks")
async def wait_tasks( # noqa: C901, WPS231
task_ids: List[str],
check_interval: float,
context: Context = default_context,
Expand All @@ -44,9 +44,12 @@ async def wait_tasks( # noqa: C901
ordered_ids = task_ids[:]
tasks_set = set(task_ids)
while tasks_set:
for task_id in task_ids:
for task_id in task_ids: # noqa: WPS327
if await context.broker.result_backend.is_result_ready(task_id):
tasks_set.remove(task_id)
try:
tasks_set.remove(task_id)
except LookupError:
continue
await asyncio.sleep(check_interval)

results = []
Expand Down Expand Up @@ -92,6 +95,7 @@ async def act(
self,
broker: AsyncBroker,
step_number: int,
parent_task_id: str,
task_id: str,
pipe_data: str,
result: "TaskiqResult[Any]",
Expand All @@ -109,6 +113,7 @@ async def act(
:param broker: current broker.
:param step_number: current step number.
:param task_id: waiter task_id.
:param parent_task_id: task_id of the previous step.
:param pipe_data: serialized pipeline.
:param result: result of the previous task.
:raises AbortPipeline: if the result of the
Expand Down
2 changes: 2 additions & 0 deletions taskiq_pipelines/steps/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def act(
self,
broker: AsyncBroker,
step_number: int,
parent_task_id: str,
task_id: str,
pipe_data: str,
result: "TaskiqResult[Any]",
Expand All @@ -61,6 +62,7 @@ async def act(

:param broker: current broker.
:param step_number: current step number.
:param parent_task_id: current step's task id.
:param task_id: new task id.
:param pipe_data: serialized pipeline.
:param result: result of the previous task.
Expand Down