Skip to content

Commit

Permalink
[rllib] Evaluators and Optimizers Refactoring (ray-project#1339)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Dec 30, 2017
1 parent 22c7c87 commit 3304099
Show file tree
Hide file tree
Showing 28 changed files with 634 additions and 351 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ script:
- python -m pytest test/dataframe.py

- python -m pytest python/ray/rllib/test/test_catalog.py
- python -m pytest python/ray/rllib/test/test_filters.py
- python -m pytest python/ray/rllib/test/test_optimizers.py

deploy:
provider: s3
Expand Down
6 changes: 4 additions & 2 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ray
from ray.rllib.agent import Agent
from ray.rllib.optimizers import AsyncOptimizer
from ray.rllib.utils import FilterManager
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator
from ray.tune.result import TrainingResult

Expand Down Expand Up @@ -53,7 +54,7 @@
"optimizer": {
# Number of gradients applied for each `train` step
"grads_per_step": 100,
},
}
}


Expand All @@ -76,6 +77,8 @@ def _init(self):

def _train(self):
self.optimizer.step()
FilterManager.synchronize(
self.local_evaluator.filters, self.remote_evaluators)
res = self._fetch_metrics_from_remote_evaluators()
return res

Expand Down Expand Up @@ -105,7 +108,6 @@ def _fetch_metrics_from_remote_evaluators(self):
def _save(self):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
# self.saver.save
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
Expand Down
45 changes: 34 additions & 11 deletions python/ray/rllib/a3c/a3c_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class A3CEvaluator(Evaluator):
Attributes:
policy: Copy of graph used for policy. Used by sampler and gradients.
obs_filter: Observation filter used in environment sampling
rew_filter: Reward filter used in rollout post-processing.
sampler: Component for interacting with environment and generating
rollouts.
Expand All @@ -40,16 +41,15 @@ def __init__(
self.obs_filter = get_filter(
config["observation_filter"], env.observation_space.shape)
self.rew_filter = get_filter(config["reward_filter"], ())
self.filters = {"obs_filter": self.obs_filter,
"rew_filter": self.rew_filter}
self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
config["batch_size"])
if start_sampler and self.sampler.async:
self.sampler.start()
self.logdir = logdir

def sample(self):
"""
Returns:
trajectory (PartialRollout): Experience Samples from evaluator"""
rollout = self.sampler.get_data()
samples = process_rollout(
rollout, self.rew_filter, gamma=self.config["gamma"],
Expand All @@ -76,20 +76,43 @@ def get_weights(self):
def set_weights(self, params):
self.policy.set_weights(params)

def update_filters(self, obs_filter=None, rew_filter=None):
if rew_filter:
# No special handling required since outside of threaded code
self.rew_filter = rew_filter.copy()
if obs_filter:
self.sampler.update_obs_filter(obs_filter)

def save(self):
filters = self.get_filters(flush_after=True)
weights = self.get_weights()
return pickle.dumps({"weights": weights})
return pickle.dumps({
"filters": filters,
"weights": weights})

def restore(self, objs):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
self.set_weights(objs["weights"])

def sync_filters(self, new_filters):
"""Changes self's filter to given and rebases any accumulated delta.
Args:
new_filters (dict): Filters with new state to update local copy.
"""
assert all(k in new_filters for k in self.filters)
for k in self.filters:
self.filters[k].sync(new_filters[k])

def get_filters(self, flush_after=False):
"""Returns a snapshot of filters.
Args:
flush_after (bool): Clears the filter buffer state.
Returns:
return_filters (dict): Dict for serializable filters
"""
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters


RemoteA3CEvaluator = ray.remote(A3CEvaluator)
2 changes: 1 addition & 1 deletion python/ray/rllib/a3c/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_weights(self):
def set_weights(self, weights):
raise NotImplementedError

def compute_gradients(self, batch):
def compute_gradients(self, samples):
raise NotImplementedError

def compute(self, observations):
Expand Down
12 changes: 5 additions & 7 deletions python/ray/rllib/a3c/shared_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def _setup_graph(self, ob_space, ac_space):
self.registry, self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
self.curr_dist = dist_class(self.logits)
# with tf.variable_scope("vf"):
# vf_model = ModelCatalog.get_model(self.x, 1)
self.vf = tf.reshape(linear(self._model.last_layer, 1, "value",
normc_initializer(1.0)), [-1])

Expand All @@ -37,13 +35,13 @@ def _setup_graph(self, ob_space, ac_space):
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)

def compute_gradients(self, trajectory):
def compute_gradients(self, samples):
info = {}
feed_dict = {
self.x: trajectory["observations"],
self.ac: trajectory["actions"],
self.adv: trajectory["advantages"],
self.r: trajectory["value_targets"],
self.x: samples["observations"],
self.ac: samples["actions"],
self.adv: samples["advantages"],
self.r: samples["value_targets"],
}
self.grads = [g for g in self.grads if g is not None]
self.local_steps += 1
Expand Down
12 changes: 6 additions & 6 deletions python/ray/rllib/a3c/shared_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ def _setup_graph(self, ob_space, ac_space):
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)

def compute_gradients(self, trajectory):
def compute_gradients(self, samples):
"""Computing the gradient is actually model-dependent.
The LSTM needs its hidden states in order to compute the gradient
accurately.
"""
features = trajectory["features"][0]
features = samples["features"][0]
feed_dict = {
self.x: trajectory["observations"],
self.ac: trajectory["actions"],
self.adv: trajectory["advantages"],
self.r: trajectory["value_targets"],
self.x: samples["observations"],
self.ac: samples["actions"],
self.adv: samples["advantages"],
self.r: samples["value_targets"],
self.state_in[0]: features[0],
self.state_in[1]: features[1]
}
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/a3c/tfpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_weights(self):
def set_weights(self, weights):
self.variables.set_weights(weights)

def compute_gradients(self, batch):
def compute_gradients(self, samples):
raise NotImplementedError

def compute(self, observation):
Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/a3c/torchpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ def set_weights(self, weights):
with self.lock:
self._model.load_state_dict(weights)

def compute_gradients(self, batch):
def compute_gradients(self, samples):
"""_backward generates the gradient in each model parameter.
This is taken out.
Args:
batch: Batch of data needed for gradient calculation.
samples: SampleBatch of data needed for gradient calculation.
Return:
gradients (list of np arrays): List of gradients
info (dict): Extra information (user-defined)"""
with self.lock:
self._backward(batch)
self._backward(samples)
# Note that return values are just references;
# calling zero_grad will modify the values
return [p.grad.data.numpy() for p in self._model.parameters()], {}
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _init(self):
self.registry, self.env_creator, self.config, self.logdir)
remote_cls = ray.remote(
num_cpus=1, num_gpus=self.config["num_gpus_per_worker"])(
DQNReplayEvaluator)
DQNReplayEvaluator)
remote_config = dict(self.config, num_workers=1)
# In async mode, we create N remote evaluators, each with their
# own replay buffer (i.e. the replay buffer is sharded).
Expand Down
5 changes: 4 additions & 1 deletion python/ray/rllib/dqn/dqn_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@


class DQNEvaluator(TFMultiGPUSupport):
"""The base DQN Evaluator that does not include the replay buffer."""
"""The base DQN Evaluator that does not include the replay buffer.
TODO(rliaw): Support observation/reward filters?"""

def __init__(self, registry, env_creator, config, logdir):
env = env_creator()
Expand Down Expand Up @@ -46,6 +48,7 @@ def __init__(self, registry, env_creator, config, logdir):
self.episode_rewards = [0.0]
self.episode_lengths = [0.0]
self.saved_mean_reward = None

self.obs = self.env.reset()

def set_global_timestep(self, global_timestep):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/dqn/dqn_replay_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sample(self, no_replay=False):
row["dones"])

if no_replay:
return samples
return SampleBatch.concat_samples(samples)

# Then return a batch sampled from the buffer
if self.config["prioritized_replay"]:
Expand Down
13 changes: 9 additions & 4 deletions python/ray/rllib/optimizers/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, *args, **kwargs):
assert type(k) == str, self
lengths.append(len(v))
assert len(set(lengths)) == 1, "data columns must be same length"
self.count = lengths[0]

@staticmethod
def concat_samples(samples):
Expand Down Expand Up @@ -56,8 +57,7 @@ def rows(self):
{"a": 3, "b": 6}
"""

num_rows = len(list(self.data.values())[0])
for i in range(num_rows):
for i in range(self.count):
row = {}
for k in self.data.keys():
row[k] = self[k][i]
Expand All @@ -77,11 +77,16 @@ def columns(self, keys):
out.append(self.data[k])
return out

def shuffle(self):
permutation = np.random.permutation(self.count)
for key, val in self.data.items():
self.data[key] = val[permutation]

def __getitem__(self, key):
return self.data[key]

def __str__(self):
return str(self.data)
return "SampleBatch({})".format(str(self.data))

def __repr__(self):
return str(self.data)
return "SampleBatch({})".format(str(self.data))
33 changes: 0 additions & 33 deletions python/ray/rllib/optimizers/util.py

This file was deleted.

Loading

0 comments on commit 3304099

Please sign in to comment.