Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Redo issue 14533 tf enable eager exec #14984

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip.
  • Loading branch information
sven1977 committed Mar 29, 2021
commit 761242d0098d3573722c7262266677ebe223f389
1 change: 1 addition & 0 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def _ppo_loss_helper(self,
policy.model)
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
if isinstance(model, TorchModelV2):
train_batch.set_get_interceptor(None)
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
Expand Down
1 change: 0 additions & 1 deletion rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,6 @@ def _validate_config(config: PartialTrainerConfigDict):
raise ValueError("`simple_optimizer=False` not supported for "
"framework={}!".format(framework))


# Offline RL settings.
if isinstance(config["input_evaluation"], tuple):
config["input_evaluation"] = list(config["input_evaluation"])
Expand Down
11 changes: 5 additions & 6 deletions rllib/tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.rllib.examples.models.rnn_spy_model import RNNSpyModel
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.utils.test_utils import check
from ray.tune.registry import register_env


Expand Down Expand Up @@ -199,15 +200,13 @@ def test_minibatch_sequencing(self):
batch0, batch1 = batch1, batch0 # sort minibatches
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4])
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3])
self.assertEqual(batch0["sequences"].tolist(), [
check(batch0["sequences"], [
[[0], [1], [2], [3]],
[[4], [5], [6], [7]],
[[8], [9], [0], [0]],
])
self.assertEqual(batch1["sequences"].tolist(), [
check(batch1["sequences"], [
[[8], [9], [10], [11]],
[[12], [13], [14], [0]],
[[0], [1], [2], [0]],
])

# second epoch: 20 observations get split into 2 minibatches of 8
Expand All @@ -220,11 +219,11 @@ def test_minibatch_sequencing(self):
batch2, batch3 = batch3, batch2
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4])
self.assertEqual(batch3["seq_lens"].tolist(), [2, 4])
self.assertEqual(batch2["sequences"].tolist(), [
check(batch2["sequences"], [
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
])
self.assertEqual(batch3["sequences"].tolist(), [
check(batch3["sequences"], [
[[13], [14], [0], [0]],
[[0], [1], [2], [3]],
])
Expand Down