Skip to content

Commit

Permalink
Merge pull request facebookresearch#141 from JanS97/main
Browse files Browse the repository at this point in the history
[Bug-fix] Made truncated_normal_ completely in-place
  • Loading branch information
luisenp authored Jan 10, 2022
2 parents 08281cc + 2811181 commit 4aa4447
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
15 changes: 7 additions & 8 deletions mbrl/util/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def gaussian_nll(

# inplace truncated normal function for pytorch.
# credit to https://github.com/Xingyu-Lin/mbpo_pytorch/blob/main/model.py#L64
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
def truncated_normal_(
tensor: torch.Tensor, mean: float = 0, std: float = 1
) -> torch.Tensor:
"""Samples from a truncated normal distribution in-place.
Args:
Expand All @@ -81,14 +83,11 @@ def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
torch.nn.init.normal_(tensor, mean=mean, std=std)
while True:
cond = torch.logical_or(tensor < mean - 2 * std, tensor > mean + 2 * std)
if not torch.sum(cond):
bound_violations = torch.sum(cond).item()
if bound_violations == 0:
break
tensor = torch.where(
cond,
torch.nn.init.normal_(
torch.ones(tensor.shape, device=tensor.device), mean=mean, std=std
),
tensor,
tensor[cond] = torch.normal(
mean, std, size=(bound_violations,), device=tensor.device
)
return tensor

Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,10 @@ def test_bootstrap_rb_sample_obs3d():
for i in range(ensemble_size):
for j in range(i + 1, ensemble_size):
assert not np.array_equal(batch.obs[i], batch.obs[j])


def test_truncated_normal():
t_original = torch.empty((100, 2))
t_new = mbrl.util.math.truncated_normal_(t_original)
assert t_original is t_new
assert (t_original > -2).all().item() and (t_original < 2).all().item()

0 comments on commit 4aa4447

Please sign in to comment.