diff --git a/tests/core/test_common_utils.py b/tests/core/test_common_utils.py index 23ddb7e7..d91150a9 100644 --- a/tests/core/test_common_utils.py +++ b/tests/core/test_common_utils.py @@ -372,3 +372,10 @@ def test_bootstrap_rb_sample_obs3d(): for i in range(ensemble_size): for j in range(i + 1, ensemble_size): assert not np.array_equal(batch.obs[i], batch.obs[j]) + + +def test_truncated_normal(): + t_original = torch.empty((100, 2)) + t_new = mbrl.util.math.truncated_normal_(t_original) + assert t_original is t_new + assert (t_original > -2).all().item() and (t_original < 2).all().item()