Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Oct 29, 2020
1 parent 5ff50c7 commit 5a0e682
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
16 changes: 8 additions & 8 deletions rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,18 +493,18 @@ def fake_array(tensor):
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
(SampleBatch.CUR_OBS, self._obs_input),
]
#loss_inputs = [
# (SampleBatch.PREV_ACTIONS, self._prev_action_input),
# (SampleBatch.PREV_REWARDS, self._prev_reward_input),
# (SampleBatch.CUR_OBS, self._obs_input),
#]
else:
train_batch = UsageTrackingDict({
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.CUR_OBS, self._obs_input),
]
#loss_inputs = [
# (SampleBatch.CUR_OBS, self._obs_input),
#]

for k, v in postprocessed_batch.items():
if k in train_batch:
Expand Down
14 changes: 8 additions & 6 deletions rllib/policy/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def __init__(self,
def initialize_loss_with_dummy_batch(policy, auto=True):
from ray.rllib.policy.sample_batch import SampleBatch

batch_size = max(policy.batch_divisibility_req, 2)
sample_batch_size = max(policy.batch_divisibility_req, 2)
B = 2 # For RNNs, have B=2, T=[depends on sample_batch_size]
policy._dummy_batch = _get_dummy_batch(
policy, batch_size=batch_size)
policy, batch_size=sample_batch_size)
input_dict = policy._lazy_tensor_dict(policy._dummy_batch)
actions, state_outs, extra_outs = \
policy.compute_actions_from_input_dict(input_dict)
Expand All @@ -80,17 +81,18 @@ def initialize_loss_with_dummy_batch(policy, auto=True):
# view setup.
i = 0
while "state_in_{}".format(i) in sb:
sb["state_in_{}".format(i)] = sb["state_in_{}".format(i)][:batch_size]
sb["state_in_{}".format(i)] = sb["state_in_{}".format(i)][:B]
if "state_out_{}".format(i) in sb:
sb["state_out_{}".format(i)] = \
sb["state_out_{}".format(i)][:batch_size]
sb["state_out_{}".format(i)][:B]
i += 1
batch_for_postproc = policy._lazy_numpy_dict(sb)
batch_for_postproc.count = sb.count
postprocessed_batch = policy.postprocess_trajectory(batch_for_postproc)
if state_outs:
seq_len = (policy.batch_divisibility_req // 2) or 1
postprocessed_batch["seq_lens"] = np.array([seq_len for _ in range(2)], dtype=np.int32)
seq_len = (policy.batch_divisibility_req // B) or 1
postprocessed_batch["seq_lens"] = \
np.array([seq_len for _ in range(B)], dtype=np.int32)
train_batch = policy._lazy_tensor_dict(postprocessed_batch)
if policy._loss is not None:
policy._loss(policy, policy.model, policy.dist_class, train_batch)
Expand Down

0 comments on commit 5a0e682

Please sign in to comment.