Description
Hey! Awesome library overall, but after upgrading to version 1.0 I am facing an issue with returning pydantic Models from tasks.
Due to the change here what happens now, is that my pydantic model is returned as dict
, not a serialized model, i.e.:
result = await task.wait_result()
print(type(result.return_value)) # dict here
One side note that might be causing troubles -- maybe I have an edge case, but I have two packages where one contains the taskiq tasks and the broker and a separate client app (FastAPI) that calls these tasks, therefore because the redis credentials or the redis url can be different depending on whether running taskiq or the client app, I had to create a function that would accept parameters and create a broker.
I have something as following in my Taskiq tasks app:
ml_models/tasks.py
async def train_regression_model_task(
schema: str, table_name: str, data: RegressionRequest
) -> RegressionResult:
...
return metrics
ml_models/taskiq.py
def _extract_tasks(mod: ModuleType) -> list[Callable]:
return [v for k, v in mod.__dict__.items() if callable(v) and k.endswith("_task")]
def get_broker(redis_url: str, redis_password: str | None) -> AsyncBroker:
broker = ListQueueBroker(
url=redis_url,
password=redis_password,
).with_result_backend(
CustomResultBackend(
redis_url=redis_url,
password=redis_password,
serializer=PickleSerializer(),
)
)
for task in _extract_tasks(tasks):
broker.register_task(task)
return broker
def find_task(
broker: AsyncBroker, f: Callable[_P, _R]
) -> AsyncTaskiqDecoratedTask[_P, _R]:
name = f.__module__ + ":" + f.__name__
t = broker.find_task(task_name=name)
if t is None:
raise ValueError(
f"Could not find task with name {name}. Available tasks are: {broker.get_all_tasks()}"
)
return t
and then in the FastAPI app I am calling the tasks as following:
broker = taskiq.get_broker(settings.redis_url, settings.redis_password)
train_regression_model_task = find_task(broker, tasks.train_regression_model_task)
async def execute_regression():
task = await train_regression_model_task.kiq(
schema_name, table_name, regression_data
)
result = await task.wait_result(timeout=60)
As a temporary fix, I had to subclass RedisAsyncResultBackend as following
class CustomResultBackend(RedisAsyncResultBackend):
async def set_result(
self,
task_id: str,
result: TaskiqResult[_ReturnType],
) -> None:
"""
Sets task result in redis.
Dumps TaskiqResult instance into the bytes and writes
it to redis.
:param task_id: ID of the task.
:param result: TaskiqResult instance.
"""
redis_set_params: dict[str, str | int | bytes] = {
"name": task_id,
"value": self.serializer.dumpb(result),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time
async with Redis(connection_pool=self.redis_pool) as redis:
await redis.set(**redis_set_params) # type: ignore