Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Oct 27, 2020
1 parent eec4bc5 commit a3aebdf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
explained_variance, sequence_mask
Expand Down Expand Up @@ -228,6 +229,8 @@ def value(ob, prev_action, prev_reward, *state):
# [0] = remove the batch dim.
return self.model.value_function()[0]

self.view_requirements[SampleBatch.VF_PREDS] = ViewRequirement()

# When not doing GAE, we do not require the value function's output.
else:

Expand Down
1 change: 1 addition & 0 deletions rllib/agents/ppo/tests/test_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_appo_compilation(self):
for _ in framework_iterator(config):
print("w/o v-trace")
_config = config.copy()
_config["vtrace"] = False
trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0")
for i in range(num_iterations):
print(trainer.train())
Expand Down

0 comments on commit a3aebdf

Please sign in to comment.