-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TaskCancellation #7669
TaskCancellation #7669
Changes from 35 commits
169c540
682c5b5
2d020ba
1958e05
4fdeb5a
40b2bb5
d1295c3
269a3b1
028d9f7
a4b58e5
33ad6a1
4f7eec7
cc3ca28
bd47066
c0b5ab4
58c8bed
bae435f
9ab039d
d85496d
d0ba816
652a0fe
daac610
b050b28
b0457a3
18b3dbc
b813faf
112d7d8
ff8bbd3
af35898
b015c51
616f487
2361273
9dba915
9a43056
68a6458
46545e1
7308225
1f95492
9a0d120
82a6248
3270f92
794f146
e43ea33
9beea80
2a789f5
8c75a83
8f7bdfe
6f1ef56
f7fb69f
e8cd360
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ import logging | |
import os | ||
import pickle | ||
import sys | ||
import _thread | ||
import setproctitle | ||
|
||
from libc.stdint cimport ( | ||
int32_t, | ||
|
@@ -90,6 +92,7 @@ from ray.exceptions import ( | |
RayTaskError, | ||
ObjectStoreFullError, | ||
RayTimeoutError, | ||
RayCancellationError | ||
) | ||
from ray.utils import decode | ||
import gc | ||
|
@@ -452,14 +455,22 @@ cdef execute_task( | |
actor_title = "{}({}, {})".format( | ||
class_name, repr(args), repr(kwargs)) | ||
core_worker.set_actor_title(actor_title.encode("utf-8")) | ||
# Ensure no previous signals are still around | ||
check_signals() | ||
# Execute the task. | ||
with ray.worker._changeproctitle(title, next_title): | ||
with core_worker.profile_event(b"task:execute"): | ||
task_exception = True | ||
outputs = function_executor(*args, **kwargs) | ||
task_exception = False | ||
try: | ||
outputs = function_executor(*args, **kwargs) | ||
task_exception = False | ||
except KeyboardInterrupt as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't this get raised outside of the try block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I believe this can be raised on any line of python code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can only be raised in regular python code. When in compiled Cython or C++ code, the interrupts can be observed with |
||
raise RayCancellationError( | ||
core_worker.get_current_task_id()) | ||
if c_return_ids.size() == 1: | ||
outputs = (outputs,) | ||
# Ensure no signals are still around | ||
check_signals() | ||
ijrsvt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Store the outputs in the object store. | ||
with core_worker.profile_event(b"task:store_outputs"): | ||
core_worker.store_task_outputs( | ||
|
@@ -551,6 +562,14 @@ cdef void async_plasma_callback(CObjectID object_id, | |
event_handler._loop.call_soon_threadsafe( | ||
event_handler._complete_future, obj_id) | ||
|
||
cdef c_bool kill_main_task() nogil: | ||
with gil: | ||
if setproctitle.getproctitle() != "ray::IDLE": | ||
ijrsvt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_thread.interrupt_main() | ||
return True | ||
return False | ||
|
||
|
||
cdef CRayStatus check_signals() nogil: | ||
with gil: | ||
try: | ||
|
@@ -657,6 +676,7 @@ cdef class CoreWorker: | |
options.ref_counting_enabled = True | ||
options.is_local_mode = local_mode | ||
options.num_workers = 1 | ||
options.kill_main = kill_main_task | ||
|
||
CCoreWorkerProcess.Initialize(options) | ||
|
||
|
@@ -952,6 +972,17 @@ cdef class CoreWorker: | |
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor( | ||
c_actor_id, True, no_reconstruction)) | ||
|
||
def kill_task(self, ObjectID object_id, c_bool force_kill): | ||
cdef: | ||
CObjectID c_object_id = object_id.native() | ||
CRayStatus status = CRayStatus.OK() | ||
|
||
status = CCoreWorkerProcess.GetCoreWorker().KillTask( | ||
c_object_id, force_kill) | ||
|
||
if not status.ok(): | ||
raise ValueError(status.message().decode()) | ||
|
||
def resource_ids(self): | ||
cdef: | ||
ResourceMappingType resource_mapping = ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import pytest | ||
import ray | ||
import random | ||
import sys | ||
import time | ||
from ray.exceptions import RayTaskError, RayTimeoutError, RayCancellationError | ||
from ray.test_utils import SignalActor | ||
|
||
|
||
@pytest.mark.parametrize("use_force", [True, False]) | ||
def test_cancel_chain(ray_start_regular, use_force): | ||
"""A helper method for chain of events tests""" | ||
signaler = SignalActor.remote() | ||
|
||
@ray.remote | ||
def wait_for(t): | ||
return ray.get(t[0]) | ||
|
||
obj1 = wait_for.remote([signaler.wait.remote()]) | ||
obj2 = wait_for.remote([obj1]) | ||
obj3 = wait_for.remote([obj2]) | ||
obj4 = wait_for.remote([obj3]) | ||
|
||
assert len(ray.wait([obj1], timeout=.1)[0]) == 0 | ||
ray.cancel(obj1, use_force) | ||
for ob in [obj1, obj2, obj3, obj4]: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(ob) | ||
|
||
signaler2 = SignalActor.remote() | ||
obj1 = wait_for.remote([signaler2.wait.remote()]) | ||
obj2 = wait_for.remote([obj1]) | ||
obj3 = wait_for.remote([obj2]) | ||
obj4 = wait_for.remote([obj3]) | ||
|
||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0 | ||
ray.cancel(obj3, use_force) | ||
for ob in [obj3, obj4]: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(ob) | ||
|
||
with pytest.raises(RayTimeoutError): | ||
ray.get(obj1, timeout=.1) | ||
|
||
with pytest.raises(RayTimeoutError): | ||
ray.get(obj2, timeout=.1) | ||
|
||
signaler2.send.remote() | ||
ray.get(obj1, timeout=.1) | ||
|
||
|
||
@pytest.mark.parametrize("use_force", [True, False]) | ||
def test_cancel_multiple_dependents(ray_start_regular, use_force): | ||
"""A helper method for multiple waiters on events tests""" | ||
signaler = SignalActor.remote() | ||
|
||
@ray.remote | ||
def wait_for(t): | ||
return ray.get(t[0]) | ||
|
||
head = wait_for.remote([signaler.wait.remote()]) | ||
deps = [] | ||
for _ in range(3): | ||
deps.append(wait_for.remote([head])) | ||
|
||
assert len(ray.wait([head], timeout=.1)[0]) == 0 | ||
ray.cancel(head, use_force) | ||
for d in deps: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(d) | ||
|
||
head2 = wait_for.remote([signaler.wait.remote()]) | ||
|
||
deps2 = [] | ||
for _ in range(3): | ||
deps2.append(wait_for.remote([head])) | ||
|
||
for d in deps2: | ||
ray.cancel(d, use_force) | ||
|
||
for d in deps2: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(d) | ||
|
||
signaler.send.remote() | ||
ray.get(head2, timeout=1) | ||
|
||
|
||
@pytest.mark.parametrize("use_force", [True, False]) | ||
def test_single_cpu_cancel(shutdown_only, use_force): | ||
ray.init(num_cpus=1) | ||
signaler = SignalActor.remote() | ||
|
||
@ray.remote | ||
def wait_for(t): | ||
return ray.get(t[0]) | ||
|
||
obj1 = wait_for.remote([signaler.wait.remote()]) | ||
obj2 = wait_for.remote([obj1]) | ||
obj3 = wait_for.remote([obj2]) | ||
indep = wait_for.remote([signaler.wait.remote()]) | ||
|
||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0 | ||
ray.cancel(obj3, use_force) | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(obj3, 0.1) | ||
|
||
ray.cancel(obj1, use_force) | ||
|
||
for d in [obj1, obj2]: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(d) | ||
|
||
signaler.send.remote() | ||
ray.get(indep) | ||
|
||
|
||
@pytest.mark.parametrize("use_force", [True, False]) | ||
def test_comprehensive(ray_start_regular, use_force): | ||
signaler = SignalActor.remote() | ||
|
||
@ray.remote | ||
def wait_for(t): | ||
ray.get(t[0]) | ||
return "Result" | ||
|
||
@ray.remote | ||
def combine(a, b): | ||
return str(a) + str(b) | ||
|
||
a = wait_for.remote([signaler.wait.remote()]) | ||
b = wait_for.remote([signaler.wait.remote()]) | ||
combo = combine.remote(a, b) | ||
a2 = wait_for.remote([a]) | ||
|
||
assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0 | ||
|
||
ray.cancel(a, use_force) | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(a, 1) | ||
|
||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(a2, 1) | ||
|
||
signaler.send.remote() | ||
|
||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(combo, 10) | ||
|
||
|
||
@pytest.mark.parametrize("use_force", [True, False]) | ||
def test_stress(shutdown_only, use_force): | ||
ray.init(num_cpus=1) | ||
|
||
@ray.remote | ||
def infinite_sleep(y): | ||
if y: | ||
while True: | ||
time.sleep(1 / 10) | ||
|
||
first = infinite_sleep.remote(True) | ||
|
||
sleep_or_no = [random.randint(0, 1) for _ in range(100)] | ||
tasks = [infinite_sleep.remote(i) for i in sleep_or_no] | ||
cancelled = set() | ||
for t in tasks: | ||
if random.random() > 0.5: | ||
ray.cancel(t, use_force) | ||
cancelled.add(t) | ||
|
||
ray.cancel(first, use_force) | ||
cancelled.add(first) | ||
|
||
for done in cancelled: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(done, 10) | ||
|
||
for indx in range(len(tasks)): | ||
t = tasks[indx] | ||
if sleep_or_no[indx]: | ||
ray.cancel(t, use_force) | ||
cancelled.add(t) | ||
if t in cancelled: | ||
with pytest.raises((RayTaskError, RayCancellationError)): | ||
ray.get(t, 10) | ||
else: | ||
ray.get(t) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(pytest.main(["-v", __file__])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What signals is this checking for? If it's the keyboardinterrupt from interrupt_main, don't we need to catch that and handle it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was initially just to 'clear' any signals, but I can make it also handle cancellation.