Skip to content

Commit

Permalink
Use different serialization context for each driver. (ray-project#2406)
Browse files Browse the repository at this point in the history
  • Loading branch information
surehb authored and robertnishihara committed Jul 21, 2018
1 parent 05f485e commit 99d0d96
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 54 deletions.
10 changes: 5 additions & 5 deletions python/ray/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ def f():

def fetch_and_execute_function_to_run(self, key):
"""Run on arbitrary function on the worker."""
driver_id, serialized_function = self.redis_client.hmget(
key, ["driver_id", "function"])
(driver_id, serialized_function,
run_on_other_drivers) = self.redis_client.hmget(
key, ["driver_id", "function", "run_on_other_drivers"])

if (self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]
if (run_on_other_drivers == "False"
and self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]
and driver_id != self.worker.task_driver_id.id()):
# This export was from a different driver and there's no need for
# this driver to import it.
return

try:
Expand Down
162 changes: 113 additions & 49 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,25 @@ def __init__(self):
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
self.profiler = profiling.Profiler(self)
self.state_lock = threading.Lock()
# A dictionary that maps from driver id to SerializationContext
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
# Identity of the driver that this worker is processing.
self.task_driver_id = None

def get_serialization_context(self, driver_id):
"""Get the SerializationContext of the driver that this worker is processing.
Args:
driver_id: The ID of the driver that indicates which driver to get
the serialization context for.
Returns:
The serialization context of the given driver.
"""
if driver_id not in self.serialization_context_map:
_initialize_serialization(driver_id)
return self.serialization_context_map[driver_id]

def check_connected(self):
"""Check if the worker is connected.
Expand Down Expand Up @@ -308,7 +327,8 @@ def store_and_register(self, object_id, value, depth=100):
value,
object_id=pyarrow.plasma.ObjectID(object_id.id()),
memcopy_threads=self.memcopy_threads,
serialization_context=self.serialization_context)
serialization_context=self.get_serialization_context(
self.task_driver_id))
break
except pyarrow.SerializationCallbackError as e:
try:
Expand Down Expand Up @@ -400,7 +420,8 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
results += self.plasma_client.get(
object_ids[i:(
i + ray._config.worker_get_request_size())],
timeout, self.serialization_context)
timeout,
self.get_serialization_context(self.task_driver_id))
return results
except pyarrow.lib.ArrowInvalid:
# TODO(ekl): the local scheduler could include relevant
Expand Down Expand Up @@ -690,7 +711,8 @@ def export_remote_function(self, function_id, function_name, function,
})
self.redis_client.rpush("Exports", key)

def run_function_on_all_workers(self, function):
def run_function_on_all_workers(self, function,
run_on_other_drivers=False):
"""Run arbitrary code on all of the workers.
This function will first be run on the driver, and then it will be
Expand All @@ -702,6 +724,9 @@ def run_function_on_all_workers(self, function):
function (Callable): The function to run on all of the workers. It
should not take any arguments. If it returns anything, its
return values will not be used.
run_on_other_drivers: The boolean that indicates whether we want to
run this funtion on other drivers. One case is we may need to
share objects across drivers.
"""
# If ray.init has not been called yet, then cache the function and
# export it when connect is called. Otherwise, run the function on all
Expand Down Expand Up @@ -734,7 +759,8 @@ def run_function_on_all_workers(self, function):
key, {
"driver_id": self.task_driver_id.id(),
"function_id": function_to_run_id,
"function": pickled_function
"function": pickled_function,
"run_on_other_drivers": run_on_other_drivers
})
self.redis_client.rpush("Exports", key)
# TODO(rkn): If the worker fails after it calls setnx and before it
Expand Down Expand Up @@ -1209,17 +1235,17 @@ def error_info(worker=global_worker):
return errors


def _initialize_serialization(worker=global_worker):
def _initialize_serialization(driver_id, worker=global_worker):
"""Initialize the serialization library.
This defines a custom serializer for object IDs and also tells ray to
serialize several exception classes that we define for error handling.
"""
worker.serialization_context = pyarrow.default_serialization_context()
serialization_context = pyarrow.default_serialization_context()
# Tell the serialization context to use the cloudpickle version that we
# ship with Ray.
worker.serialization_context.set_pickle(pickle.dumps, pickle.loads)
pyarrow.register_torch_serialization_handlers(worker.serialization_context)
serialization_context.set_pickle(pickle.dumps, pickle.loads)
pyarrow.register_torch_serialization_handlers(serialization_context)

# Define a custom serializer and deserializer for handling Object IDs.
def object_id_custom_serializer(obj):
Expand All @@ -1231,7 +1257,7 @@ def object_id_custom_deserializer(serialized_obj):
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
serialization_context.register_type(
ray.ObjectID,
"ray.ObjectID",
pickle=False,
Expand All @@ -1249,28 +1275,55 @@ def actor_handle_deserializer(serialized_obj):
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
serialization_context.register_type(
ray.actor.ActorHandle,
"ray.ActorHandle",
pickle=False,
custom_serializer=actor_handle_serializer,
custom_deserializer=actor_handle_deserializer)

if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because
# register_custom_serializer will export the class to all of the
# workers.
register_custom_serializer(RayTaskError, use_dict=True)
register_custom_serializer(RayGetError, use_dict=True)
register_custom_serializer(RayGetArgumentError, use_dict=True)
# Tell Ray to serialize lambdas with pickle.
register_custom_serializer(type(lambda: 0), use_pickle=True)
# Tell Ray to serialize types with pickle.
register_custom_serializer(type(int), use_pickle=True)
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
# used when passing around actor handles.
register_custom_serializer(
ray.signature.FunctionSignature, use_dict=True)
worker.serialization_context_map[driver_id] = serialization_context

register_custom_serializer(
RayTaskError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayTaskError")
register_custom_serializer(
RayGetError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayGetError")
register_custom_serializer(
RayGetArgumentError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayGetArgumentError")
# Tell Ray to serialize lambdas with pickle.
register_custom_serializer(
type(lambda: 0),
use_pickle=True,
local=True,
driver_id=driver_id,
class_id="lambda")
# Tell Ray to serialize types with pickle.
register_custom_serializer(
type(int),
use_pickle=True,
local=True,
driver_id=driver_id,
class_id="type")
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
# used when passing around actor handles.
register_custom_serializer(
ray.signature.FunctionSignature,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.signature.FunctionSignature")


def get_address_info_from_redis_helper(redis_address,
Expand Down Expand Up @@ -2167,10 +2220,6 @@ def connect(info,
# driver task.
worker.current_task_id = driver_task.task_id()

# Initialize the serialization library. This registers some classes, and so
# it must be run before we export all of the cached remote functions.
_initialize_serialization()

# Start the import thread
import_thread.ImportThread(worker, mode).start()

Expand Down Expand Up @@ -2242,7 +2291,7 @@ def disconnect(worker=global_worker):
worker.connected = False
worker.cached_functions_to_run = []
worker.cached_remote_functions_and_actors = []
worker.serialization_context = pyarrow.SerializationContext()
worker.serialization_context_map.clear()


def _try_to_compute_deterministic_class_id(cls, depth=5):
Expand Down Expand Up @@ -2293,6 +2342,8 @@ def register_custom_serializer(cls,
serializer=None,
deserializer=None,
local=False,
driver_id=None,
class_id=None,
worker=global_worker):
"""Enable serialization and deserialization for a particular class.
Expand All @@ -2313,6 +2364,9 @@ def register_custom_serializer(cls,
if and only if use_pickle and use_dict are False.
local: True if the serializers should only be registered on the current
worker. This should usually be False.
driver_id: ID of the driver that we want to register the class for.
class_id: ID of the class that we are registering. If this is not
specified, we will calculate a new one inside the function.
Raises:
Exception: An exception is raised if pickle=False and the class cannot
Expand All @@ -2332,33 +2386,43 @@ def register_custom_serializer(cls,
# Raise an exception if cls cannot be serialized efficiently by Ray.
serialization.check_serializable(cls)

if not local:
# In this case, the class ID will be used to deduplicate the class
# across workers. Note that cloudpickle unfortunately does not produce
# deterministic strings, so these IDs could be different on different
# workers. We could use something weaker like cls.__name__, however
# that would run the risk of having collisions. TODO(rkn): We should
# improve this.
try:
# Attempt to produce a class ID that will be the same on each
# worker. However, determinism is not guaranteed, and the result
# may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception:
raise serialization.CloudPickleError("Failed to pickle class "
"'{}'".format(cls))
if class_id is None:
if not local:
# In this case, the class ID will be used to deduplicate the class
# across workers. Note that cloudpickle unfortunately does not
# produce deterministic strings, so these IDs could be different
# on different workers. We could use something weaker like
# cls.__name__, however that would run the risk of having
# collisions.
# TODO(rkn): We should improve this.
try:
# Attempt to produce a class ID that will be the same on each
# worker. However, determinism is not guaranteed, and the
# result may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception as e:
raise serialization.CloudPickleError("Failed to pickle class "
"'{}'".format(cls))
else:
# In this case, the class ID only needs to be meaningful on this
# worker and not across workers.
class_id = random_string()

if driver_id is None:
driver_id_bytes = worker.task_driver_id.id()
else:
# In this case, the class ID only needs to be meaningful on this worker
# and not across workers.
class_id = random_string()
driver_id_bytes = driver_id.id()

def register_class_for_serialization(worker_info):
# TODO(rkn): We need to be more thoughtful about what to do if custom
# serializers have already been registered for class_id. In some cases,
# we may want to use the last user-defined serializers and ignore
# subsequent calls to register_custom_serializer that were made by the
# system.
worker_info["worker"].serialization_context.register_type(

serialization_context = worker_info[
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
serialization_context.register_type(
cls,
class_id,
pickle=use_pickle,
Expand Down

0 comments on commit 99d0d96

Please sign in to comment.