forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[rllib] A3C Refactoring (ray-project#1166)
* fixing policy * Compute Action is singular, fixed weird issue with arrays * remove vestige * extraneous ipdb * Can Drop in Pytorch Model * lint * naming * finish comments
- Loading branch information
1 parent
4cace09
commit dc66a2d
Showing
12 changed files
with
404 additions
and
344 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import scipy.signal | ||
from collections import namedtuple | ||
|
||
|
||
def discount(x, gamma): | ||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] | ||
|
||
|
||
def process_rollout(rollout, gamma, lambda_=1.0): | ||
"""Given a rollout, compute its returns and the advantage.""" | ||
batch_si = np.asarray(rollout.states) | ||
batch_a = np.asarray(rollout.actions) | ||
rewards = np.asarray(rollout.rewards) | ||
vpred_t = np.asarray(rollout.values + [rollout.r]) | ||
|
||
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) | ||
batch_r = discount(rewards_plus_v, gamma)[:-1] | ||
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] | ||
# This formula for the advantage comes "Generalized Advantage Estimation": | ||
# https://arxiv.org/abs/1506.02438 | ||
batch_adv = discount(delta_t, gamma * lambda_) | ||
|
||
features = rollout.features[0] | ||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, | ||
features) | ||
|
||
|
||
Batch = namedtuple( | ||
"Batch", ["si", "a", "adv", "r", "terminal", "features"]) | ||
|
||
CompletedRollout = namedtuple( | ||
"CompletedRollout", ["episode_length", "episode_reward"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.