Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Oct 29, 2020
1 parent ddd9847 commit 5ff50c7
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 42 deletions.
26 changes: 16 additions & 10 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Fake CartPole episode of n time steps.
FAKE_BATCH = {
SampleBatch.CUR_OBS: np.array(
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
Expand All @@ -31,6 +31,8 @@
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[-2., 0.5], [-3., -0.3], [-0.1, 2.5]], dtype=np.float32),
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}


Expand Down Expand Up @@ -119,7 +121,10 @@ def test_ppo_exploration_setup(self):
# Test whether this is really the argmax action over the logits.
if fw != "tf":
last_out = trainer.get_policy().model.last_output()
check(a_, np.argmax(last_out.numpy(), 1)[0])
if fw == "torch":
check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
else:
check(a_, np.argmax(last_out.numpy(), 1)[0])
for _ in range(50):
a = trainer.compute_action(
obs,
Expand Down Expand Up @@ -171,7 +176,7 @@ def get_value():
if fw == "tf":
return policy.get_session().run(log_std_var)[0]
elif fw == "torch":
return log_std_var.detach().numpy()[0]
return log_std_var.detach().cpu().numpy()[0]
else:
return log_std_var.numpy()[0]

Expand All @@ -180,9 +185,9 @@ def get_value():
assert init_std == 0.0, init_std

if fw in ["tf2", "tf", "tfe"]:
batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH)
batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
else:
batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH)
batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH.copy())
batch = policy._lazy_tensor_dict(batch)
policy.learn_on_batch(batch)

Expand Down Expand Up @@ -222,9 +227,10 @@ def test_ppo_loss_function(self):
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
if fw in ["tf2", "tf", "tfe"]:
train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH)
train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
else:
train_batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH)
train_batch = postprocess_ppo_gae_torch(
policy, FAKE_BATCH.copy())
train_batch = policy._lazy_tensor_dict(train_batch)

# Check Advantage values.
Expand Down Expand Up @@ -307,12 +313,12 @@ def _ppo_loss_helper(self,
policy.model)
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
if isinstance(model, TorchModelV2):
expected_rho = np.exp(expected_logp.detach().numpy() -
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
train_batch.get(SampleBatch.ACTION_LOGP))
# KL(prev vs current action dist)-loss component.
kl = np.mean(dist_prev.kl(dist).detach().numpy())
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
# Entropy-loss component.
entropy = np.mean(dist.entropy().detach().numpy())
entropy = np.mean(dist.entropy().detach().cpu().numpy())
else:
if sess:
expected_logp = sess.run(expected_logp)
Expand Down
3 changes: 3 additions & 0 deletions rllib/contrib/maddpg/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 0,

# Do not use with trajectory view API for now.
"_use_trajectory_view_api": False,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down
5 changes: 0 additions & 5 deletions rllib/contrib/maddpg/maddpg_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,6 @@ def _make_loss_inputs(placeholders):
loss_inputs=loss_inputs,
dist_inputs=actor_feature)

# Additional view requirements for postprocessing.
self.view_requirements["infos"] = \
ViewRequirement(used_for_training=False)
self.view_requirements["t"] = ViewRequirement()

self.sess.run(tf1.global_variables_initializer())

# Hard initial update
Expand Down
7 changes: 4 additions & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,11 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
if shift == 0:
batch_data[view_col] = np_data[data_col][self.shift_before:]
data = np_data[data_col][self.shift_before:]
else:
batch_data[view_col] = np_data[data_col][self.shift_before +
shift:shift]
data = np_data[data_col][self.shift_before + shift:shift]
if len(data) > 0:
batch_data[view_col] = data
batch = SampleBatch(batch_data)

if SampleBatch.UNROLL_ID not in batch.data:
Expand Down
3 changes: 2 additions & 1 deletion rllib/execution/train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def __call__(self,
policy._debug_vars()
tuples = policy._get_loss_inputs_dict(
batch, shuffle=self.shuffle_sequences)
data_keys = [ph for _, ph in policy._loss_inputs]
data_keys = \
[ph for ph in policy._loss_input_dict_no_rnn.values()]
if policy._state_inputs:
state_keys = policy._state_inputs + [policy._seq_lens]
else:
Expand Down
30 changes: 20 additions & 10 deletions rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gym
import logging
import numpy as np
import re
from typing import Callable, Dict, List, Optional, Tuple, Type

from ray.util.debug import log_once
Expand Down Expand Up @@ -398,8 +399,18 @@ def _get_input_dict_and_dummy_batch(self, view_requirements,
input_dict = {}
dummy_batch = {}
for view_col, view_req in view_requirements.items():
# Point state_in to the already existing self._state_inputs.
mo = re.match("state_in_(\d+)", view_col)
if mo is not None:
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
# State-outs (no placeholders needed).
elif view_col.startswith("state_out_"):
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
# Skip action dist inputs placeholder (do later).
if view_col == SampleBatch.ACTION_DIST_INPUTS:
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
continue
elif view_col in existing_inputs:
input_dict[view_col] = existing_inputs[view_col]
Expand Down Expand Up @@ -523,13 +534,11 @@ def fake_array(tensor):
"Initializing loss function with dummy input:\n\n{}\n".format(
summarize(train_batch)))

self._loss_input_dict = train_batch
self._loss_input_dict = {k: v for k, v in train_batch.items()}
loss = self._do_loss_init(train_batch)
#for k in sorted(train_batch.accessed_keys):
# if k != "seq_lens" and not k.startswith("state_in_"):
# loss_inputs.append((k, train_batch[k]))

TFPolicy._initialize_loss(self, loss, [(k, v) for k, v in train_batch.items()]) #loss_inputs
del self._loss_input_dict["is_training"]
if self._grad_stats_fn:
self._stats_fetches.update(
self._grad_stats_fn(self, train_batch, self._grads))
Expand All @@ -546,7 +555,8 @@ def fake_array(tensor):
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys:
self.view_requirements[key].used_for_training = False
del self._loss_input_dict[key]
if key in self._loss_input_dict:
del self._loss_input_dict[key]
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
for key in list(self.view_requirements.keys()):
Expand All @@ -564,10 +574,10 @@ def fake_array(tensor):
used_for_training = vr.data_col in train_batch.accessed_keys
self.view_requirements[vr.data_col] = ViewRequirement(space=vr.space, used_for_training=used_for_training)

self._loss_input_dict_no_rnn = {k: v for k, v in
self._loss_input_dict.items() if
not k.startswith(
"state_in_") and v != self._seq_lens}
self._loss_input_dict_no_rnn = {
k: v for k, v in self._loss_input_dict.items() if
not v in self._state_inputs and v != self._seq_lens
}

def _do_loss_init(self, train_batch: SampleBatch):
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
Expand Down
1 change: 1 addition & 0 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def _get_default_view_requirements(self):
SampleBatch.ACTION_LOGP: ViewRequirement(),
SampleBatch.VF_PREDS: ViewRequirement(),
PRIO_WEIGHTS: ViewRequirement(),
"t": ViewRequirement(),
}
## Add the state-in/out views in case the policy has an RNN.
#if policy.is_recurrent():
Expand Down
6 changes: 5 additions & 1 deletion rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def __init__(self,
"Model classes for TFPolicy other than `ModelV2` not allowed! " \
"You passed in {}.".format(model)
self.model = model
# Auto-update model's inference view requirements, if recurrent.
if self.model is not None:
self.model.update_view_requirements_from_init_state()

self.exploration = self._create_exploration()
self._sess = sess
Expand Down Expand Up @@ -803,7 +806,8 @@ def _get_loss_inputs_dict(self, batch, shuffle):
shuffle=shuffle,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self._batch_divisibility_req,
feature_keys=list(self._loss_input_dict.keys()), #[k for k, v in self._loss_inputs])
feature_keys=[
k for k in self._loss_input_dict.keys() if k != "seq_lens"],
)
batch["is_training"] = True

Expand Down
25 changes: 13 additions & 12 deletions rllib/policy/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,19 @@ def initialize_loss_with_dummy_batch(policy, auto=True):
for key in all_accessed_keys:
if key not in policy.view_requirements:
policy.view_requirements[key] = ViewRequirement()
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys:
policy.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
for key in list(policy.view_requirements.keys()):
if key not in all_accessed_keys and key not in [
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.UNROLL_ID, SampleBatch.DONES] and \
key not in policy.model.inference_view_requirements:
del policy.view_requirements[key]
if policy._loss:
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys:
policy.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
for key in list(policy.view_requirements.keys()):
if key not in all_accessed_keys and key not in [
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.UNROLL_ID, SampleBatch.DONES] and \
key not in policy.model.inference_view_requirements:
del policy.view_requirements[key]
# Add those data_cols (again) that are missing and have
# dependencies by view_cols.
for key in list(policy.view_requirements.keys()):
Expand Down

0 comments on commit 5ff50c7

Please sign in to comment.