Skip to content

Commit 9f91eb8

Browse files
robertnishiharapcmoritz
authored andcommitted
Change API for remote function declaration, actor instantiation, and actor method invocation. (ray-project#541)
* Direction substitution of @ray.remote -> @ray.task. * Changes to make '@ray.task' work. * Instantiate actors with Class.remote() instead of Class(). * Convert actor instantiation in tests and examples from Class() to Class.remote(). * Change actor method invocation from object.method() to object.method.remote(). * Update tests and examples to invoke actor methods with .remote(). * Fix bugs in jenkins tests. * Fix example applications. * Change @ray.task back to @ray.remote. * Changes to make @ray.actor -> @ray.remote work. * Direct substitution of @ray.actor -> @ray.remote. * Fixes. * Raise exception if @ray.actor decorator is used. * Simplify ActorMethod class.
1 parent 22c6a22 commit 9f91eb8

19 files changed

+390
-346
lines changed

doc/source/actors.rst

+26-25
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ An actor can be defined as follows.
2727
2828
import gym
2929
30-
@ray.actor
30+
@ray.remote
3131
class GymEnvironment(object):
3232
def __init__(self, name):
3333
self.env = gym.make(name)
@@ -63,27 +63,28 @@ We can use the actor by calling one of its methods.
6363

6464
.. code-block:: python
6565
66-
a1.step(0)
67-
a2.step(0)
66+
a1.step.remote(0)
67+
a2.step.remote(0)
6868
69-
When ``a1.step(0)`` is called, a task is created and scheduled on the first
70-
actor. This scheduling procedure bypasses the global scheduler, and is assigned
71-
directly to the local scheduler responsible for the actor by the driver's local
72-
scheduler. Since the method call is a task, ``a1.step(0)`` returns an object ID.
73-
We can call `ray.get` on the object ID to retrieve the actual value.
69+
When ``a1.step.remote(0)`` is called, a task is created and scheduled on the
70+
first actor. This scheduling procedure bypasses the global scheduler, and is
71+
assigned directly to the local scheduler responsible for the actor by the
72+
driver's local scheduler. Since the method call is a task, ``a1.step(0)``
73+
returns an object ID. We can call `ray.get` on the object ID to retrieve the
74+
actual value.
7475

75-
The call to ``a2.step(0)`` generates a task which is scheduled on the second
76-
actor. Since these two tasks run on different actors, they can be executed in
77-
parallel (note that only actor methods will be scheduled on actor workers, not
78-
regular remote functions).
76+
The call to ``a2.step.remote(0)`` generates a task which is scheduled on the
77+
second actor. Since these two tasks run on different actors, they can be
78+
executed in parallel (note that only actor methods will be scheduled on actor
79+
workers, not regular remote functions).
7980

8081
On the other hand, methods called on the same actor are executed serially and
8182
share in the order that they are called and share state with one another. We
8283
illustrate this with a simple example.
8384

8485
.. code-block:: python
8586
86-
@ray.actor
87+
@ray.remote
8788
class Counter(object):
8889
def __init__(self):
8990
self.value = 0
@@ -96,12 +97,12 @@ illustrate this with a simple example.
9697
9798
# Increment each counter once and get the results. These tasks all happen in
9899
# parallel.
99-
results = ray.get([c.increment() for c in counters])
100+
results = ray.get([c.increment.remote() for c in counters])
100101
print(results) # prints [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
101102
102103
# Increment the first counter five times. These tasks are executed serially
103104
# and share state.
104-
results = ray.get([counters[0].increment() for _ in range(5)])
105+
results = ray.get([counters[0].increment.remote() for _ in range(5)])
105106
print(results) # prints [2, 3, 4, 5, 6]
106107
107108
Using GPUs on actors
@@ -136,8 +137,8 @@ We can then define an actor for this network as follows.
136137
import os
137138
138139
# Define an actor that runs on GPUs. If there are no GPUs, then simply use
139-
# ray.actor without any arguments and no parentheses.
140-
@ray.actor(num_gpus=1)
140+
# ray.remote without any arguments and no parentheses.
141+
@ray.remote(num_gpus=1)
141142
class NeuralNetOnGPU(object):
142143
def __init__(self):
143144
# Set an environment variable to tell TensorFlow which GPUs to use. Note
@@ -154,15 +155,15 @@ We can then define an actor for this network as follows.
154155
self.sess.run(init)
155156
156157
To indicate that an actor requires one GPU, we pass in ``num_gpus=1`` to
157-
``ray.actor``. Note that in order for this to work, Ray must have been started
158+
``ray.remote``. Note that in order for this to work, Ray must have been started
158159
with some GPUs, e.g., via ``ray.init(num_gpus=2)``. Otherwise, when you try to
159-
instantiate the GPU version with ``NeuralNetOnGPU()``, an exception will be
160-
thrown saying that there aren't enough GPUs in the system.
160+
instantiate the GPU version with ``NeuralNetOnGPU.remote()``, an exception will
161+
be thrown saying that there aren't enough GPUs in the system.
161162

162163
When the actor is created, it will have access to a list of the IDs of the GPUs
163164
that it is allowed to use via ``ray.get_gpu_ids()``. This is a list of integers,
164165
like ``[]``, or ``[1]``, or ``[2, 5, 6]``. Since we passed in
165-
``ray.actor(num_gpus=1)``, this list will have length one.
166+
``ray.remote(num_gpus=1)``, this list will have length one.
166167

167168
We can put this all together as follows.
168169

@@ -190,7 +191,7 @@ We can put this all together as follows.
190191
191192
return x, y_, train_step, accuracy
192193
193-
@ray.actor(num_gpus=1)
194+
@ray.remote(num_gpus=1)
194195
class NeuralNetOnGPU(object):
195196
def __init__(self, mnist_data):
196197
self.mnist = mnist_data
@@ -223,9 +224,9 @@ We can put this all together as follows.
223224
ray.register_class(type(mnist.train))
224225
225226
# Create the actor.
226-
nn = NeuralNetOnGPU(mnist)
227+
nn = NeuralNetOnGPU.remote(mnist)
227228
228229
# Run a few steps of training and print the accuracy.
229-
nn.train(100)
230-
accuracy = ray.get(nn.get_accuracy())
230+
nn.train.remote(100)
231+
accuracy = ray.get(nn.get_accuracy.remote())
231232
print("Accuracy is {}.".format(accuracy))

doc/source/example-a3c.rst

+10-4
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ We use a Ray Actor to simulate the environment.
7373
import numpy as np
7474
import ray
7575
76-
@ray.actor
76+
@ray.remote
7777
class Runner(object):
7878
"""Actor object to start running simulation on workers.
7979
Gradient computation is also executed on this object."""
@@ -127,7 +127,7 @@ global model parameters. The main training script looks like the following.
127127
128128
# Start gradient calculation tasks on each actor
129129
parameters = policy.get_weights()
130-
gradient_list = [agent.compute_gradient(parameters) for agent in agents]
130+
gradient_list = [agent.compute_gradient.remote(parameters) for agent in agents]
131131
132132
while True: # Replace with your termination condition
133133
# wait for some gradient to be computed - unblock as soon as the earliest arrives
@@ -147,6 +147,12 @@ global model parameters. The main training script looks like the following.
147147
Benchmarks and Visualization
148148
----------------------------
149149

150-
For the :code:`PongDeterministic-v3` and an Amazon EC2 m4.16xlarge instance, we are able to train the agent with 16 workers in around 15 minutes. With 8 workers, we can train the agent in around 25 minutes.
150+
For the :code:`PongDeterministic-v3` and an Amazon EC2 m4.16xlarge instance, we
151+
are able to train the agent with 16 workers in around 15 minutes. With 8
152+
workers, we can train the agent in around 25 minutes.
151153

152-
You can visualize performance by running :code:`tensorboard --logdir [directory]` in a separate screen, where :code:`[directory]` is defaulted to :code:`./results/`. If you are running multiple experiments, be sure to vary the directory to which Tensorflow saves its progress (found in :code:`driver.py`).
154+
You can visualize performance by running
155+
:code:`tensorboard --logdir [directory]` in a separate screen, where
156+
:code:`[directory]` is defaulted to :code:`./results/`. If you are running
157+
multiple experiments, be sure to vary the directory to which Tensorflow saves
158+
its progress (found in :code:`driver.py`).

doc/source/example-resnet.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ The core of the script is the actor definition.
5353

5454
.. code-block:: python
5555
56-
@ray.actor(num_gpus=1)
56+
@ray.remote(num_gpus=1)
5757
class ResNetTrainActor(object):
5858
def __init__(self, path, num_gpus):
5959
# Set the CUDA_VISIBLE_DEVICES environment variable in order to restrict
@@ -78,7 +78,7 @@ The main script first creates one actor for each GPU.
7878

7979
.. code-block:: python
8080
81-
train_actors = [ResNetTrainActor(train_data, num_gpus) for _ in range(num_gpus)]
81+
train_actors = [ResNetTrainActor.remote(train_data, num_gpus) for _ in range(num_gpus)]
8282
8383
Then after initializing the actors with the same weights, the main loop performs
8484
updates on each model, averages the updates, and puts the new weights in the
@@ -87,7 +87,7 @@ object store.
8787
.. code-block:: python
8888
8989
while True:
90-
all_weights = ray.get([actor.compute_steps(weight_id) for actor in train_actors])
90+
all_weights = ray.get([actor.compute_steps.remote(weight_id) for actor in train_actors])
9191
mean_weights = {k: sum([weights[k] for weights in all_weights]) / num_gpus for k in all_weights[0]}
9292
weight_id = ray.put(mean_weights)
9393

doc/source/example-rl-pong.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ the actor.
5858

5959
.. code-block:: python
6060
61-
@ray.actor
61+
@ray.remote
6262
class PongEnv(object):
6363
def __init__(self):
6464
# Tell numpy to only use one core. If we don't do this, each actor may try
@@ -93,7 +93,7 @@ perform rollouts and compute gradients in parallel.
9393
actions = []
9494
# Launch tasks to compute gradients from multiple rollouts in parallel.
9595
for i in range(batch_size):
96-
action_id = actors[i].compute_gradient(model_id)
96+
action_id = actors[i].compute_gradient.remote(model_id)
9797
actions.append(action_id)
9898
9999

doc/source/using-ray-with-tensorflow.rst

+7-7
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ complex Python objects.
158158
x_test, y_test = ray.get(generate_fake_x_y_data.remote(BATCH_SIZE, seed=NUM_BATCHES))
159159
160160
# Create actors to store the networks.
161-
remote_network = ray.actor(Network)
162-
actor_list = [remote_network(x_ids[i], y_ids[i]) for i in range(NUM_BATCHES)]
161+
remote_network = ray.remote(Network)
162+
actor_list = [remote_network.remote(x_ids[i], y_ids[i]) for i in range(NUM_BATCHES)]
163163
164164
# Get initial weights of some actor.
165-
weights = ray.get(actor_list[0].get_weights())
165+
weights = ray.get(actor_list[0].get_weights.remote())
166166
167167
# Do some steps of training.
168168
for iteration in range(NUM_ITERS):
@@ -173,7 +173,7 @@ complex Python objects.
173173
# more efficient.
174174
weights_id = ray.put(weights)
175175
# Call the remote function multiple times in parallel.
176-
new_weights_ids = [actor.step(weights_id) for actor in actor_list]
176+
new_weights_ids = [actor.step.remote(weights_id) for actor in actor_list]
177177
# Get all of the weights.
178178
new_weights_list = ray.get(new_weights_ids)
179179
# Add up all the different weights. Each element of new_weights_list is a dict
@@ -288,8 +288,8 @@ For reference, the full code is below:
288288
x_test, y_test = ray.get(generate_fake_x_y_data.remote(BATCH_SIZE, seed=NUM_BATCHES))
289289
290290
# Create actors to store the networks.
291-
remote_network = ray.actor(Network)
292-
actor_list = [remote_network(x_ids[i], y_ids[i]) for i in range(NUM_BATCHES)]
291+
remote_network = ray.remote(Network)
292+
actor_list = [remote_network.remote(x_ids[i], y_ids[i]) for i in range(NUM_BATCHES)]
293293
local_network = Network(x_test, y_test)
294294
295295
# Get initial weights of local network.
@@ -304,7 +304,7 @@ For reference, the full code is below:
304304
# more efficient.
305305
weights_id = ray.put(weights)
306306
# Call the remote function multiple times in parallel.
307-
gradients_ids = [actor.step(weights_id) for actor in actor_list]
307+
gradients_ids = [actor.step.remote(weights_id) for actor in actor_list]
308308
# Get all of the weights.
309309
gradients_list = ray.get(gradients_ids)
310310

examples/a3c/driver.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from misc import timestamp, time_string
1616
from envs import create_env
1717

18-
@ray.actor
18+
@ray.remote
1919
class Runner(object):
2020
"""Actor object to start running simulation on workers.
2121
Gradient computation is also executed from this object."""
@@ -58,9 +58,9 @@ def compute_gradient(self, params):
5858
def train(num_workers, env_name="PongDeterministic-v3"):
5959
env = create_env(env_name)
6060
policy = LSTMPolicy(env.observation_space.shape, env.action_space.n, 0)
61-
agents = [Runner(env_name, i) for i in range(num_workers)]
61+
agents = [Runner.remote(env_name, i) for i in range(num_workers)]
6262
parameters = policy.get_weights()
63-
gradient_list = [agent.compute_gradient(parameters) for agent in agents]
63+
gradient_list = [agent.compute_gradient.remote(parameters) for agent in agents]
6464
steps = 0
6565
obs = 0
6666
while True:
@@ -70,7 +70,7 @@ def train(num_workers, env_name="PongDeterministic-v3"):
7070
parameters = policy.get_weights()
7171
steps += 1
7272
obs += info["size"]
73-
gradient_list.extend([agents[info["id"]].compute_gradient(parameters)])
73+
gradient_list.extend([agents[info["id"]].compute_gradient.remote(parameters)])
7474
return policy
7575

7676
if __name__ == '__main__':

examples/lbfgs/driver.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def grad(self, xs, ys):
6060
"""Computes the gradients of the network."""
6161
return self.sess.run(self.cross_entropy_grads, feed_dict={self.x: xs, self.y_: ys})
6262

63-
@ray.actor
63+
@ray.remote
6464
class NetActor(object):
6565
def __init__(self, xs, ys):
6666
os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -88,13 +88,13 @@ def get_flat_size(self):
8888
# Compute the loss on the entire dataset.
8989
def full_loss(theta):
9090
theta_id = ray.put(theta)
91-
loss_ids = [actor.loss(theta_id) for actor in actors]
91+
loss_ids = [actor.loss.remote(theta_id) for actor in actors]
9292
return sum(ray.get(loss_ids))
9393

9494
# Compute the gradient of the loss on the entire dataset.
9595
def full_grad(theta):
9696
theta_id = ray.put(theta)
97-
grad_ids = [actor.grad(theta_id) for actor in actors]
97+
grad_ids = [actor.grad.remote(theta_id) for actor in actors]
9898
return sum(ray.get(grad_ids)).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b.
9999

100100
if __name__ == "__main__":
@@ -117,9 +117,9 @@ def full_grad(theta):
117117
batch_size = mnist.train.num_examples // num_batches
118118
batches = [mnist.train.next_batch(batch_size) for _ in range(num_batches)]
119119
print("Putting MNIST in the object store.")
120-
actors = [NetActor(xs, ys) for (xs, ys) in batches]
120+
actors = [NetActor.remote(xs, ys) for (xs, ys) in batches]
121121
# Initialize the weights for the network to the vector of all zeros.
122-
dim = ray.get(actors[0].get_flat_size())
122+
dim = ray.get(actors[0].get_flat_size.remote())
123123
theta_init = 1e-2 * np.random.normal(size=dim)
124124

125125
# Use L-BFGS to minimize the loss function.

examples/policy_gradient/examples/example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@
4747
preprocessor = AtariPixelPreprocessor()
4848

4949
print("Using the environment {}.".format(mdp_name))
50-
agents = [RemoteAgent(mdp_name, 1, preprocessor, config, False) for _ in range(5)]
50+
agents = [RemoteAgent.remote(mdp_name, 1, preprocessor, config, False) for _ in range(5)]
5151
agent = Agent(mdp_name, 1, preprocessor, config, True)
5252

5353
kl_coeff = config["kl_coeff"]
5454

5555
for j in range(1000):
5656
print("== iteration", j)
5757
weights = ray.put(agent.get_weights())
58-
[a.load_weights(weights) for a in agents]
58+
[a.load_weights.remote(weights) for a in agents]
5959
trajectory, total_reward, traj_len_mean = collect_samples(agents, config["timesteps_per_batch"], 0.995, 1.0, 2000)
6060
print("total reward is ", total_reward)
6161
print("trajectory length mean is ", traj_len_mean)

examples/policy_gradient/reinforce/agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def compute_trajectory(self, gamma, lam, horizon):
4040
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
4141
return trajectory
4242

43-
RemoteAgent = ray.actor(Agent)
43+
RemoteAgent = ray.remote(Agent)

examples/policy_gradient/reinforce/rollout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def collect_samples(agents, num_timesteps, gamma, lam, horizon, observation_filt
7979
total_rewards = []
8080
traj_len_means = []
8181
while num_timesteps_so_far < num_timesteps:
82-
trajectory_batch = ray.get([agent.compute_trajectory(gamma, lam, horizon) for agent in agents])
82+
trajectory_batch = ray.get([agent.compute_trajectory.remote(gamma, lam, horizon) for agent in agents])
8383
trajectory = concatenate(trajectory_batch)
8484
total_rewards.append(trajectory["raw_rewards"].sum(axis=0).mean() / len(agents))
8585
trajectory = flatten(trajectory)

examples/resnet/resnet_main.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_data(path, size, dataset):
4444
images[int(2 * size / 3):, :],
4545
labels)
4646

47-
@ray.actor(num_gpus=use_gpu)
47+
@ray.remote(num_gpus=use_gpu)
4848
class ResNetTrainActor(object):
4949
def __init__(self, data, dataset, num_gpus):
5050
if num_gpus > 0:
@@ -89,7 +89,7 @@ def compute_steps(self, weights):
8989
def get_weights(self):
9090
return self.model.variables.get_weights()
9191

92-
@ray.actor
92+
@ray.remote
9393
class ResNetTestActor(object):
9494
def __init__(self, data, dataset, eval_batch_count, eval_dir):
9595
hps = resnet_model.HParams(batch_size=100,
@@ -162,25 +162,25 @@ def train():
162162
train_data = get_data.remote(FLAGS.train_data_path, 50000, FLAGS.dataset)
163163
test_data = get_data.remote(FLAGS.eval_data_path, 10000, FLAGS.dataset)
164164
if num_gpus > 0:
165-
train_actors = [ResNetTrainActor(train_data, FLAGS.dataset, num_gpus) for _ in range(num_gpus)]
165+
train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset, num_gpus) for _ in range(num_gpus)]
166166
else:
167-
train_actors = [ResNetTrainActor(train_data, num_gpus)]
168-
test_actor = ResNetTestActor(test_data, FLAGS.dataset, FLAGS.eval_batch_count, FLAGS.eval_dir)
169-
print('The log files for tensorboard are stored at ip {}.'.format(ray.get(test_actor.get_ip_addr())))
167+
train_actors = [ResNetTrainActor.remote(train_data, num_gpus, 0)]
168+
test_actor = ResNetTestActor.remote(test_data, FLAGS.dataset, FLAGS.eval_batch_count, FLAGS.eval_dir)
169+
print('The log files for tensorboard are stored at ip {}.'.format(ray.get(test_actor.get_ip_addr.remote())))
170170
step = 0
171-
weight_id = train_actors[0].get_weights()
172-
acc_id = test_actor.accuracy(weight_id, step)
171+
weight_id = train_actors[0].get_weights.remote()
172+
acc_id = test_actor.accuracy.remote(weight_id, step)
173173
if num_gpus == 0:
174174
num_gpus = 1
175175
print("Starting computation.")
176176
while True:
177-
all_weights = ray.get([actor.compute_steps(weight_id) for actor in train_actors])
177+
all_weights = ray.get([actor.compute_steps.remote(weight_id) for actor in train_actors])
178178
mean_weights = {k: sum([weights[k] for weights in all_weights]) / num_gpus for k in all_weights[0]}
179179
weight_id = ray.put(mean_weights)
180180
step += 10
181181
if step % 200 == 0:
182182
acc = ray.get(acc_id)
183-
acc_id = test_actor.accuracy(weight_id, step)
183+
acc_id = test_actor.accuracy.remote(weight_id, step)
184184
print('Step {0}: {1:.6f}'.format(step - 200, acc))
185185

186186
def main(_):

0 commit comments

Comments
 (0)