Skip to content

Commit

Permalink
[rllib] Add magic methods for rollouts (ray-project#2024)
Browse files Browse the repository at this point in the history
  • Loading branch information
alok authored and richardliaw committed May 17, 2018
1 parent 7549209 commit c0e4c9d
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 47 deletions.
33 changes: 24 additions & 9 deletions python/ray/rllib/optimizers/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self, *args, **kwargs):
@staticmethod
def concat_samples(samples):
out = {}
for k in samples[0].data.keys():
out[k] = np.concatenate([s.data[k] for s in samples])
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)

def concat(self, other):
Expand All @@ -50,10 +50,10 @@ def concat(self, other):
{"a": [1, 2, 3, 4, 5]}
"""

assert self.data.keys() == other.data.keys(), "must have same columns"
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.data.keys():
out[k] = np.concatenate([self.data[k], other.data[k]])
for k in self.keys():
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)

def rows(self):
Expand All @@ -70,7 +70,7 @@ def rows(self):

for i in range(self.count):
row = {}
for k in self.data.keys():
for k in self.keys():
row[k] = self[k][i]
yield row

Expand All @@ -85,19 +85,34 @@ def columns(self, keys):

out = []
for k in keys:
out.append(self.data[k])
out.append(self[k])
return out

def shuffle(self):
permutation = np.random.permutation(self.count)
for key, val in self.data.items():
self.data[key] = val[permutation]
for key, val in self.items():
self[key] = val[permutation]

def __getitem__(self, key):
return self.data[key]

def __setitem__(self, key, item):
self.data[key] = item

def __str__(self):
return "SampleBatch({})".format(str(self.data))

def __repr__(self):
return "SampleBatch({})".format(str(self.data))

def keys(self):
return self.data.keys()

def items(self):
return self.data.items()

def __iter__(self):
return self.data.__iter__()

def __contains__(self, x):
return x in self.data
14 changes: 7 additions & 7 deletions python/ray/rllib/test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class AsyncOptimizerTest(unittest.TestCase):

def tearDown(self):
ray.worker.cleanup()

Expand All @@ -21,8 +20,9 @@ def testBasic(self):
local = _MockEvaluator()
remotes = ray.remote(_MockEvaluator)
remote_evaluators = [remotes.remote() for i in range(5)]
test_optimizer = AsyncOptimizer(
{"grads_per_step": 10}, local, remote_evaluators)
test_optimizer = AsyncOptimizer({
"grads_per_step": 10
}, local, remote_evaluators)
test_optimizer.step()
self.assertTrue(all(local.get_weights() == 0))

Expand All @@ -33,11 +33,11 @@ def testConcat(self):
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
b12 = b1.concat(b2)
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
b = SampleBatch.concat_samples([b1, b2, b3])
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])


if __name__ == '__main__':
Expand Down
16 changes: 8 additions & 8 deletions python/ray/rllib/utils/process_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
processed rewards."""

traj = {}
trajsize = len(rollout.data["actions"])
for key in rollout.data:
traj[key] = np.stack(rollout.data[key])
trajsize = len(rollout["actions"])
for key in rollout:
traj[key] = np.stack(rollout[key])

if use_gae:
assert "vf_preds" in rollout.data, "Values not found!"
vpred_t = np.stack(
rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze()
assert "vf_preds" in rollout, "Values not found!"
vpred_t = np.stack(rollout["vf_preds"] +
[np.array(rollout.last_r)]).squeeze()
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
traj["advantages"] = discount(delta_t, gamma * lambda_)
traj["value_targets"] = traj["advantages"] + traj["vf_preds"]
else:
rewards_plus_v = np.stack(
rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze()
rewards_plus_v = np.stack(rollout["rewards"] +
[np.array(rollout.last_r)]).squeeze()
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]

for i in range(traj["advantages"].shape[0]):
Expand Down
67 changes: 44 additions & 23 deletions python/ray/rllib/utils/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,30 @@ def is_terminal(self):
terminal (bool): if rollout has terminated."""
return self.data["dones"][-1]

def __getitem__(self, key):
return self.data[key]

CompletedRollout = namedtuple(
"CompletedRollout", ["episode_length", "episode_reward"])
def __setitem__(self, key, item):
self.data[key] = item

def keys(self):
return self.data.keys()

def items(self):
return self.data.items()

def __iter__(self):
return self.data.__iter__()

def __next__(self):
return self.data.__next__()

def __contains__(self, x):
return x in self.data


CompletedRollout = namedtuple("CompletedRollout",
["episode_length", "episode_reward"])


class SyncSampler(object):
Expand All @@ -71,16 +92,15 @@ class SyncSampler(object):
thread."""
async = False

def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self._obs_filter = obs_filter
self.rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, self.horizon,
self._obs_filter)
self.rollout_provider = _env_runner(self.env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter)
self.metrics_queue = queue.Queue()

def get_data(self):
Expand Down Expand Up @@ -108,10 +128,10 @@ class AsyncSampler(threading.Thread):
accumulate and the gradient can be calculated on up to 5 batches."""
async = True

def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
assert getattr(obs_filter, "is_concurrent", False), (
"Observation Filter must support concurrent updates.")
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
assert getattr(
obs_filter, "is_concurrent",
False), ("Observation Filter must support concurrent updates.")
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
Expand All @@ -132,9 +152,9 @@ def run(self):
raise e

def _run(self):
rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps,
self.horizon, self._obs_filter)
rollout_provider = _env_runner(self.env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
Expand Down Expand Up @@ -232,13 +252,14 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
action = np.concatenate(action, axis=0).flatten()

# Collect the experience.
rollout.add(obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
features=last_features,
new_obs=observation,
**pi_info)
rollout.add(
obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
features=last_features,
new_obs=observation,
**pi_info)

last_observation = observation
last_features = features
Expand All @@ -247,8 +268,8 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
terminal_end = True
yield CompletedRollout(length, rewards)

if (length >= horizon or
not env.metadata.get("semantics.autoreset")):
if (length >= horizon
or not env.metadata.get("semantics.autoreset")):
last_observation = obs_filter(env.reset())
if hasattr(policy, "get_initial_features"):
last_features = policy.get_initial_features()
Expand Down

0 comments on commit c0e4c9d

Please sign in to comment.