Skip to content

Commit

Permalink
[rllib] Fix truncate episodes mode in central critic example (#8073)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Apr 20, 2020
1 parent 3812bfe commit 17e3c54
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def centralized_critic_postprocessing(policy,
other_agent_batches=None,
episode=None):
if policy.loss_initialized():
assert sample_batch["dones"][-1], \
"Not implemented for train_batch_mode=truncate_episodes"
assert other_agent_batches is not None
[(_, opponent_batch)] = list(other_agent_batches.values())

Expand All @@ -116,11 +114,17 @@ def centralized_critic_postprocessing(policy,
sample_batch[OPPONENT_ACTION] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS])
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS], dtype=np.float32)
sample_batch[SampleBatch.REWARDS], dtype=np.float32)

completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
last_r = sample_batch[SampleBatch.VF_PREDS][-1]

train_batch = compute_advantages(
sample_batch,
0.0,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
Expand Down

0 comments on commit 17e3c54

Please sign in to comment.