Skip to content

Commit

Permalink
Different approach to removing RayGetError (#3471)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Dec 13, 2018
1 parent 20c7fad commit 0e00533
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 105 deletions.
10 changes: 10 additions & 0 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,16 @@ def __del__(self):
# there are ANY handles in scope in the process that created the actor,
# not just the first one.
worker = ray.worker.get_global_worker()
if (worker.mode == ray.worker.SCRIPT_MODE
and self._ray_actor_driver_id.id() != worker.worker_id):
# If the worker is a driver and driver id has changed because
# Ray was shut down re-initialized, the actor is already cleaned up
# and we don't need to send `__ray_terminate__` again.
logger.warn(
"Actor is garbage collected in the wrong driver." +
" Actor id = %s, class name = %s.", self._ray_actor_id,
self._ray_class_name)
return
if worker.connected and self._ray_original_handle:
# TODO(rkn): Should we be passing in the actor cursor as a
# dependency here?
Expand Down
183 changes: 80 additions & 103 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import redis
import signal
from six.moves import queue
import sys
import threading
import time
Expand Down Expand Up @@ -97,82 +98,34 @@ class RayTaskError(Exception):
traceback_str (str): The traceback from the exception.
"""

def __init__(self, function_name, exception, traceback_str):
def __init__(self, function_name, traceback_str):
"""Initialize a RayTaskError."""
self.function_name = function_name
if (isinstance(exception, RayGetError)
or isinstance(exception, RayGetArgumentError)):
self.exception = exception
if setproctitle:
self.proctitle = setproctitle.getproctitle()
else:
self.exception = None
self.proctitle = "ray_worker"
self.pid = os.getpid()
self.host = os.uname()[1]
self.function_name = function_name
self.traceback_str = traceback_str

def __str__(self):
"""Format a RayTaskError as a string."""
if self.traceback_str is None:
# This path is taken if getting the task arguments failed.
return ("Remote function {}{}{} failed with:\n\n{}".format(
colorama.Fore.RED, self.function_name, colorama.Fore.RESET,
self.exception))
else:
# This path is taken if the task execution failed.
return ("Remote function {}{}{} failed with:\n\n{}".format(
colorama.Fore.RED, self.function_name, colorama.Fore.RESET,
self.traceback_str))


class RayGetError(Exception):
"""An exception used when get is called on an output of a failed task.
Attributes:
objectid (lib.ObjectID): The ObjectID that get was called on.
task_error (RayTaskError): The RayTaskError object created by the
failed task.
"""

def __init__(self, objectid, task_error):
"""Initialize a RayGetError object."""
self.objectid = objectid
self.task_error = task_error

def __str__(self):
"""Format a RayGetError as a string."""
return ("Could not get objectid {}. It was created by remote function "
"{}{}{} which failed with:\n\n{}".format(
self.objectid, colorama.Fore.RED,
self.task_error.function_name, colorama.Fore.RESET,
self.task_error))


class RayGetArgumentError(Exception):
"""An exception used when a task's argument was produced by a failed task.
Attributes:
argument_index (int): The index (zero indexed) of the failed argument
in present task's remote function call.
function_name (str): The name of the function for the current task.
objectid (lib.ObjectID): The ObjectID that was passed in as the
argument.
task_error (RayTaskError): The RayTaskError object created by the
failed task.
"""

def __init__(self, function_name, argument_index, objectid, task_error):
"""Initialize a RayGetArgumentError object."""
self.argument_index = argument_index
self.function_name = function_name
self.objectid = objectid
self.task_error = task_error

def __str__(self):
"""Format a RayGetArgumentError as a string."""
return ("Failed to get objectid {} as argument {} for remote function "
"{}{}{}. It was created by remote function {}{}{} which "
"failed with:\n{}".format(
self.objectid, self.argument_index, colorama.Fore.RED,
self.function_name, colorama.Fore.RESET, colorama.Fore.RED,
self.task_error.function_name, colorama.Fore.RESET,
self.task_error))
lines = self.traceback_str.split("\n")
out = []
in_worker = False
for line in lines:
if line.startswith("Traceback "):
out.append("{}{}{} (pid={}, host={})".format(
colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET,
self.pid, self.host))
elif in_worker:
in_worker = False
elif "ray/worker.py" in line or "ray/function_manager.py" in line:
in_worker = True
else:
out.append(line)
return "\n".join(out)


class Worker(object):
Expand Down Expand Up @@ -449,7 +402,7 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
# TODO(ekl): the local scheduler could include relevant
# metadata in the task kill case for a better error message
invalid_error = RayTaskError(
"<unknown>", None,
"<unknown>",
"Invalid return value: likely worker died or was killed "
"while executing the task; check previous logs or dmesg "
"for errors.")
Expand Down Expand Up @@ -757,7 +710,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args):
passed by value.
Raises:
RayGetArgumentError: This exception is raised if a task that
RayTaskError: This exception is raised if a task that
created one of the arguments failed.
"""
arguments = []
Expand All @@ -766,10 +719,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args):
# get the object from the local object store
argument = self.get_object([arg])[0]
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that
# created this object failed, and we should propagate the
# error message here.
raise RayGetArgumentError(function_name, i, arg, argument)
raise argument
else:
# pass the argument by value
argument = arg
Expand Down Expand Up @@ -842,7 +792,7 @@ def _process_task(self, task, function_execution_info):
with profiling.profile("task:deserialize_arguments", worker=self):
arguments = self._get_arguments_for_execution(
function_name, args)
except (RayGetError, RayGetArgumentError) as e:
except RayTaskError as e:
self._handle_process_task_failure(function_id, function_name,
return_object_ids, e, None)
return
Expand Down Expand Up @@ -889,7 +839,7 @@ def _process_task(self, task, function_execution_info):

def _handle_process_task_failure(self, function_id, function_name,
return_object_ids, error, backtrace):
failure_object = RayTaskError(function_name, error, backtrace)
failure_object = RayTaskError(function_name, backtrace)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
]
Expand Down Expand Up @@ -1196,18 +1146,6 @@ def actor_handle_deserializer(serialized_obj):
local=True,
driver_id=driver_id,
class_id="ray.RayTaskError")
register_custom_serializer(
RayGetError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayGetError")
register_custom_serializer(
RayGetArgumentError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayGetArgumentError")
# Tell Ray to serialize lambdas with pickle.
register_custom_serializer(
type(lambda: 0),
Expand Down Expand Up @@ -1833,12 +1771,38 @@ def custom_excepthook(type, value, tb):

sys.excepthook = custom_excepthook

# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0

def print_error_messages_raylet(worker):
"""Print error messages in the background on the driver.
# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5

This runs in a separate thread on the driver and prints error messages in
the background.

def print_error_messages_raylet(task_error_queue):
"""Prints message received in the given output queue.
This checks periodically if any un-raised errors occured in the background.
"""

while True:
error, t = task_error_queue.get()
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
time.sleep(1)
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug("Suppressing error from worker: {}".format(error))
else:
logger.error(
"Possible unhandled error from worker: {}".format(error))


def listen_error_messages_raylet(worker, task_error_queue):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
tuples to the output queue.
"""
worker.error_message_pubsub_client = worker.redis_client.pubsub(
ignore_subscribe_messages=True)
Expand Down Expand Up @@ -1875,7 +1839,12 @@ def print_error_messages_raylet(worker):
continue

error_message = ray.utils.decode(error_data.ErrorMessage())
logger.error(error_message)
if (ray.utils.decode(
error_data.Type()) == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
else:
logger.error(error_message)

except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError,
Expand Down Expand Up @@ -2164,14 +2133,19 @@ def connect(info,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
t = threading.Thread(
q = queue.Queue()
listener = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, q))
printer = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(worker, ))
# Making the thread a daemon causes it to exit when the main thread
# exits.
t.daemon = True
t.start()
args=(q, ))
listener.daemon = True
listener.start()
printer.daemon = True
printer.start()

# If we are using the raylet code path and we are not in local mode, start
# a background thread to periodically flush profiling data to the GCS.
Expand Down Expand Up @@ -2399,19 +2373,22 @@ def get(object_ids, worker=global_worker):
# In LOCAL_MODE, ray.get is the identity operation (the input will
# actually be a value not an objectid).
return object_ids
global last_task_error_raise_time
if isinstance(object_ids, list):
values = worker.get_object(object_ids)
for i, value in enumerate(values):
if isinstance(value, RayTaskError):
raise RayGetError(object_ids[i], value)
last_task_error_raise_time = time.time()
raise value
return values
else:
value = worker.get_object([object_ids])[0]
if isinstance(value, RayTaskError):
# If the result is a RayTaskError, then the task that created
# this object failed, and we should propagate the error message
# here.
raise RayGetError(object_ids, value)
last_task_error_raise_time = time.time()
raise value
return value


Expand Down
2 changes: 1 addition & 1 deletion test/actor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def inc(self):
# Submit some new actor tasks.
x_ids = [actor.inc.remote() for _ in range(5)]
for x_id in x_ids:
with pytest.raises(ray.worker.RayGetError):
with pytest.raises(ray.worker.RayTaskError):
# There is some small chance that ray.get will actually
# succeed (if the object is transferred before the raylet
# dies).
Expand Down
2 changes: 1 addition & 1 deletion test/component_failures_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def ping(self):
for i, out in enumerate(children_out):
try:
ray.get(out)
except ray.worker.RayGetError:
except ray.worker.RayTaskError:
children[i] = Child.remote(death_probability)
# Remove a node. Any actor creation tasks that were forwarded to this
# node must be reconstructed.
Expand Down

0 comments on commit 0e00533

Please sign in to comment.