Skip to content

Make ActorHandles pickleable, also make proper ActorHandle and ActorC… #2007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 9, 2018
836 changes: 488 additions & 348 deletions python/ray/actor.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _fetch_metrics_from_remote_evaluators(self):
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _train_stats(self, start_timestep):
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _train_stats(self, start_timestep):
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _train(self):
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:
w.__ray_terminate__.remote(w._ray_actor_id.id())
w.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _fetch_metrics_from_remote_evaluators(self):
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/utils/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def count(self):
def drop_colocated(actors):
colocated, non_colocated = split_colocated(actors)
for a in colocated:
a.__ray_terminate__.remote(a._ray_actor_id.id())
a.__ray_terminate__.remote()
return non_colocated


Expand Down
4 changes: 1 addition & 3 deletions python/ray/tune/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def stop(self, error=False, error_msg=None, stop_logger=True):
if self.runner:
stop_tasks = []
stop_tasks.append(self.runner.stop.remote())
stop_tasks.append(
self.runner.__ray_terminate__.remote(
self.runner._ray_actor_id.id()))
stop_tasks.append(self.runner.__ray_terminate__.remote())
# TODO(ekl) seems like wait hangs when killing actors
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
Expand Down
65 changes: 43 additions & 22 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def set_mode(self, mode):
print any information about errors because some of the tests
intentionally fail.

args:
Args:
mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and
SILENT_MODE.
"""
Expand Down Expand Up @@ -363,11 +363,6 @@ def put_object(self, object_id, value):
"do this, you can wrap the ObjectID in a list and "
"call 'put' on it (or return it).")

if isinstance(value, ray.actor.ActorHandleParent):
raise Exception("Calling 'put' on an actor handle is currently "
"not allowed (similarly, returning an actor "
"handle from a remote function is not allowed).")

# Serialize and put the object in the object store.
try:
self.store_and_register(object_id, value)
Expand Down Expand Up @@ -525,7 +520,8 @@ def submit_task(self,
num_return_vals=None,
num_cpus=None,
num_gpus=None,
resources=None):
resources=None,
driver_id=None):
"""Submit a remote task to the scheduler.

Tell the scheduler to schedule the execution of the function with ID
Expand All @@ -552,6 +548,11 @@ def submit_task(self,
num_cpus: The number of CPUs required by this task.
num_gpus: The number of GPUs required by this task.
resources: The resource requirements for this task.
driver_id: The ID of the relevant driver. This is almost always the
driver ID of the driver that is currently running. However, in
the exceptional case that an actor task is being dispatched to
an actor created by a different driver, this should be the
driver ID of the driver that created the actor.

Returns:
The return object IDs for this task.
Expand Down Expand Up @@ -579,9 +580,6 @@ def submit_task(self,
for arg in args:
if isinstance(arg, ray.local_scheduler.ObjectID):
args_for_local_scheduler.append(arg)
elif isinstance(arg, ray.actor.ActorHandleParent):
args_for_local_scheduler.append(
put(ray.actor.wrap_actor_handle(arg)))
elif ray.local_scheduler.check_simple_value(arg):
args_for_local_scheduler.append(arg)
else:
Expand All @@ -591,9 +589,12 @@ def submit_task(self,
if execution_dependencies is None:
execution_dependencies = []

if driver_id is None:
driver_id = self.task_driver_id

# Look up the various function properties.
function_properties = self.function_properties[
self.task_driver_id.id()][function_id.id()]
function_properties = self.function_properties[driver_id.id()][
function_id.id()]

if num_return_vals is None:
num_return_vals = function_properties.num_return_vals
Expand All @@ -610,8 +611,7 @@ def submit_task(self,

# Submit the task to local scheduler.
task = ray.local_scheduler.Task(
self.task_driver_id,
ray.local_scheduler.ObjectID(
driver_id, ray.local_scheduler.ObjectID(
function_id.id()), args_for_local_scheduler,
num_return_vals, self.current_task_id, self.task_index,
actor_creation_id, actor_creation_dummy_object_id, actor_id,
Expand Down Expand Up @@ -749,8 +749,6 @@ def _get_arguments_for_execution(self, function_name, serialized_args):
# created this object failed, and we should propagate the
# error message here.
raise RayGetArgumentError(function_name, i, arg, argument)
elif isinstance(argument, ray.actor.ActorHandleWrapper):
argument = ray.actor.unwrap_actor_handle(self, argument)
else:
# pass the argument by value
argument = arg
Expand Down Expand Up @@ -779,6 +777,10 @@ def _store_outputs_in_objstore(self, object_ids, outputs):
passed into this function.
"""
for i in range(len(object_ids)):
if isinstance(outputs[i], ray.actor.ActorHandle):
raise Exception("Returning an actor handle from a remote "
"function is not allowed).")

self.put_object(object_ids[i], outputs[i])

def _process_task(self, task):
Expand Down Expand Up @@ -1137,18 +1139,39 @@ def _initialize_serialization(worker=global_worker):
pyarrow.register_torch_serialization_handlers(worker.serialization_context)

# Define a custom serializer and deserializer for handling Object IDs.
def objectid_custom_serializer(obj):
def object_id_custom_serializer(obj):
return obj.id()

def objectid_custom_deserializer(serialized_obj):
def object_id_custom_deserializer(serialized_obj):
return ray.local_scheduler.ObjectID(serialized_obj)

# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
ray.local_scheduler.ObjectID,
"ray.ObjectID",
pickle=False,
custom_serializer=objectid_custom_serializer,
custom_deserializer=objectid_custom_deserializer)
custom_serializer=object_id_custom_serializer,
custom_deserializer=object_id_custom_deserializer)

def actor_handle_serializer(obj):
return obj._serialization_helper(True)

def actor_handle_deserializer(serialized_obj):
new_handle = ray.actor.ActorHandle.__new__(ray.actor.ActorHandle)
new_handle._deserialization_helper(serialized_obj, True)
return new_handle

# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
ray.actor.ActorHandle,
"ray.ActorHandle",
pickle=False,
custom_serializer=actor_handle_serializer,
custom_deserializer=actor_handle_deserializer)

if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because
Expand All @@ -1161,8 +1184,6 @@ def objectid_custom_deserializer(serialized_obj):
register_custom_serializer(type(lambda: 0), use_pickle=True)
# Tell Ray to serialize types with pickle.
register_custom_serializer(type(int), use_pickle=True)
# Ray can serialize actor handles that have been wrapped.
register_custom_serializer(ray.actor.ActorHandleWrapper, use_dict=True)
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
# used when passing around actor handles.
register_custom_serializer(
Expand Down
39 changes: 30 additions & 9 deletions test/actor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,7 +1822,12 @@ def testCallingPutOnActorHandle(self):

@ray.remote
class Counter(object):
pass
def __init__(self):
self.x = 0

def inc(self):
self.x += 1
return self.x

@ray.remote
def f():
Expand All @@ -1832,18 +1837,34 @@ def f():
def g():
return [Counter.remote()]

with self.assertRaises(Exception):
ray.put(Counter.remote())
# Currently, calling ray.put on an actor handle is allowed, but is
# there a good use case?
counter = Counter.remote()
counter_id = ray.put(counter)
new_counter = ray.get(counter_id)
assert ray.get(new_counter.inc.remote()) == 1
assert ray.get(counter.inc.remote()) == 2
assert ray.get(new_counter.inc.remote()) == 3

with self.assertRaises(Exception):
ray.get(f.remote())

# The below test is commented out because it currently does not behave
# properly. The call to g.remote() does not raise an exception because
# even though the actor handle cannot be pickled, pyarrow attempts to
# serialize it as a dictionary of its fields which kind of works.
# self.assertRaises(Exception):
# ray.get(g.remote())
# The below test works, but do we want to disallow this usage?
ray.get(g.remote())

def testPicklingActorHandle(self):
ray.worker.init(num_workers=1)

@ray.remote
class Foo(object):
def method(self):
pass

f = Foo.remote()
new_f = ray.worker.pickle.loads(ray.worker.pickle.dumps(f))
# Verify that we can call a method on the unpickled handle. TODO(rkn):
# we should also test this from a different driver.
ray.get(new_f.method.remote())


class ActorPlacementAndResources(unittest.TestCase):
Expand Down