Skip to content

Commit

Permalink
better error messages when composing remote functions (ray-project#339)
Browse files Browse the repository at this point in the history
Better error messages when composing remote functions
  • Loading branch information
robertnishihara authored and pcmoritz committed Aug 3, 2016
1 parent 07baf44 commit de200ff
Showing 1 changed file with 171 additions and 53 deletions.
224 changes: 171 additions & 53 deletions lib/python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,61 +28,168 @@
PYTHON_MODE = 2
SILENT_MODE = 3 # This is only used during testing.

class RayFailedObject(object):
class RayTaskError(Exception):
"""An object used internally to represent a task that threw an exception.
If a task throws an exception during execution, a RayFailedObject is stored in
the object store for each of the tasks outputs. When an object is retrieved
from the object store, the Python method that retrieved it should check to see
if the object is a RayFailedObject and if it is then an exception should be
thrown containing the error message.
If a task throws an exception during execution, a RayTaskError is stored in
the object store for each of the task's outputs. When an object is retrieved
from the object store, the Python method that retrieved it checks to see if
the object is a RayTaskError and if it is then an exceptionis thrown
propagating the error message.
Attributes
error_message (str): The error message raised by the task that failed.
Currently, we either use the exception attribute or the traceback attribute
but not both.
Attributes:
function_name (str): The name of the function that failed and produced the
RayTaskError.
exception (Exception): The exception object thrown by the failed task.
traceback_str (str): The traceback from the exception.
"""

def __init__(self, error_message):
"""Initialize a RayFailedObject.
def __init__(self, function_name, exception, traceback_str):
"""Initialize a RayTaskError."""
self.function_name = function_name
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
self.exception = exception
else:
self.exception = None
self.traceback_str = traceback_str

Args:
error_message (str): The error message raised by the task for which a
RayFailedObject is being created.
"""
self.error_message = error_message
@staticmethod
def deserialize(primitives):
"""Create a RayTaskError from a primitive object."""
function_name, exception, traceback_str = primitives
if exception[0] == "RayGetError":
exception = RayGetError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentError":
exception = RayGetArgumentError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentTypeError":
exception = RayGetArgumentTypeError.deserialize(exception[1])
elif exception[0] == "None":
exception = None
else:
assert False, "This code should be unreachable."
return RayTaskError(function_name, exception, traceback_str)

def serialize(self):
"""Turn a RayTaskError into a primitive object."""
if isinstance(self.exception, RayGetError):
serialized_exception = ("RayGetError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentError):
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentTypeError):
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
elif self.exception is None:
serialized_exception = ("None",)
else:
assert False, "This code should be unreachable."
return (self.function_name, serialized_exception, self.traceback_str)

def __str__(self):
"""Format a RayTaskError as a string."""
if self.traceback_str is None:
# This path is taken if getting the task arguments failed.
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception)
else:
# This path is taken if the task execution failed.
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str)

class RayGetError(Exception):
"""An exception used when get is called on an output of a failed task.
Attributes:
objectid (lib.ObjectID): The ObjectID that get was called on.
task_error (RayTaskError): The RayTaskError object created by the failed
task.
"""

def __init__(self, objectid, task_error):
"""Initialize a RayGetError object."""
self.objectid = objectid
self.task_error = task_error

@staticmethod
def deserialize(primitives):
"""Create a RayFailedObject from a primitive object.
"""Create a RayGetError from a primitive object."""
objectid, task_error = primitives
return RayGetError(objectid, RayTaskError.deserialize(task_error))

This initializes a RayFailedObject from a primitive object created by the
serialize method. This method is required in order for Ray to serialize
custom Python classes.
def serialize(self):
"""Turn a RayGetError into a primitive object."""
return (self.objectid, self.task_error.serialize())

Note:
This method should not be called by users.
def __str__(self):
"""Format a RayGetError as a string."""
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)

Args:
primitives (str): The object's error message.
"""
return RayFailedObject(primitives)
class RayGetArgumentError(Exception):
"""An exception used when a task's argument was produced by a failed task.
Attributes:
argument_index (int): The index (zero indexed) of the failed argument in
present task's remote function call.
function_name (str): The name of the function for the current task.
objectid (lib.ObjectID): The ObjectID that was passed in as the argument.
task_error (RayTaskError): The RayTaskError object created by the failed
task.
"""

def __init__(self, function_name, argument_index, objectid, task_error):
"""Initialize a RayGetArgumentError object."""
self.argument_index = argument_index
self.function_name = function_name
self.objectid = objectid
self.task_error = task_error

@staticmethod
def deserialize(primitives):
"""Create a RayGetArgumentError from a primitive object."""
function_name, argument_index, objectid, task_error = primitives
return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error))

def serialize(self):
"""Turn a RayFailedObject into a primitive object.
"""Turn a RayGetArgumentError into a primitive object."""
return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize())

This method is required in order for Ray to serialize
custom Python classes.
def __str__(self):
"""Format a RayGetArgumentError as a string."""
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)

Note:
The output of this method should only be used by the deserialize method.
This method should not be called by users.
class RayGetArgumentTypeError(Exception):
"""An exception used when a task's argument doesn't type check.
Args:
primitives (str): The object's error message.
Attributes:
function_name (str): The name of the function for the current task.
argument_index (int): The index (zero indexed) of the argument in the
present task's remote function call.
received_type: The type of the argument that was passed in.
expected_type: The type that was expected. This is determined by the remote
decorator.
"""

Returns:
A primitive representation of a RayFailedObject.
"""
return self.error_message
def __init__(self, function_name, argument_index, received_type, expected_type):
"""Initialize a RayGetArgumentTypeError object."""
self.function_name = function_name
self.argument_index = argument_index
# TODO(rkn): when we support the serialization of types, then we should
# remove the string conversions below.
self.received_type = str(received_type)
self.expected_type = str(expected_type)

@staticmethod
def deserialize(primitives):
"""Create a RayGetArgumentTypeError from a primitive object."""
function_name, argument_index, received_type, expected_type = primitives
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)

def serialize(self):
"""Turn a RayGetArgumentTypeError into a primitive object."""
return (self.function_name, self.argument_index, self.received_type, self.expected_type)

def __str__(self):
"""Format a RayGetArgumentTypeError as a string."""
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)

class RayDealloc(object):
"""An object used internally to properly implement reference counting.
Expand Down Expand Up @@ -689,8 +796,10 @@ def get(objectid, worker=global_worker):
if worker.mode == SCRIPT_MODE:
worker.print_new_failures()
value = worker.get_object(objectid)
if isinstance(value, RayFailedObject):
raise Exception("The task that created this object ID failed with error message:\n{}".format(value.error_message))
if isinstance(value, RayTaskError):
# If the result is a RayTaskError, then the task that created this object
# failed, and we should propagate the error message here.
raise RayGetError(objectid, value)
return value

def put(value, worker=global_worker):
Expand Down Expand Up @@ -749,7 +858,7 @@ def restart_workers_local(num_workers, worker_path, worker=global_worker):
def format_error_message(exception_message):
"""Improve the formatting of an exception thrown by a remote function.
This method takes an backtrace from an exception and makes it nicer by
This method takes a traceback from an exception and makes it nicer by
removing a few uninformative lines and adding some space to indent the
remaining lines nicely.
Expand All @@ -763,7 +872,6 @@ def format_error_message(exception_message):
# Remove lines 1, 2, 3, and 4, which are always the same, they just contain
# information about the main loop.
lines = lines[0:1] + lines[5:]
lines = [10 * " " + line for line in lines]
return "\n".join(lines)

def main_loop(worker=global_worker):
Expand All @@ -778,7 +886,7 @@ def main_loop(worker=global_worker):
If the process of getting the arguments for execution (which does some type
checking) or the process of executing the task fail, then the main loop will
catch the exception and store RayFailedObject objects containing the relevant
catch the exception and store RayTaskError objects containing the relevant
error messages in the object store in place of the actual outputs. These
objects are used to propagate the error messages.
"""
Expand All @@ -792,14 +900,16 @@ def process_task(task): # wrapping these lines in a function should cause the lo
outputs = worker.functions[func_name].executor(arguments) # execute the function
if len(return_objectids) == 1:
outputs = (outputs,)
except Exception:
exception_message = format_error_message(traceback.format_exc())
# Here we are storing RayFailedObjects in the object store to indicate
# failure (this is only interpreted by the worker).
failure_objects = [RayFailedObject(exception_message) for _ in range(len(return_objectids))]
except Exception as e:
# If the task threw an exception, then record the traceback. We determine
# whether the exception was thrown in the task execution by whether the
# variable "arguments" is defined.
traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None
failure_object = RayTaskError(func_name, e, traceback_str)
failure_objects = [failure_object for _ in range(len(return_objectids))]
store_outputs_in_objstore(return_objectids, failure_objects, worker)
raylib.notify_task_completed(worker.handle, False, exception_message) # notify the scheduler that the task threw an exception
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(exception_message, func_name))
raylib.notify_task_completed(worker.handle, False, str(failure_object))
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(str(failure_object), func_name))
else:
store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store
raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
Expand Down Expand Up @@ -1013,7 +1123,8 @@ def typecheck_arg(arg, expected_type, i, name):
name (str): The name of the function.
Raises:
Exception: An exception is raised if arg does not have the expected type.
RayGetArgumentTypeError: An exception is raised if arg does not have the
expected type.
"""
if issubclass(type(arg), expected_type):
# Passed the type-checck
Expand All @@ -1023,7 +1134,7 @@ def typecheck_arg(arg, expected_type, i, name):
# TODO(mehrdadn): Should long really be convertible to int?
pass
else:
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, name, type(arg), expected_type))
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)

def check_arguments(arg_types, has_vararg_param, name, args):
"""Check that the arguments to the remote function have the right types.
Expand Down Expand Up @@ -1080,7 +1191,10 @@ def get_arguments_for_execution(function, args, worker=global_worker):
value.
Raises:
Exception: An exception is raised the args do not all have the right types.
RayGetArgumentError: This exception is raised if a task that created one of
the arguments failed.
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
of the arguments does not have the expected type.
"""
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
arguments = []
Expand All @@ -1102,12 +1216,16 @@ def get_arguments_for_execution(function, args, worker=global_worker):
# get the object from the local object store
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
argument = worker.get_object(arg)
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that created this
# object failed, and we should propagate the error message here.
raise RayGetArgumentError(function.__name__, i, arg, argument)
_logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__))
else:
# pass the argument by value
argument = arg

typecheck_arg(argument, expected_type, i, function)
typecheck_arg(argument, expected_type, i, function.__name__)
arguments.append(argument)
return arguments

Expand Down

0 comments on commit de200ff

Please sign in to comment.