Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel state #1758

Merged
merged 11 commits into from
Nov 21, 2019
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Features

- None
- Add graceful cancellation hooks to Flow and Task runners - [#1757](https://github.com/PrefectHQ/prefect/pull/1757)

### Enhancements

Expand Down
334 changes: 167 additions & 167 deletions docs/.vuepress/public/state_inheritance_diagram.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/outline.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ classes = [
"Mapped",
"Skipped",
"Failed",
"Aborted",
"Cancelled",
"TriggerFailed",
"TimedOut",
]
Expand Down
64 changes: 39 additions & 25 deletions src/prefect/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import signal
import time
from typing import Any, Iterable, Union
from contextlib import contextmanager
from typing import Any, Callable, Generator, Iterable, Union

import pendulum

Expand All @@ -23,6 +24,24 @@
"""


def _exit(agent: "Agent") -> Callable:
def _exit_handler(*args: Any, **kwargs: Any) -> None:
agent.is_running = False
agent.logger.info("Keyboard Interrupt received: Agent is shutting down.")

return _exit_handler


@contextmanager
def keyboard_handler(agent: "Agent") -> Generator:
original = signal.getsignal(signal.SIGINT)
try:
signal.signal(signal.SIGINT, _exit(agent))
yield
finally:
signal.signal(signal.SIGINT, original)


class Agent:
"""
Base class for Agents. Information on using the Prefect agents can be found at
Expand Down Expand Up @@ -63,14 +82,6 @@ def __init__(self, name: str = None, labels: Iterable[str] = None) -> None:
logger.addHandler(ch)

self.logger = logger
self.add_signal_handlers()

def add_signal_handlers(self) -> None:
def _exit(*args: Any, **kwargs: Any) -> None:
self.is_running = False
self.logger.info("Keyboard Interrupt received: Agent is shutting down.")

signal.signal(signal.SIGINT, _exit)

def _verify_token(self, token: str) -> None:
"""
Expand Down Expand Up @@ -98,26 +109,29 @@ def start(self) -> None:
The main entrypoint to the agent. This function loops and constantly polls for
new flow runs to deploy
"""
self.is_running = True
tenant_id = self.agent_connect()
with keyboard_handler(self):
self.is_running = True
tenant_id = self.agent_connect()

# Loop intervals for query sleep backoff
loop_intervals = {0: 0.25, 1: 0.5, 2: 1.0, 3: 2.0, 4: 4.0, 5: 8.0, 6: 10.0}
# Loop intervals for query sleep backoff
loop_intervals = {0: 0.25, 1: 0.5, 2: 1.0, 3: 2.0, 4: 4.0, 5: 8.0, 6: 10.0}

index = 0
while self.is_running:
self.heartbeat()
index = 0
while self.is_running:
self.heartbeat()

runs = self.agent_process(tenant_id)
if runs:
index = 0
elif index < max(loop_intervals.keys()):
index += 1
runs = self.agent_process(tenant_id)
if runs:
index = 0
elif index < max(loop_intervals.keys()):
index += 1

self.logger.debug(
"Next query for flow runs in {} seconds".format(loop_intervals[index])
)
time.sleep(loop_intervals[index])
self.logger.debug(
"Next query for flow runs in {} seconds".format(
loop_intervals[index]
)
)
time.sleep(loop_intervals[index])

def agent_connect(self) -> str:
"""
Expand Down
9 changes: 8 additions & 1 deletion src/prefect/engine/cloud/flow_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import _thread
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -63,8 +64,14 @@ def __init__(self, flow: Flow, state_handlers: Iterable[Callable] = None) -> Non

def _heartbeat(self) -> bool:
try:
flow_run_id = prefect.context.get("flow_run_id")
# use empty string for testing purposes
flow_run_id = prefect.context.get("flow_run_id", "") # type: str
self.client.update_flow_run_heartbeat(flow_run_id)
query = 'query{flow_run_by_pk(id: "' + flow_run_id + '"){state}}'
state = self.client.graphql(query).data.flow_run_by_pk.state
if state == "Cancelled":
_thread.interrupt_main()
return False
return True
except Exception as exc:
self.logger.exception(
Expand Down
9 changes: 9 additions & 0 deletions src/prefect/engine/cloud/task_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import datetime
import _thread
import time
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -63,6 +64,14 @@ def _heartbeat(self) -> bool:
try:
task_run_id = self.task_run_id # type: ignore
self.client.update_task_run_heartbeat(task_run_id) # type: ignore

# use empty string for testing purposes
flow_run_id = prefect.context.get("flow_run_id", "") # type: str
query = 'query{flow_run_by_pk(id: "' + flow_run_id + '"){state}}'
state = self.client.graphql(query).data.flow_run_by_pk.state
if state == "Cancelled":
_thread.interrupt_main()
return False
return True
except Exception as exc:
self.logger.exception(
Expand Down
5 changes: 5 additions & 0 deletions src/prefect/engine/flow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from prefect.engine.result_handlers import ConstantResultHandler
from prefect.engine.runner import ENDRUN, Runner, call_state_handlers
from prefect.engine.state import (
Cancelled,
Failed,
Mapped,
Pending,
Expand Down Expand Up @@ -258,6 +259,10 @@ def run(
except ENDRUN as exc:
state = exc.state

except KeyboardInterrupt:
self.logger.exception("Interrupt signal raised, cancelling Flow run.")
state = Cancelled(message="Interrupt signal raised, cancelling flow run.")

# All other exceptions are trapped and turned into Failed states
except Exception as exc:
self.logger.exception(
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,9 @@ def __init__(
self.cached_inputs = cached_inputs


class Aborted(Failed):
class Cancelled(Failed):
"""
Finished state indicating that a user aborted the flow run manually.
Finished state indicating that a user cancelled the flow run manually, mid-run.

Args:
- message (str or Exception, optional): Defaults to `None`. A message about the
Expand Down
6 changes: 6 additions & 0 deletions src/prefect/engine/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from prefect.engine.runner import ENDRUN, Runner, call_state_handlers
from prefect.engine.state import (
Cached,
Cancelled,
Failed,
Looped,
Mapped,
Expand Down Expand Up @@ -869,6 +870,11 @@ def get_task_run_state(
self.task.run, timeout=self.task.timeout, **raw_inputs
)

except KeyboardInterrupt:
self.logger.exception("Interrupt signal raised, cancelling task run.")
state = Cancelled(message="Interrupt signal raised, cancelling task run.")
return state

# inform user of timeout
except TimeoutError as exc:
if prefect.context.get("raise_on_exception"):
Expand Down
6 changes: 3 additions & 3 deletions src/prefect/serialization/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ class Meta:
)


class AbortedSchema(FailedSchema):
class CancelledSchema(FailedSchema):
class Meta:
object_class = state.Aborted
object_class = state.Cancelled


class TimedOutSchema(FinishedSchema):
Expand Down Expand Up @@ -203,7 +203,7 @@ class StateSchema(OneOfSchema):

# map class name to schema
type_schemas = {
"Aborted": AbortedSchema,
"Cancelled": CancelledSchema,
"Cached": CachedSchema,
"ClientFailed": ClientFailedSchema,
"Failed": FailedSchema,
Expand Down
16 changes: 16 additions & 0 deletions tests/core/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from prefect.core.flow import Flow
from prefect.core.task import Parameter, Task
from prefect.engine.cache_validators import all_inputs, partial_inputs_only
from prefect.engine.executors import LocalExecutor
from prefect.engine.result_handlers import LocalResultHandler, ResultHandler
from prefect.engine.signals import PrefectError, FAIL, LOOP
from prefect.engine.state import (
Cancelled,
Failed,
Finished,
Mapped,
Expand Down Expand Up @@ -2132,6 +2134,20 @@ def record_start_time():
f.run()
assert REPORTED_START_TIMES == start_times

def test_flow_dot_run_handles_keyboard_signals_gracefully(self):
class BadExecutor(LocalExecutor):
def submit(self, *args, **kwargs):
raise KeyboardInterrupt

@task
def do_something():
pass

f = Flow("test", tasks=[do_something])
state = f.run(executor=BadExecutor())
assert isinstance(state, Cancelled)
assert "interrupt" in state.message.lower()


class TestFlowDeploy:
@pytest.mark.parametrize(
Expand Down
26 changes: 26 additions & 0 deletions tests/engine/cloud/test_cloud_flow_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import time
import uuid
from box import Box
from datetime import timedelta
from unittest.mock import MagicMock

Expand All @@ -18,6 +19,7 @@
SecretResultHandler,
)
from prefect.engine.state import (
Cancelled,
Failed,
Finished,
Pending,
Expand Down Expand Up @@ -594,3 +596,27 @@ def get_flow_run_info(self, *args, **kwargs):

task_run_ids = [c["taskRunId"] for c in logs if c["taskRunId"]]
assert task_run_ids == ["TESTME"] * 3


def test_db_cancelled_states_interrupt_flow_run(client, monkeypatch):
calls = dict(count=0)

def heartbeat_counter(*args, **kwargs):
if calls["count"] == 3:
return Box(dict(data=dict(flow_run_by_pk=dict(state="Cancelled"))))
calls["count"] += 1
return Box(dict(data=dict(flow_run_by_pk=dict(state="Running"))))

client.graphql = heartbeat_counter

@prefect.task
def sleeper():
time.sleep(3)

f = prefect.Flow("test", tasks=[sleeper])

with set_temporary_config({"cloud.heartbeat_interval": 0.025}):
state = CloudFlowRunner(flow=f).run(return_tasks=[sleeper])

assert isinstance(state, Cancelled)
assert "interrupt" in state.message.lower()
24 changes: 24 additions & 0 deletions tests/engine/cloud/test_cloud_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import time
import uuid
from box import Box
from unittest.mock import MagicMock

import cloudpickle
Expand All @@ -20,6 +21,7 @@
from prefect.engine.signals import LOOP
from prefect.engine.state import (
Cached,
Cancelled,
ClientFailed,
Failed,
Finished,
Expand Down Expand Up @@ -982,3 +984,25 @@ def tagged_task(x):

# ensures result handler was called and persisted
assert calls[2]["state"].cached_inputs["x"].safe_value.value == "42"


def test_db_cancelled_states_interrupt_task_run(client, monkeypatch):
calls = dict(count=0)

def heartbeat_counter(*args, **kwargs):
if calls["count"] == 3:
return Box(dict(data=dict(flow_run_by_pk=dict(state="Cancelled"))))
calls["count"] += 1
return Box(dict(data=dict(flow_run_by_pk=dict(state="Running"))))

client.graphql = heartbeat_counter

@prefect.task
def sleeper():
time.sleep(3)

with set_temporary_config({"cloud.heartbeat_interval": 0.025}):
state = CloudTaskRunner(task=sleeper).run()

assert isinstance(state, Cancelled)
assert "interrupt" in state.message.lower()
8 changes: 4 additions & 4 deletions tests/engine/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from prefect.engine.result import NoResult, Result, SafeResult
from prefect.engine.result_handlers import JSONResultHandler, LocalResultHandler
from prefect.engine.state import (
Aborted,
Cancelled,
Cached,
ClientFailed,
Failed,
Expand Down Expand Up @@ -359,8 +359,8 @@ def test_success_is_finished(self):
def test_failed_is_finished(self):
assert issubclass(Failed, Finished)

def test_aborted_is_failed(self):
assert issubclass(Aborted, Failed)
def test_cancelled_is_failed(self):
assert issubclass(Cancelled, Failed)

def test_trigger_failed_is_finished(self):
assert issubclass(TriggerFailed, Finished)
Expand All @@ -384,7 +384,7 @@ def test_trigger_failed_is_failed(self):
@pytest.mark.parametrize(
"state_check",
[
dict(state=Aborted(), assert_true={"is_finished", "is_failed"}),
dict(state=Cancelled(), assert_true={"is_finished", "is_failed"}),
dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}),
dict(state=ClientFailed(), assert_true={"is_meta_state"}),
dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
Expand Down