Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output caching #587

Merged
merged 16 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/prefect/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def my_task():
in the `edges` argument. Defaults to the value of `eager_edge_validation` in
your prefect configuration file.
- result_handler (ResultHandler, optional): the handler to use for
retrieving and storing state results during execution
retrieving and storing state results during execution; if not provided, will default
to the one specified in your config

"""

Expand All @@ -151,7 +152,9 @@ def __init__(
self.name = name or type(self).__name__
self.schedule = schedule
self.environment = environment or prefect.environments.LocalEnvironment()
self.result_handler = result_handler
self.result_handler = (
result_handler or prefect.engine.get_default_result_handler_class()()
)

self.tasks = set() # type: Set[Task]
self.edges = set() # type: Set[Edge]
Expand Down
6 changes: 6 additions & 0 deletions src/prefect/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if TYPE_CHECKING:
from prefect.core.flow import Flow # pylint: disable=W0611
from prefect.engine.result_handlers import ResultHandler
from prefect.engine.state import State

VAR_KEYWORD = inspect.Parameter.VAR_KEYWORD
Expand Down Expand Up @@ -109,6 +110,9 @@ def run(self, x, y):
- cache_validator (Callable, optional): Validator which will determine
whether the cache for this task is still valid (only required if `cache_for`
is provided; defaults to `prefect.engine.cache_validators.duration_only`)
- result_handler (ResultHandler, optional): the handler to use for
retrieving and storing state results during execution; if not provided, will default to the
one attached to the Flow
- state_handlers (Iterable[Callable], optional): A list of state change handlers
that will be called whenever the task changes state, providing an
opportunity to inspect or modify the new state. The handler
Expand Down Expand Up @@ -140,6 +144,7 @@ def __init__(
skip_on_upstream_skip: bool = True,
cache_for: timedelta = None,
cache_validator: Callable = None,
result_handler: "ResultHandler" = None,
state_handlers: List[Callable] = None,
on_failure: Callable = None,
):
Expand Down Expand Up @@ -200,6 +205,7 @@ def __init__(
else prefect.engine.cache_validators.duration_only
)
self.cache_validator = cache_validator or default_validator
self.result_handler = result_handler

if state_handlers and not isinstance(state_handlers, collections.Sequence):
raise TypeError("state_handlers should be iterable.")
Expand Down
96 changes: 83 additions & 13 deletions src/prefect/engine/cloud/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import prefect
from prefect.client import Client
from prefect.core import Edge, Task
from prefect.engine.cloud import CloudResultHandler
from prefect.engine.result_handlers import ResultHandler
from prefect.engine.runner import ENDRUN
from prefect.engine.runner import ENDRUN, call_state_handlers
from prefect.engine.state import Failed, Mapped, State
from prefect.engine.task_runner import TaskRunner, TaskRunnerInitializeResult
from prefect.utilities.graphql import with_args
Expand All @@ -24,38 +23,33 @@ class CloudTaskRunner(TaskRunner):

Args:
- task (Task): the Task to be run / executed
- result_handler (ResultHandler, optional): the handler to use for
retrieving and storing state results during execution
- state_handlers (Iterable[Callable], optional): A list of state change handlers
that will be called whenever the task changes state, providing an
opportunity to inspect or modify the new state. The handler
will be passed the task runner instance, the old (prior) state, and the new
(current) state, with the following signature:

```
```python
state_handler(
task_runner: TaskRunner,
old_state: State,
new_state: State) -> State
```

If multiple functions are passed, then the `new_state` argument will be the
result of the previous handler.
- result_handler (ResultHandler, optional): the handler to use for
retrieving and storing state results during execution (if the Task doesn't already have one);
if not provided here or by the Task, will default to the one specified in your config
"""

def __init__(
self,
task: Task,
result_handler: ResultHandler = None,
state_handlers: Iterable[Callable] = None,
result_handler: ResultHandler = None,
) -> None:
self.client = Client()
result_handler = (
result_handler or prefect.engine.get_default_result_handler_class()()
)

super().__init__(
task=task, result_handler=result_handler, state_handlers=state_handlers
task=task, state_handlers=state_handlers, result_handler=result_handler
)

def _heartbeat(self) -> None:
Expand Down Expand Up @@ -164,6 +158,82 @@ def initialize_run( # type: ignore
# we assign this so it can be shared with heartbeat thread
self.task_run_id = context.get("task_run_id") # type: ignore

## ensure all inputs have been handled
if state is not None:
state.ensure_raw()
for up_state in upstream_states.values():
up_state.ensure_raw()

return super().initialize_run(
state=state, context=context, upstream_states=upstream_states
)

@call_state_handlers
def finalize_run(self, state: State, upstream_states: Dict[Edge, State]) -> State:
"""
Ensures that all results are handled appropriately on the final state.

Args:
- state (State): the final state of this task
- upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

Returns:
- State: the state of the task after running the check
"""
raise_on_exception = prefect.context.get("raise_on_exception", False)
from prefect.serialization.result_handlers import ResultHandlerSchema

## if a state has a "cached" attribute or a "cached_inputs" attribute, we need to handle it
if getattr(state, "cached_inputs", None) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than checking for the attribute, I suggest checking if state.is_pending() and state.cached_inputs is not None:, since only Pending states will qualify

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that's not entirely correct --> TimedOut states store inputs as well

Copy link
Member

@jlowin jlowin Jan 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIRealized

try:
input_handlers = {}

for edge, upstream_state in upstream_states.items():
if edge.key is not None:
input_handlers[edge.key] = upstream_state._metadata["result"][
"result_handler"
]

state.handle_inputs(input_handlers)
except Exception as exc:
self.logger.debug(
"Exception raised while serializing inputs: {}".format(repr(exc))
)
if raise_on_exception:
raise exc
new_state = Failed(
"Exception raised while serializing inputs.", result=exc
)
return new_state

if getattr(state, "cached", None) is not None:
try:
input_handlers = {}

for edge, upstream_state in upstream_states.items():
if edge.key is not None:
input_handlers[edge.key] = upstream_state._metadata["result"][
"result_handler"
]

state.cached.handle_inputs(input_handlers) # type: ignore
state.cached.handle_result(self.result_handler) # type: ignore
except Exception as exc:
self.logger.debug(
"Exception raised while serializing cached data: {}".format(
repr(exc)
)
)
if raise_on_exception:
raise exc
new_state = Failed(
"Exception raised while serializing cached data.", result=exc
)
return new_state

## finally, update state _metadata attribute with information about how to handle this state's data
state._metadata["result"].setdefault("raw", True)
state._metadata["result"].setdefault(
"result_handler", ResultHandlerSchema().dump(self.result_handler)
)
return state
3 changes: 2 additions & 1 deletion src/prefect/engine/flow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,11 @@ def run_task(
- State: `State` representing the final post-run state of the `Flow`.

"""
default_handler = task.result_handler or self.flow.result_handler
task_runner = self.task_runner_cls(
task=task,
result_handler=self.flow.result_handler,
state_handlers=task_runner_state_handlers,
result_handler=default_handler,
)

# if this task reduces over a mapped state, make sure its children have finished
Expand Down
1 change: 1 addition & 0 deletions src/prefect/engine/result_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Licensed under LICENSE.md; also available at https://www.prefect.io/licenses/alpha-eula

from prefect.engine.result_handlers.result_handler import ResultHandler
from prefect.engine.result_handlers.json_result_handler import JSONResultHandler
from prefect.engine.result_handlers.local_result_handler import LocalResultHandler
37 changes: 37 additions & 0 deletions src/prefect/engine/result_handlers/json_result_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed under LICENSE.md; also available at https://www.prefect.io/licenses/alpha-eula

import json
from typing import Any

from prefect.engine.result_handlers import ResultHandler


class JSONResultHandler(ResultHandler):
"""
Hook for storing and retrieving task results to / from JSON. Only intended to be used
for small data loads.
"""

def deserialize(self, jblob: str) -> Any:
"""
Deserialize a result from a string JSON blob.

Args:
- jblob (str): the JSON representation of the result

Returns:
- the deserialized result
"""
return json.loads(jblob)

def serialize(self, result: Any) -> str:
"""
Serialize the provided result to JSON.

Args:
- result (Any): the result to serialize

Returns:
- str: the JSON representation of the result
"""
return json.dumps(result)
110 changes: 108 additions & 2 deletions src/prefect/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class State:
def __init__(self, message: str = None, result: Any = None):
self.message = message
self.result = result
self.metadata = {} # type: dict
self._metadata = {
"result": {},
"cached_result": {},
"cached_inputs": {},
} # type: dict

def __repr__(self) -> str:
if self.message:
Expand All @@ -72,6 +76,108 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return id(self)

def _populate_metadata(self) -> None:
## populate result
if "result" not in self._metadata:
self._metadata["result"] = dict(raw=True)
elif "raw" not in self._metadata["result"]:
self._metadata["result"].update(raw=True)

## populate cached_result
if "cached_result" not in self._metadata:
self._metadata["cached_result"] = dict(raw=True)
elif "raw" not in self._metadata["cached_result"]:
self._metadata["cached_result"].update(raw=True)

## populate cached_inputs
if "cached_inputs" not in self._metadata:
self._metadata["cached_inputs"] = dict()
if getattr(self, "cached_inputs", None) is not None:
for variable in self.cached_inputs:
if variable not in self._metadata["cached_inputs"]:
self._metadata["cached_inputs"][variable] = dict(raw=True)
elif "raw" not in self._metadata["cached_inputs"][variable]:
self._metadata["cached_inputs"][variable].update(raw=True)

def handle_inputs(self, input_handlers: dict) -> None:
"""
Handles the `cached_inputs` attribute of this state (if it has one).

Args:
- input_handlers (dict): the individual serialized result handlers to use when
processing each variable in `cached_inputs`

Modifies the state object in place.
"""
from prefect.serialization.result_handlers import ResultHandlerSchema

schema = ResultHandlerSchema()
self._populate_metadata()
for variable in self.cached_inputs: # type: ignore
var_info = self._metadata["cached_inputs"].get(variable)
if var_info.get("raw") is True:
handler = ResultHandlerSchema().load(input_handlers[variable])
packed_value = handler.serialize(
self.cached_inputs[variable] # type: ignore
)
self.cached_inputs[variable] = packed_value # type: ignore
self._metadata["cached_inputs"][variable]["raw"] = False

def handle_result(self, result_handler: ResultHandler) -> None:
"""
Handles the `cached_result` attribute of this state (if it has one).

Args:
- result_handler (ResultHandler): the result handler to use when
processing the `cached_result`

Modifies the state object in place.
"""
from prefect.serialization.result_handlers import ResultHandlerSchema

schema = ResultHandlerSchema()
self._populate_metadata()
if self._metadata.get("cached_result", {}).get("raw") is True:
packed_value = result_handler.serialize(self.cached_result) # type: ignore
self.cached_result = packed_value # type: ignore
self._metadata["cached_result"]["raw"] = False
self._metadata["cached_result"]["result_handler"] = schema.dump(
result_handler
)

def ensure_raw(self) -> None:
"""
Ensures that all attributes are _raw_ (as specified in `self._metadata`).

Modifies the state object in place.
"""
from prefect.serialization.result_handlers import ResultHandlerSchema

schema = ResultHandlerSchema()

for attr in ["result", "cached_result"]:
if self._metadata[attr].get("raw") is False:
handler = schema.load(self._metadata[attr]["result_handler"])
unpacked_value = handler.deserialize(getattr(self, attr))
setattr(self, attr, unpacked_value)
self._metadata[attr].update(raw=True)

if getattr(self, "cached_inputs", None) is not None:
# each variable could presumably come from different tasks with
# different result handlers
for variable in self.cached_inputs: # type: ignore
var_info = self._metadata["cached_inputs"].get(variable, {})
if var_info.get("raw") is False:
handler = schema.load(var_info["result_handler"])
unpacked_value = handler.deserialize(
self.cached_inputs[variable] # type: ignore
)
self.cached_inputs[variable] = unpacked_value # type: ignore
self._metadata["cached_inputs"][variable]["raw"] = True

if getattr(self, "cached", None) is not None:
self.cached.ensure_raw() # type: ignore

def is_pending(self) -> bool:
"""
Checks if the object is currently in a pending state
Expand Down Expand Up @@ -244,7 +350,7 @@ def __init__(
cached_result_expiration: datetime.datetime = None,
):
super().__init__(message=message, result=result, cached_inputs=cached_inputs)
self.cached_result = cached_result
self.cached_result = cached_result # type: ignore
self.cached_parameters = cached_parameters
if cached_result_expiration is not None:
cached_result_expiration = ensure_tz_aware(cached_result_expiration)
Expand Down
Loading