Skip to content

Commit

Permalink
Retry application-level errors (#18176)
Browse files Browse the repository at this point in the history
* Retry application-level errors

* Retry application-level errors

* Push retry message to the driver
  • Loading branch information
jjyao authored Sep 1, 2021
1 parent 673bf35 commit fbb3ac6
Show file tree
Hide file tree
Showing 26 changed files with 296 additions and 59 deletions.
4 changes: 2 additions & 2 deletions cpp/src/ray/runtime/task/native_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation,
invocation.args, options, &return_ids);
} else {
core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, options,
&return_ids, 1, std::make_pair(PlacementGroupID::Nil(), -1),
true, "");
&return_ids, 1, false,
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
}
return return_ids[0];
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/ray/runtime/task/task_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ Status TaskExecutor::ExecuteTask(
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes) {
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error) {
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
auto function_descriptor = ray_function.GetFunctionDescriptor();
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/ray/runtime/task/task_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class TaskExecutor {
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes);
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error);

virtual ~TaskExecutor(){};

Expand Down
17 changes: 13 additions & 4 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ cdef execute_task(
const c_vector[CObjectID] &c_arg_reference_ids,
const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns):
c_vector[shared_ptr[CRayObject]] *returns,
c_bool *is_application_level_error):

is_application_level_error[0] = False

worker = ray.worker.global_worker
manager = worker.function_actor_manager
Expand Down Expand Up @@ -579,6 +582,9 @@ cdef execute_task(
except KeyboardInterrupt as e:
raise TaskCancelledError(
core_worker.get_current_task_id())
except Exception as e:
is_application_level_error[0] = True
raise e
if c_return_ids.size() == 1:
outputs = (outputs,)
# Check for a cancellation that was called when the function
Expand Down Expand Up @@ -656,15 +662,17 @@ cdef CRayStatus task_execution_handler(
const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil:
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
c_bool *is_application_level_error) nogil:
with gil, disable_client_hook():
try:
try:
# The call to execute_task should never raise an exception. If
# it does, that indicates that there was an internal error.
execute_task(task_type, task_name, ray_function, c_resources,
c_args, c_arg_reference_ids, c_return_ids,
debugger_breakpoint, returns)
debugger_breakpoint, returns,
is_application_level_error)
except Exception as e:
sys_exit = SystemExit()
if isinstance(e, RayActorError) and \
Expand Down Expand Up @@ -1318,6 +1326,7 @@ cdef class CoreWorker:
int num_returns,
resources,
int max_retries,
c_bool retry_exceptions,
PlacementGroupID placement_group_id,
int64_t placement_group_bundle_index,
c_bool placement_group_capture_child_tasks,
Expand Down Expand Up @@ -1354,7 +1363,7 @@ cdef class CoreWorker:
b"",
c_serialized_runtime_env,
c_override_environment_variables),
&return_ids, max_retries,
&return_ids, max_retries, retry_exceptions,
c_pair[CPlacementGroupID, int64_t](
c_placement_group_id, placement_group_bundle_index),
placement_group_capture_child_tasks,
Expand Down
3 changes: 2 additions & 1 deletion python/ray/cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def java_function(class_name, function_name):
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None, # max_retries
None, # max_retries,
None, # retry_exceptions,
None) # runtime_env


Expand Down
4 changes: 3 additions & 1 deletion python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_vector[unique_ptr[CTaskArg]] &args,
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
int max_retries,
c_bool retry_exceptions,
c_pair[CPlacementGroupID, int64_t] placement_options,
c_bool placement_group_capture_child_tasks,
c_string debugger_breakpoint)
Expand Down Expand Up @@ -280,7 +281,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer]
&creation_task_exception_pb_bytes) nogil
&creation_task_exception_pb_bytes,
c_bool *is_application_level_error) nogil
) task_execution_callback
(void(const CWorkerID &) nogil) on_worker_shutdown
(CRayStatus() nogil) check_signals
Expand Down
16 changes: 15 additions & 1 deletion python/ray/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# Normal tasks may be retried on failure this many times.
# TODO(swang): Allow this to be set globally for an application.
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS = False

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +54,9 @@ class RemoteFunction:
of this remote function.
_max_calls: The number of times a worker can execute this function
before exiting.
_max_retries: The number of times this task may be retried
on worker failure.
_retry_exceptions: Whether application-level errors should be retried.
_runtime_env: The runtime environment for this task.
_decorator: An optional decorator that should be applied to the remote
function invocation (as opposed to the function execution) before
Expand All @@ -73,7 +77,7 @@ class RemoteFunction:
def __init__(self, language, function, function_descriptor, num_cpus,
num_gpus, memory, object_store_memory, resources,
accelerator_type, num_returns, max_calls, max_retries,
runtime_env):
retry_exceptions, runtime_env):
if inspect.iscoroutinefunction(function):
raise ValueError("'async def' should not be used for remote "
"tasks. You can wrap the async function with "
Expand All @@ -100,6 +104,9 @@ def __init__(self, language, function, function_descriptor, num_cpus,
if max_calls is None else max_calls)
self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
if max_retries is None else max_retries)
self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS
if retry_exceptions is None else
retry_exceptions)
self._runtime_env = runtime_env
self._decorator = getattr(function, "__ray_invocation_decorator__",
None)
Expand Down Expand Up @@ -131,6 +138,7 @@ def options(self,
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
Expand Down Expand Up @@ -168,6 +176,7 @@ def remote(self, *args, **kwargs):
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
Expand All @@ -191,6 +200,7 @@ def _remote(self,
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
Expand All @@ -211,6 +221,7 @@ def _remote(self,
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
Expand Down Expand Up @@ -251,6 +262,8 @@ def _remote(self,
num_returns = self._num_returns
if max_retries is None:
max_retries = self._max_retries
if retry_exceptions is None:
retry_exceptions = self._retry_exceptions

if placement_group_capture_child_tasks is None:
placement_group_capture_child_tasks = (
Expand Down Expand Up @@ -307,6 +320,7 @@ def invocation(args, kwargs):
num_returns,
resources,
max_retries,
retry_exceptions,
placement_group.id,
placement_group_bundle_index,
placement_group_capture_child_tasks,
Expand Down
62 changes: 62 additions & 0 deletions python/ray/tests/test_failure_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,68 @@
run_string_as_driver)


def test_retry_system_level_error(ray_start_regular):
@ray.remote
class Counter:
def __init__(self):
self.value = 0

def increment(self):
self.value += 1
return self.value

@ray.remote(max_retries=1)
def func(counter):
count = counter.increment.remote()
if ray.get(count) == 1:
import os
os._exit(0)
else:
return 1

counter1 = Counter.remote()
r1 = func.remote(counter1)
assert ray.get(r1) == 1

counter2 = Counter.remote()
r2 = func.options(max_retries=0).remote(counter2)
with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(r2)


def test_retry_application_level_error(ray_start_regular):
@ray.remote
class Counter:
def __init__(self):
self.value = 0

def increment(self):
self.value += 1
return self.value

@ray.remote(max_retries=1, retry_exceptions=True)
def func(counter):
count = counter.increment.remote()
if ray.get(count) == 1:
raise ValueError()
else:
return 2

counter1 = Counter.remote()
r1 = func.remote(counter1)
assert ray.get(r1) == 2

counter2 = Counter.remote()
r2 = func.options(max_retries=0).remote(counter2)
with pytest.raises(ValueError):
ray.get(r2)

counter3 = Counter.remote()
r3 = func.options(retry_exceptions=False).remote(counter3)
with pytest.raises(ValueError):
ray.get(r3)


def test_connect_with_disconnected_node(shutdown_only):
config = {
"num_heartbeats_timeout": 50,
Expand Down
1 change: 1 addition & 0 deletions python/ray/util/client/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"max_retries": (int, lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 "
"or a positive integer"),
"retry_exceptions": (),
"max_concurrency": (),
"name": (),
"namespace": (),
Expand Down
23 changes: 18 additions & 5 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def init(
a raylet, a plasma store, a plasma manager, and some workers.
It will also kill these processes when Python exits. If the driver
is running on a node in a Ray cluster, using `auto` as the value
tells the driver to detect the the cluster, removing the need to
tells the driver to detect the cluster, removing the need to
specify a specific node address. If the environment variable
`RAY_ADDRESS` is defined and the address is None or "auto", Ray
will set `address` to `RAY_ADDRESS`.
Expand Down Expand Up @@ -1924,7 +1924,8 @@ def make_decorator(num_returns=None,
max_restarts=None,
max_task_retries=None,
runtime_env=None,
worker=None):
worker=None,
retry_exceptions=None):
def decorator(function_or_class):
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
Expand Down Expand Up @@ -1953,12 +1954,19 @@ def decorator(function_or_class):
return ray.remote_function.RemoteFunction(
Language.PYTHON, function_or_class, None, num_cpus, num_gpus,
memory, object_store_memory, resources, accelerator_type,
num_returns, max_calls, max_retries, runtime_env)
num_returns, max_calls, max_retries, retry_exceptions,
runtime_env)

if inspect.isclass(function_or_class):
if num_returns is not None:
raise TypeError("The keyword 'num_returns' is not "
"allowed for actors.")
if max_retries is not None:
raise TypeError("The keyword 'max_retries' is not "
"allowed for actors.")
if retry_exceptions is not None:
raise TypeError("The keyword 'retry_exceptions' is not "
"allowed for actors.")
if max_calls is not None:
raise TypeError("The keyword 'max_calls' is not "
"allowed for actors.")
Expand Down Expand Up @@ -2082,6 +2090,9 @@ def method(self):
this actor or task and its children. See
:ref:`runtime-environments` for detailed documentation. This API is
in beta and may change before becoming stable.
retry_exceptions (bool): Only for *remote functions*. This specifies
whether application-level errors should be retried
up to max_retries times.
override_environment_variables (Dict[str, str]): (Deprecated in Ray
1.4.0, will be removed in Ray 1.6--please use the ``env_vars``
field of :ref:`runtime-environments` instead.) This specifies
Expand All @@ -2102,7 +2113,7 @@ def method(self):
valid_kwargs = [
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory",
"resources", "accelerator_type", "max_calls", "max_restarts",
"max_task_retries", "max_retries", "runtime_env"
"max_task_retries", "max_retries", "runtime_env", "retry_exceptions"
]
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
Expand Down Expand Up @@ -2135,6 +2146,7 @@ def method(self):
object_store_memory = kwargs.get("object_store_memory")
max_retries = kwargs.get("max_retries")
runtime_env = kwargs.get("runtime_env")
retry_exceptions = kwargs.get("retry_exceptions")

return make_decorator(
num_returns=num_returns,
Expand All @@ -2149,4 +2161,5 @@ def method(self):
max_task_retries=max_task_retries,
max_retries=max_retries,
runtime_env=runtime_env,
worker=worker)
worker=worker,
retry_exceptions=retry_exceptions)
6 changes: 6 additions & 0 deletions src/ray/common/task/task_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class TaskSpecBuilder {
return *this;
}

TaskSpecBuilder &SetNormalTaskSpec(int max_retries, bool retry_exceptions) {
message_->set_max_retries(max_retries);
message_->set_retry_exceptions(retry_exceptions);
return *this;
}

/// Set the driver attributes of the task spec.
/// See `common.proto` for meaning of the arguments.
///
Expand Down
Loading

0 comments on commit fbb3ac6

Please sign in to comment.