Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output caching #587

Merged
merged 16 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Ensure all states coming into task runners are raw
  • Loading branch information
cicdw committed Jan 29, 2019
commit 8509cf22bd9f71da2422786b1ff1c09f4ec93562
4 changes: 2 additions & 2 deletions src/prefect/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def ensure_raw(self) -> None:
setattr(self, attr, unpacked_value)
self._metadata[attr].update(raw=True)

if hasattr(self, "cached_inputs"):
if getattr(self, "cached_inputs", None) is not None:
# each variable could presumably come from different tasks with
# different result handlers
for variable in self.cached_inputs: # type: ignore
Expand All @@ -108,7 +108,7 @@ def ensure_raw(self) -> None:
self.cached_inputs[variable] = unpacked_value # type: ignore
self._metadata["cached_inputs"][variable]["raw"] = True

if hasattr(self, "cached"):
if getattr(self, "cached", None) is not None:
self.cached.ensure_raw() # type: ignore

def is_pending(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/engine/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def initialize_run( # type: ignore
- tuple: a tuple of the updated state, context, upstream_states, and inputs objects
"""
state, context = super().initialize_run(state=state, context=context)
state.ensure_raw()

if isinstance(state, Retrying):
run_count = state.run_count + 1
Expand All @@ -158,6 +159,9 @@ def initialize_run( # type: ignore

context.update(task_run_count=run_count, task_name=self.task.name)

for up_state in upstream_states.values():
up_state.ensure_raw() # ensures no inputs need handling from this point forward

return TaskRunnerInitializeResult(
state=state, context=context, upstream_states=upstream_states
)
Expand Down
43 changes: 43 additions & 0 deletions tests/engine/test_task_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import cloudpickle
import collections
import tempfile
from datetime import datetime, timedelta
from time import sleep
from unittest.mock import MagicMock
Expand All @@ -19,6 +21,7 @@
partial_inputs_only,
partial_parameters_only,
)
from prefect.engine.result_handlers import LocalResultHandler
from prefect.engine.state import (
CachedState,
Failed,
Expand All @@ -38,6 +41,7 @@
TriggerFailed,
)
from prefect.engine.task_runner import ENDRUN, TaskRunner
from prefect.serialization.result_handlers import ResultHandlerSchema
from prefect.utilities.configuration import set_temporary_config
from prefect.utilities.debug import raise_on_exception
from prefect.utilities.tasks import pause_task
Expand Down Expand Up @@ -373,6 +377,45 @@ def test_unwrap_submitted_states(self):
)
assert result.state is state

def test_ensures_all_upstream_states_are_raw(self):
serialized_handler = ResultHandlerSchema().dump(LocalResultHandler())

with tempfile.NamedTemporaryFile() as tmp:
with open(tmp.name, "wb") as f:
cloudpickle.dump(42, f)

a, b, c = (
Success(result=tmp.name),
Failed(result=55),
Pending(result=tmp.name),
)
a._metadata["result"] = dict(raw=False, result_handler=serialized_handler)
c._metadata["result"] = dict(raw=False, result_handler=serialized_handler)
result = TaskRunner(Task()).initialize_run(
state=Success(), context={}, upstream_states={1: a, 2: b, 3: c}
)

assert result.upstream_states[1].result == 42
assert result.upstream_states[2].result == 55
assert result.upstream_states[3].result == 42

def test_ensures_provided_initial_state_is_raw(self):
serialized_handler = ResultHandlerSchema().dump(LocalResultHandler())

with tempfile.NamedTemporaryFile() as tmp:
with open(tmp.name, "wb") as f:
cloudpickle.dump(42, f)

state = Success(result=tmp.name)
state._metadata["result"] = dict(
raw=False, result_handler=serialized_handler
)
result = TaskRunner(Task()).initialize_run(
state=state, context={}, upstream_states={}
)

assert result.state.result == 42


class TestCheckUpstreamFinished:
def test_with_empty(self):
Expand Down