Skip to content

Commit

Permalink
Masking error. With t*valid_mask, we get the error np.inf*0 = np.inf (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
09wakharet authored Jul 12, 2020
1 parent 381c242 commit 3536d8e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
num_valid = torch.sum(valid_mask)

def reduce_mean_valid(t):
return torch.sum(t * valid_mask) / num_valid
return torch.sum(t[valid_mask]) / num_valid

else:

Expand Down Expand Up @@ -195,7 +195,8 @@ def value(ob, prev_action, prev_reward, *state):
np.asarray([prev_reward])),
"is_training": False,
}, [convert_to_torch_tensor(np.asarray([s])) for s in state],
convert_to_torch_tensor(np.asarray([1])))
convert_to_torch_tensor(
np.asarray([1])))
return self.model.value_function()[0]

else:
Expand Down

0 comments on commit 3536d8e

Please sign in to comment.