Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Nov 2, 2020
1 parent f9cd241 commit dad163c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def postprocess_nstep_and_prio(policy: Policy,
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (np.abs(convert_to_numpy(td_errors)) +
policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities
batch[PRIO_WEIGHTS] = new_priorities

return batch

Expand Down
9 changes: 8 additions & 1 deletion rllib/agents/dqn/tests/test_simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ def test_simple_q_loss_function(self):
SampleBatch.ACTIONS: np.array([0, 1]),
SampleBatch.REWARDS: np.array([0.4, -1.23]),
SampleBatch.DONES: np.array([False, False]),
SampleBatch.NEXT_OBS: np.random.random(size=(2, 4))
SampleBatch.NEXT_OBS: np.random.random(size=(2, 4)),
SampleBatch.EPS_ID: np.array([1234, 1234]),
SampleBatch.AGENT_INDEX: np.array([0, 0]),
SampleBatch.ACTION_LOGP: np.array([-0.1, -0.1]),
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[0.1, 0.2], [-0.1, -0.2]]),
SampleBatch.ACTION_PROB: np.array([0.1, 0.2]),
"q_values": np.array([[0.1, 0.2], [0.2, 0.1]]),
}
# Get model vars for computing expected model outs (q-vals).
# 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
Expand Down
11 changes: 6 additions & 5 deletions rllib/policy/torch_policy_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,18 @@ def __init__(self, obs_space, action_space, config):
self.view_requirements.update(
self.model.inference_view_requirements)

if before_loss_init:
before_loss_init(self, self.observation_space,
self.action_space, config)
_before_loss_init = before_loss_init or after_init
if _before_loss_init:
_before_loss_init(self, self.observation_space,
self.action_space, config)

self._initialize_loss_dynamically(
auto_remove_unneeded_view_reqs=view_requirements_fn is None,
stats_fn=stats_fn,
)

if after_init:
after_init(self, obs_space, action_space, config)
if _after_loss_init:
_after_loss_init(self, obs_space, action_space, config)

# Got to reset global_timestep again after this fake run-through.
self.global_timestep = 0
Expand Down

0 comments on commit dad163c

Please sign in to comment.