|
33 | 33 | MAX_PRIORITY = 100 |
34 | 34 | DEFAULT_PRIORITY = 0 |
35 | 35 |
|
| 36 | +TASK_REFRESH_ATTRS = { |
| 37 | + "_exception_data", |
| 38 | + "_return_value", |
| 39 | + "finished_at", |
| 40 | + "started_at", |
| 41 | + "status", |
| 42 | +} |
| 43 | + |
36 | 44 |
|
37 | 45 | class ResultStatus(TextChoices): |
38 | 46 | NEW = ("NEW", _("New")) |
@@ -239,59 +247,58 @@ class TaskResult(Generic[T]): |
239 | 247 | backend: str |
240 | 248 | """The name of the backend the task will run on""" |
241 | 249 |
|
242 | | - _result: Optional[Union[T, SerializedExceptionDict]] = field( |
243 | | - init=False, default=None |
244 | | - ) |
| 250 | + _return_value: Optional[T] = field(init=False, default=None) |
| 251 | + _exception_data: Optional[SerializedExceptionDict] = field(init=False, default=None) |
245 | 252 |
|
246 | 253 | @property |
247 | | - def result(self) -> Optional[Union[T, BaseException]]: |
248 | | - if self.status == ResultStatus.COMPLETE: |
249 | | - return cast(T, self._result) |
250 | | - elif self.status == ResultStatus.FAILED: |
251 | | - return ( |
252 | | - exception_from_dict(cast(SerializedExceptionDict, self._result)) |
253 | | - if self._result is not None |
254 | | - else None |
255 | | - ) |
256 | | - |
257 | | - raise ValueError("Task has not finished yet") |
| 254 | + def exception(self) -> Optional[BaseException]: |
| 255 | + return ( |
| 256 | + exception_from_dict(cast(SerializedExceptionDict, self._exception_data)) |
| 257 | + if self.status == ResultStatus.FAILED and self._exception_data is not None |
| 258 | + else None |
| 259 | + ) |
258 | 260 |
|
259 | 261 | @property |
260 | 262 | def traceback(self) -> Optional[str]: |
261 | 263 | """ |
262 | 264 | Return the string representation of the traceback of the task if it failed |
263 | 265 | """ |
264 | | - if self.status == ResultStatus.FAILED and self._result is not None: |
265 | | - return cast(SerializedExceptionDict, self._result)["exc_traceback"] |
266 | | - |
267 | | - return None |
| 266 | + return ( |
| 267 | + cast(SerializedExceptionDict, self._exception_data)["exc_traceback"] |
| 268 | + if self.status == ResultStatus.FAILED and self._exception_data is not None |
| 269 | + else None |
| 270 | + ) |
268 | 271 |
|
269 | | - def get_result(self) -> Optional[T]: |
| 272 | + @property |
| 273 | + def return_value(self) -> Optional[T]: |
270 | 274 | """ |
271 | | - A convenience method to get the result, or None if it's not ready yet or has failed. |
| 275 | + Get the return value of the task. |
| 276 | +
|
| 277 | + If the task didn't complete successfully, an exception is raised. |
| 278 | + This is to distinguish against the task returning None. |
272 | 279 | """ |
273 | | - return cast(T, self.result) if self.status == ResultStatus.COMPLETE else None |
| 280 | + if self.status == ResultStatus.FAILED: |
| 281 | + raise ValueError("Task failed") |
| 282 | + |
| 283 | + elif self.status != ResultStatus.COMPLETE: |
| 284 | + raise ValueError("Task has not finished yet") |
| 285 | + |
| 286 | + return cast(T, self._return_value) |
274 | 287 |
|
275 | 288 | def refresh(self) -> None: |
276 | 289 | """ |
277 | 290 | Reload the cached task data from the task store |
278 | 291 | """ |
279 | 292 | refreshed_task = self.task.get_backend().get_result(self.id) |
280 | 293 |
|
281 | | - # status, started_at, finished_at and result are the only refreshable attributes |
282 | | - self.status = refreshed_task.status |
283 | | - self.started_at = refreshed_task.started_at |
284 | | - self.finished_at = refreshed_task.finished_at |
285 | | - self._result = refreshed_task._result |
| 294 | + for attr in TASK_REFRESH_ATTRS: |
| 295 | + setattr(self, attr, getattr(refreshed_task, attr)) |
286 | 296 |
|
287 | 297 | async def arefresh(self) -> None: |
288 | 298 | """ |
289 | 299 | Reload the cached task data from the task store |
290 | 300 | """ |
291 | 301 | refreshed_task = await self.task.get_backend().aget_result(self.id) |
292 | 302 |
|
293 | | - # status, started_at, finished_at and result are the only refreshable attributes |
294 | | - self.status = refreshed_task.status |
295 | | - self.started_at = refreshed_task.started_at |
296 | | - self.finished_at = refreshed_task.finished_at |
297 | | - self._result = refreshed_task._result |
| 303 | + for attr in TASK_REFRESH_ATTRS: |
| 304 | + setattr(self, attr, getattr(refreshed_task, attr)) |
0 commit comments