Skip to content

Commit 77c8aa7

Browse files
robertnishiharapcmoritz
authored andcommitted
Make ActorHandles pickleable, also make proper ActorHandle and ActorC… (#2007)
* Make ActorHandles pickleable, also make proper ActorHandle and ActorClass classes. * Fix bug. * Fix actor test bug. * Update __ray_terminate__ usage. * Fix most linting, add documentation, and small cleanups. * Handle forking and pickling differently for actor handles. Fix linting. * Fixes for named actors via pickling. * Generate actor handle IDs deterministically in the pickling case.
1 parent 2048b54 commit 77c8aa7

File tree

10 files changed

+568
-388
lines changed

10 files changed

+568
-388
lines changed

python/ray/actor.py

Lines changed: 488 additions & 348 deletions
Large diffs are not rendered by default.

python/ray/rllib/a3c/a3c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _fetch_metrics_from_remote_evaluators(self):
126126
def _stop(self):
127127
# workaround for https://github.com/ray-project/ray/issues/1516
128128
for ev in self.remote_evaluators:
129-
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
129+
ev.__ray_terminate__.remote()
130130

131131
def _save(self, checkpoint_dir):
132132
checkpoint_path = os.path.join(

python/ray/rllib/ddpg/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _train_stats(self, start_timestep):
234234
def _stop(self):
235235
# workaround for https://github.com/ray-project/ray/issues/1516
236236
for ev in self.remote_evaluators:
237-
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
237+
ev.__ray_terminate__.remote()
238238

239239
def _save(self, checkpoint_dir):
240240
checkpoint_path = self.saver.save(

python/ray/rllib/dqn/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _train_stats(self, start_timestep):
232232
def _stop(self):
233233
# workaround for https://github.com/ray-project/ray/issues/1516
234234
for ev in self.remote_evaluators:
235-
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
235+
ev.__ray_terminate__.remote()
236236

237237
def _save(self, checkpoint_dir):
238238
checkpoint_path = self.saver.save(

python/ray/rllib/es/es.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def _train(self):
311311
def _stop(self):
312312
# workaround for https://github.com/ray-project/ray/issues/1516
313313
for w in self.workers:
314-
w.__ray_terminate__.remote(w._ray_actor_id.id())
314+
w.__ray_terminate__.remote()
315315

316316
def _save(self, checkpoint_dir):
317317
checkpoint_path = os.path.join(

python/ray/rllib/ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _fetch_metrics_from_remote_evaluators(self):
269269
def _stop(self):
270270
# workaround for https://github.com/ray-project/ray/issues/1516
271271
for ev in self.remote_evaluators:
272-
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
272+
ev.__ray_terminate__.remote()
273273

274274
def _save(self, checkpoint_dir):
275275
checkpoint_path = self.saver.save(

python/ray/rllib/utils/actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def count(self):
3030
def drop_colocated(actors):
3131
colocated, non_colocated = split_colocated(actors)
3232
for a in colocated:
33-
a.__ray_terminate__.remote(a._ray_actor_id.id())
33+
a.__ray_terminate__.remote()
3434
return non_colocated
3535

3636

python/ray/tune/trial.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,7 @@ def stop(self, error=False, error_msg=None, stop_logger=True):
182182
if self.runner:
183183
stop_tasks = []
184184
stop_tasks.append(self.runner.stop.remote())
185-
stop_tasks.append(
186-
self.runner.__ray_terminate__.remote(
187-
self.runner._ray_actor_id.id()))
185+
stop_tasks.append(self.runner.__ray_terminate__.remote())
188186
# TODO(ekl) seems like wait hangs when killing actors
189187
_, unfinished = ray.wait(
190188
stop_tasks, num_returns=2, timeout=250)

python/ray/worker.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def set_mode(self, mode):
267267
print any information about errors because some of the tests
268268
intentionally fail.
269269
270-
args:
270+
Args:
271271
mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and
272272
SILENT_MODE.
273273
"""
@@ -363,11 +363,6 @@ def put_object(self, object_id, value):
363363
"do this, you can wrap the ObjectID in a list and "
364364
"call 'put' on it (or return it).")
365365

366-
if isinstance(value, ray.actor.ActorHandleParent):
367-
raise Exception("Calling 'put' on an actor handle is currently "
368-
"not allowed (similarly, returning an actor "
369-
"handle from a remote function is not allowed).")
370-
371366
# Serialize and put the object in the object store.
372367
try:
373368
self.store_and_register(object_id, value)
@@ -525,7 +520,8 @@ def submit_task(self,
525520
num_return_vals=None,
526521
num_cpus=None,
527522
num_gpus=None,
528-
resources=None):
523+
resources=None,
524+
driver_id=None):
529525
"""Submit a remote task to the scheduler.
530526
531527
Tell the scheduler to schedule the execution of the function with ID
@@ -552,6 +548,11 @@ def submit_task(self,
552548
num_cpus: The number of CPUs required by this task.
553549
num_gpus: The number of GPUs required by this task.
554550
resources: The resource requirements for this task.
551+
driver_id: The ID of the relevant driver. This is almost always the
552+
driver ID of the driver that is currently running. However, in
553+
the exceptional case that an actor task is being dispatched to
554+
an actor created by a different driver, this should be the
555+
driver ID of the driver that created the actor.
555556
556557
Returns:
557558
The return object IDs for this task.
@@ -579,9 +580,6 @@ def submit_task(self,
579580
for arg in args:
580581
if isinstance(arg, ray.local_scheduler.ObjectID):
581582
args_for_local_scheduler.append(arg)
582-
elif isinstance(arg, ray.actor.ActorHandleParent):
583-
args_for_local_scheduler.append(
584-
put(ray.actor.wrap_actor_handle(arg)))
585583
elif ray.local_scheduler.check_simple_value(arg):
586584
args_for_local_scheduler.append(arg)
587585
else:
@@ -591,9 +589,12 @@ def submit_task(self,
591589
if execution_dependencies is None:
592590
execution_dependencies = []
593591

592+
if driver_id is None:
593+
driver_id = self.task_driver_id
594+
594595
# Look up the various function properties.
595-
function_properties = self.function_properties[
596-
self.task_driver_id.id()][function_id.id()]
596+
function_properties = self.function_properties[driver_id.id()][
597+
function_id.id()]
597598

598599
if num_return_vals is None:
599600
num_return_vals = function_properties.num_return_vals
@@ -610,8 +611,7 @@ def submit_task(self,
610611

611612
# Submit the task to local scheduler.
612613
task = ray.local_scheduler.Task(
613-
self.task_driver_id,
614-
ray.local_scheduler.ObjectID(
614+
driver_id, ray.local_scheduler.ObjectID(
615615
function_id.id()), args_for_local_scheduler,
616616
num_return_vals, self.current_task_id, self.task_index,
617617
actor_creation_id, actor_creation_dummy_object_id, actor_id,
@@ -749,8 +749,6 @@ def _get_arguments_for_execution(self, function_name, serialized_args):
749749
# created this object failed, and we should propagate the
750750
# error message here.
751751
raise RayGetArgumentError(function_name, i, arg, argument)
752-
elif isinstance(argument, ray.actor.ActorHandleWrapper):
753-
argument = ray.actor.unwrap_actor_handle(self, argument)
754752
else:
755753
# pass the argument by value
756754
argument = arg
@@ -779,6 +777,10 @@ def _store_outputs_in_objstore(self, object_ids, outputs):
779777
passed into this function.
780778
"""
781779
for i in range(len(object_ids)):
780+
if isinstance(outputs[i], ray.actor.ActorHandle):
781+
raise Exception("Returning an actor handle from a remote "
782+
"function is not allowed).")
783+
782784
self.put_object(object_ids[i], outputs[i])
783785

784786
def _process_task(self, task):
@@ -1137,18 +1139,39 @@ def _initialize_serialization(worker=global_worker):
11371139
pyarrow.register_torch_serialization_handlers(worker.serialization_context)
11381140

11391141
# Define a custom serializer and deserializer for handling Object IDs.
1140-
def objectid_custom_serializer(obj):
1142+
def object_id_custom_serializer(obj):
11411143
return obj.id()
11421144

1143-
def objectid_custom_deserializer(serialized_obj):
1145+
def object_id_custom_deserializer(serialized_obj):
11441146
return ray.local_scheduler.ObjectID(serialized_obj)
11451147

1148+
# We register this serializer on each worker instead of calling
1149+
# register_custom_serializer from the driver so that isinstance still
1150+
# works.
11461151
worker.serialization_context.register_type(
11471152
ray.local_scheduler.ObjectID,
11481153
"ray.ObjectID",
11491154
pickle=False,
1150-
custom_serializer=objectid_custom_serializer,
1151-
custom_deserializer=objectid_custom_deserializer)
1155+
custom_serializer=object_id_custom_serializer,
1156+
custom_deserializer=object_id_custom_deserializer)
1157+
1158+
def actor_handle_serializer(obj):
1159+
return obj._serialization_helper(True)
1160+
1161+
def actor_handle_deserializer(serialized_obj):
1162+
new_handle = ray.actor.ActorHandle.__new__(ray.actor.ActorHandle)
1163+
new_handle._deserialization_helper(serialized_obj, True)
1164+
return new_handle
1165+
1166+
# We register this serializer on each worker instead of calling
1167+
# register_custom_serializer from the driver so that isinstance still
1168+
# works.
1169+
worker.serialization_context.register_type(
1170+
ray.actor.ActorHandle,
1171+
"ray.ActorHandle",
1172+
pickle=False,
1173+
custom_serializer=actor_handle_serializer,
1174+
custom_deserializer=actor_handle_deserializer)
11521175

11531176
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
11541177
# These should only be called on the driver because
@@ -1161,8 +1184,6 @@ def objectid_custom_deserializer(serialized_obj):
11611184
register_custom_serializer(type(lambda: 0), use_pickle=True)
11621185
# Tell Ray to serialize types with pickle.
11631186
register_custom_serializer(type(int), use_pickle=True)
1164-
# Ray can serialize actor handles that have been wrapped.
1165-
register_custom_serializer(ray.actor.ActorHandleWrapper, use_dict=True)
11661187
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
11671188
# used when passing around actor handles.
11681189
register_custom_serializer(

test/actor_test.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,7 +1822,12 @@ def testCallingPutOnActorHandle(self):
18221822

18231823
@ray.remote
18241824
class Counter(object):
1825-
pass
1825+
def __init__(self):
1826+
self.x = 0
1827+
1828+
def inc(self):
1829+
self.x += 1
1830+
return self.x
18261831

18271832
@ray.remote
18281833
def f():
@@ -1832,18 +1837,34 @@ def f():
18321837
def g():
18331838
return [Counter.remote()]
18341839

1835-
with self.assertRaises(Exception):
1836-
ray.put(Counter.remote())
1840+
# Currently, calling ray.put on an actor handle is allowed, but is
1841+
# there a good use case?
1842+
counter = Counter.remote()
1843+
counter_id = ray.put(counter)
1844+
new_counter = ray.get(counter_id)
1845+
assert ray.get(new_counter.inc.remote()) == 1
1846+
assert ray.get(counter.inc.remote()) == 2
1847+
assert ray.get(new_counter.inc.remote()) == 3
18371848

18381849
with self.assertRaises(Exception):
18391850
ray.get(f.remote())
18401851

1841-
# The below test is commented out because it currently does not behave
1842-
# properly. The call to g.remote() does not raise an exception because
1843-
# even though the actor handle cannot be pickled, pyarrow attempts to
1844-
# serialize it as a dictionary of its fields which kind of works.
1845-
# self.assertRaises(Exception):
1846-
# ray.get(g.remote())
1852+
# The below test works, but do we want to disallow this usage?
1853+
ray.get(g.remote())
1854+
1855+
def testPicklingActorHandle(self):
1856+
ray.worker.init(num_workers=1)
1857+
1858+
@ray.remote
1859+
class Foo(object):
1860+
def method(self):
1861+
pass
1862+
1863+
f = Foo.remote()
1864+
new_f = ray.worker.pickle.loads(ray.worker.pickle.dumps(f))
1865+
# Verify that we can call a method on the unpickled handle. TODO(rkn):
1866+
# we should also test this from a different driver.
1867+
ray.get(new_f.method.remote())
18471868

18481869

18491870
class ActorPlacementAndResources(unittest.TestCase):

0 commit comments

Comments
 (0)