Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output caching #587

Merged
merged 16 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Begin modifying the _metadata attribute of states within the task runner
  • Loading branch information
cicdw committed Jan 28, 2019
commit de0ad5837f234b59251c292584bd4862a3a996d4
2 changes: 1 addition & 1 deletion src/prefect/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class State:
def __init__(self, message: str = None, result: Any = None):
self.message = message
self.result = result
self.metadata = {} # type: dict
self._metadata = {"result": {}} # type: dict

def __repr__(self) -> str:
if self.message:
Expand Down
16 changes: 13 additions & 3 deletions src/prefect/engine/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NamedTuple,
Sized,
Tuple,
TYPE_CHECKING,
Union,
)

Expand All @@ -24,7 +25,6 @@
from prefect import config
from prefect.core import Edge, Task
from prefect.engine import signals
from prefect.engine.result_handlers import ResultHandler
from prefect.engine.runner import ENDRUN, Runner, call_state_handlers
from prefect.engine.state import (
CachedState,
Expand All @@ -45,6 +45,10 @@
)
from prefect.utilities.executors import main_thread_timeout, run_with_heartbeat

if TYPE_CHECKING:
from prefect.engine.result_handlers import ResultHandler


TaskRunnerInitializeResult = NamedTuple(
"TaskRunnerInitializeResult",
[
Expand Down Expand Up @@ -87,7 +91,7 @@ def __init__(
self,
task: Task,
state_handlers: Iterable[Callable] = None,
result_handler: ResultHandler = None,
result_handler: "ResultHandler" = None,
):
self.task = task
self.result_handler = (
Expand Down Expand Up @@ -314,7 +318,13 @@ def run(
)
)

## finally, update state metadata attribute with information about how to handle this state's data
## finally, update state _metadata attribute with information about how to handle this state's data
from prefect.serialization.result_handlers import ResultHandlerSchema

state._metadata["result"].setdefault(
"result_handler", ResultHandlerSchema().dump(self.result_handler)
)

return state

@call_state_handlers
Expand Down
10 changes: 5 additions & 5 deletions src/prefect/serialization/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

class ResultHandlerField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
if hasattr(obj, "metadata"):
is_raw = obj.metadata.get(attr, {}).get("raw", True)
if hasattr(obj, "_metadata"):
is_raw = obj._metadata.get(attr, {}).get("raw", True)
# "raw" results are never serialized
if is_raw:
value = None
Expand All @@ -40,14 +40,14 @@ class Meta:
object_class = state.State

message = fields.String(allow_none=True)
metadata = fields.Dict(keys=fields.Str())
_metadata = fields.Dict(keys=fields.Str())
result = ResultHandlerField(allow_none=True)

@post_load
def create_object(self, data):
metadata = data.pop("metadata", {})
_metadata = data.pop("_metadata", {})
base_obj = super().create_object(data)
base_obj.metadata = metadata
base_obj._metadata = _metadata
return base_obj


Expand Down
8 changes: 5 additions & 3 deletions tests/engine/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ def test_serialize_and_deserialize_with_metadata():
cached_result=dict(hi=5, bye=6),
cached_result_expiration=now,
)
cached.metadata.update(cached_inputs=dict(raw=False), cached_result=dict(raw=False))
cached._metadata.update(
cached_inputs=dict(raw=False), cached_result=dict(raw=False)
)
state = Success(result=dict(hi=5, bye=6), cached=cached)
state.metadata.update(dict(result=dict(raw=False)))
state._metadata.update(dict(result=dict(raw=False)))
serialized = state.serialize()
new_state = State.deserialize(serialized)
assert isinstance(new_state, Success)
Expand All @@ -158,7 +160,7 @@ def test_serialize_and_deserialize_with_metadata():

def test_serialization_of_cached_inputs():
state = Pending(cached_inputs=dict(hi=5, bye=6))
state.metadata.update(cached_inputs=dict(raw=False))
state._metadata.update(cached_inputs=dict(raw=False))
serialized = state.serialize()
new_state = State.deserialize(serialized)
assert isinstance(new_state, Pending)
Expand Down
22 changes: 11 additions & 11 deletions tests/serialization/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_nested_serializes_without_derived_result_attrs_if_raw(self):

def test_serializes_with_result_if_not_raw(self):
s = state.Success(message="hi", result=42)
s.metadata.update(result=dict(raw=False))
s._metadata.update(result=dict(raw=False))
schema = StateSchema()
serialized = schema.dump(s)
print(serialized)
Expand All @@ -151,7 +151,7 @@ def test_serializes_with_derived_result_attrs_if_not_raw(self):
cached_result={"x": {"y": {"z": 55}}},
cached_parameters=dict(three=3),
)
s.metadata.update(
s._metadata.update(
result=dict(raw=False),
cached_result=dict(raw=True),
cached_inputs=dict(raw=False),
Expand All @@ -172,7 +172,7 @@ def test_nested_serializes_with_derived_result_attrs_if_not_raw(self):
cached_result={"x": {"y": {"z": 55}}},
cached_parameters=dict(three=3),
)
s.metadata.update(
s._metadata.update(
result=dict(raw=False),
cached_result=dict(raw=False),
cached_inputs=dict(raw=False),
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_serialize_state(cls):
@pytest.mark.parametrize("cls", [s for s in all_states if s is not state.Mapped])
def test_serialize_state_with_metadata(cls):
state = cls(message="message", result=1)
state.metadata.update(result=dict(raw=False))
state._metadata.update(result=dict(raw=False))
serialized = StateSchema().dump(state)
assert isinstance(serialized, dict)
assert serialized["type"] == cls.__name__
Expand All @@ -215,8 +215,8 @@ def test_serialize_state_with_metadata(cls):
def test_serialize_mapped():
s = state.Success(message="1", result=1)
f = state.Failed(message="2", result=2)
s.metadata.update(result=dict(raw=False))
f.metadata.update(result=dict(raw=False))
s._metadata.update(result=dict(raw=False))
f._metadata.update(result=dict(raw=False))
serialized = StateSchema().dump(state.Mapped(message="message", map_states=[s, f]))
assert isinstance(serialized, dict)
assert serialized["type"] == "Mapped"
Expand All @@ -230,7 +230,7 @@ def test_serialize_mapped():
@pytest.mark.parametrize("cls", [s for s in all_states if s is not state.Mapped])
def test_deserialize_state(cls):
s = cls(message="message", result=1)
s.metadata.update(result=dict(raw=False))
s._metadata.update(result=dict(raw=False))
serialized = StateSchema().dump(s)
deserialized = StateSchema().load(serialized)
assert isinstance(deserialized, cls)
Expand Down Expand Up @@ -269,7 +269,7 @@ def test_deserialize_state_with_unknown_type_fails():

@pytest.mark.parametrize("state", complex_states())
def test_complex_state_attributes_are_handled(state):
state.metadata.update(
state._metadata.update(
result=dict(raw=False),
cached_result=dict(raw=False),
cached_parameters=dict(raw=False),
Expand All @@ -282,21 +282,21 @@ def test_complex_state_attributes_are_handled(state):

def test_result_must_be_valid_json():
s = state.Success(result={"x": {"y": {"z": 1}}})
s.metadata.update(result=dict(raw=False))
s._metadata.update(result=dict(raw=False))
serialized = StateSchema().dump(s)
assert serialized["result"] == s.result


def test_result_doesnt_raise_error_on_dump_if_raw():
s = state.Success(result={"x": {"y": {"z": lambda: 1}}})
s.metadata.update(result=dict(raw=True))
s._metadata.update(result=dict(raw=True))
serialized = StateSchema().dump(s)
assert serialized["result"] is None


def test_result_raises_error_on_dump_if_not_valid_json():
s = state.Success(result={"x": {"y": {"z": lambda: 1}}})
s.metadata.update(result=dict(raw=False))
s._metadata.update(result=dict(raw=False))
with pytest.raises(TypeError):
StateSchema().dump(s)

Expand Down
2 changes: 1 addition & 1 deletion versioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ def run(self):
self._versioneer_generated_versions = versions
# unless we update this, the command will keep using the old
# version
self.distribution.metadata.version = versions["version"]
self.distribution._metadata.version = versions["version"]
return _sdist.run(self)

def make_release_tree(self, base_dir, files):
Expand Down