Skip to content

Commit

Permalink
[rllib] Use batch.count in async samples optimizer (ray-project#2488)
Browse files Browse the repository at this point in the history
Using the actual batch size reduces the risk of mis-accounting. Here, we under-counted samples since in truncate_episodes mode we were doubling the batch size by accident in policy_evaluator.
  • Loading branch information
ericl authored Jul 27, 2018
1 parent 1e6b130 commit 2464972
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 20 deletions.
1 change: 0 additions & 1 deletion python/ray/rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"learning_starts": 50000,
"train_batch_size": 512,
"sample_batch_size": 50,
"max_weight_sync_delay": 400,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
Expand Down
27 changes: 22 additions & 5 deletions python/ray/rllib/evaluation/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ def __init__(self,
in each sample batch returned from this evaluator.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of exactly `batch_steps` in size. Episodes may be truncated
in order to meet this size requirement. When
`num_envs > 1`, episodes will be truncated to sequences of
`batch_size / num_envs` in length.
of at most `batch_steps` in size. The batch will be exactly
`batch_steps` in size if postprocessing does not change
batch sizes. Episodes may be truncated in order to meet
this size requirement. When `num_envs > 1`, episodes will
be truncated to sequences of `batch_size / num_envs` in
length.
"complete_episodes": Each call to sample() will return a batch
of at least `batch_steps in size. Episodes will not be
truncated, but multiple episodes may be packed within one
Expand Down Expand Up @@ -220,6 +222,7 @@ def make_env():
# Always use vector env for consistency even if num_envs = 1
self.async_env = AsyncVectorEnv.wrap_async(
self.env, make_env=make_env, num_envs=num_envs)
self.num_envs = num_envs

if self.batch_mode == "truncate_episodes":
if batch_steps % num_envs != 0:
Expand Down Expand Up @@ -276,7 +279,15 @@ def sample(self):

batches = [self.sampler.get_data()]
steps_so_far = batches[0].count
while steps_so_far < self.batch_steps:

# In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size.
if self.batch_mode == "truncate_episodes":
max_batches = self.num_envs
else:
max_batches = float("inf")

while steps_so_far < self.batch_steps and len(batches) < max_batches:
batch = self.sampler.get_data()
steps_so_far += batch.count
batches.append(batch)
Expand All @@ -293,6 +304,12 @@ def sample(self):

return batch

@ray.method(num_return_vals=2)
def sample_with_count(self):
"""Same as sample() but returns the count as a separate future."""
batch = self.sample()
return batch, batch.count

def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy graph."""

Expand Down
22 changes: 11 additions & 11 deletions python/ray/rllib/optimizers/async_samples_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_host(self):
return os.uname()[1]

def add_batch(self, batch):
PolicyOptimizer._check_not_multiagent(batch)
with self.add_batch_timer:
for row in batch.rows():
self.replay_buffer.add(row["obs"], row["actions"],
Expand Down Expand Up @@ -131,7 +132,7 @@ def step(self):
with self.grad_timer:
td_error = self.local_evaluator.compute_apply(replay)[
"td_error"]
self.outqueue.put((ra, replay, td_error))
self.outqueue.put((ra, replay, td_error, replay.count))
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True

Expand Down Expand Up @@ -164,8 +165,6 @@ def _init(self,
self.replay_starts = learning_starts
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.train_batch_size = train_batch_size
self.sample_batch_size = sample_batch_size
self.max_weight_sync_delay = max_weight_sync_delay

self.learner = LearnerThread(self.local_evaluator)
Expand Down Expand Up @@ -205,7 +204,7 @@ def _init(self,
ev.set_weights.remote(weights)
self.steps_since_update[ev] = 0
for _ in range(SAMPLE_QUEUE_DEPTH):
self.sample_tasks.add(ev, ev.sample.remote())
self.sample_tasks.add(ev, ev.sample_with_count.remote())

def step(self):
start = time.time()
Expand All @@ -226,16 +225,17 @@ def _step(self):
weights = None

with self.timers["sample_processing"]:
for ev, sample_batch in self.sample_tasks.completed():
self._check_not_multiagent(sample_batch)
sample_timesteps += self.sample_batch_size
completed = list(self.sample_tasks.completed())
counts = ray.get([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]

# Send the data to the replay buffer
random.choice(
self.replay_actors).add_batch.remote(sample_batch)

# Update weights if needed
self.steps_since_update[ev] += self.sample_batch_size
self.steps_since_update[ev] += counts[i]
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
Expand All @@ -249,7 +249,7 @@ def _step(self):
self.steps_since_update[ev] = 0

# Kick off another sample request
self.sample_tasks.add(ev, ev.sample.remote())
self.sample_tasks.add(ev, ev.sample_with_count.remote())

with self.timers["replay_processing"]:
for ra, replay in self.replay_tasks.completed():
Expand All @@ -261,9 +261,9 @@ def _step(self):

with self.timers["update_priorities"]:
while not self.learner.outqueue.empty():
ra, replay, td_error = self.learner.outqueue.get()
ra, replay, td_error, count = self.learner.outqueue.get()
ra.update_priorities.remote(replay["batch_indexes"], td_error)
train_timesteps += self.train_batch_size
train_timesteps += count

return sample_timesteps, train_timesteps

Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/optimizers/policy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def foreach_evaluator_with_index(self, func):
])
return local_result + remote_results

def _check_not_multiagent(self, sample_batch):
@staticmethod
def _check_not_multiagent(sample_batch):
if isinstance(sample_batch, MultiAgentBatch):
raise NotImplementedError(
"This optimizer does not support multi-agent yet.")
Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/utils/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,22 @@ class TaskPool(object):

def __init__(self):
self._tasks = {}
self._objects = {}

def add(self, worker, obj_id):
def add(self, worker, all_obj_ids):
if isinstance(all_obj_ids, list):
obj_id = all_obj_ids[0]
else:
obj_id = all_obj_ids
self._tasks[obj_id] = worker
self._objects[obj_id] = all_obj_ids

def completed(self):
pending = list(self._tasks)
if pending:
ready, _ = ray.wait(pending, num_returns=len(pending), timeout=10)
for obj_id in ready:
yield (self._tasks.pop(obj_id), obj_id)
yield (self._tasks.pop(obj_id), self._objects.pop(obj_id))

@property
def count(self):
Expand Down

0 comments on commit 2464972

Please sign in to comment.