Skip to content

Commit 00b7c2e

Browse files
author
Vincent Moens
authored
[BugFix] Fix tanh normal mode (#2198)
1 parent 1062e3e commit 00b7c2e

20 files changed

+383
-129
lines changed

test/test_actors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
130130
out_keys=[("data", "action")],
131131
distribution_class=TanhNormal,
132132
distribution_kwargs={
133-
"min": action_spec.space.low,
134-
"max": action_spec.space.high,
133+
"low": action_spec.space.low,
134+
"high": action_spec.space.high,
135135
},
136136
log_prob_key=log_prob_key,
137137
return_log_prob=True,
@@ -153,8 +153,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
153153
out_keys=[("data", "action")],
154154
distribution_class=TanhNormal,
155155
distribution_kwargs={
156-
"min": action_spec.space.low,
157-
"max": action_spec.space.high,
156+
"low": action_spec.space.low,
157+
"high": action_spec.space.high,
158158
},
159159
log_prob_key=log_prob_key,
160160
return_log_prob=True,

test/test_cost.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13519,17 +13519,36 @@ def __init__(self):
1351913519

1352013520
def test_loss_exploration():
1352113521
class DummyLoss(LossModule):
13522-
def forward(self, td):
13523-
assert exploration_type() == InteractionType.MODE
13522+
def forward(self, td, mode):
13523+
if mode is None:
13524+
mode = self.deterministic_sampling_mode
13525+
assert exploration_type() == mode
1352413526
with set_exploration_type(ExplorationType.RANDOM):
1352513527
assert exploration_type() == ExplorationType.RANDOM
13526-
assert exploration_type() == ExplorationType.MODE
13528+
assert exploration_type() == mode
1352713529
return td
1352813530

1352913531
loss_fn = DummyLoss()
1353013532
with set_exploration_type(ExplorationType.RANDOM):
1353113533
assert exploration_type() == ExplorationType.RANDOM
13532-
loss_fn(None)
13534+
loss_fn(None, None)
13535+
assert exploration_type() == ExplorationType.RANDOM
13536+
13537+
with set_exploration_type(ExplorationType.RANDOM):
13538+
assert exploration_type() == ExplorationType.RANDOM
13539+
loss_fn(None, ExplorationType.DETERMINISTIC)
13540+
assert exploration_type() == ExplorationType.RANDOM
13541+
13542+
loss_fn.deterministic_sampling_mode = ExplorationType.MODE
13543+
with set_exploration_type(ExplorationType.RANDOM):
13544+
assert exploration_type() == ExplorationType.RANDOM
13545+
loss_fn(None, ExplorationType.MODE)
13546+
assert exploration_type() == ExplorationType.RANDOM
13547+
13548+
loss_fn.deterministic_sampling_mode = ExplorationType.MEAN
13549+
with set_exploration_type(ExplorationType.RANDOM):
13550+
assert exploration_type() == ExplorationType.RANDOM
13551+
loss_fn(None, ExplorationType.MEAN)
1353313552
assert exploration_type() == ExplorationType.RANDOM
1353413553

1353513554

test/test_distributions.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ def _map_all(*tensors_or_other, device):
8585

8686
class TestTanhNormal:
8787
@pytest.mark.parametrize(
88-
"min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1]
88+
"low", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1]
8989
)
9090
@pytest.mark.parametrize(
91-
"max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1]
91+
"high", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1]
9292
)
9393
@pytest.mark.parametrize(
9494
"vecs",
@@ -102,25 +102,64 @@ class TestTanhNormal:
102102
)
103103
@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])])
104104
@pytest.mark.parametrize("device", get_default_devices())
105-
def test_tanhnormal(self, min, max, vecs, upscale, shape, device):
106-
min, max, vecs, upscale, shape = _map_all(
107-
min, max, vecs, upscale, shape, device=device
105+
def test_tanhnormal(self, low, high, vecs, upscale, shape, device):
106+
torch.manual_seed(0)
107+
low, high, vecs, upscale, shape = _map_all(
108+
low, high, vecs, upscale, shape, device=device
108109
)
109110
torch.manual_seed(0)
110111
d = TanhNormal(
111112
*vecs,
112113
upscale=upscale,
113-
min=min,
114-
max=max,
114+
low=low,
115+
high=high,
115116
)
116117
for _ in range(100):
117118
a = d.rsample(shape)
118119
assert a.shape[: len(shape)] == shape
119-
assert (a >= d.min).all()
120-
assert (a <= d.max).all()
120+
assert (a >= d.low).all()
121+
assert (a <= d.high).all()
121122
lp = d.log_prob(a)
122123
assert torch.isfinite(lp).all()
123124

125+
def test_tanhnormal_mode(self):
126+
# Checks that the std of the mode computed by tanh normal is within a certain range
127+
# when starting from close points
128+
129+
torch.manual_seed(0)
130+
# 10 start points with 1000 jitters around that
131+
# std of the loc is about 1e-4
132+
loc = torch.randn(10) + torch.randn(1000, 10) / 10000
133+
134+
t = TanhNormal(loc=loc, scale=0.5, low=-1, high=1, event_dims=0)
135+
136+
mode = t.get_mode()
137+
assert mode.shape == loc.shape
138+
empirical_mode, empirical_mode_lp = torch.zeros_like(loc), -float("inf")
139+
for v in torch.arange(-1, 1, step=0.01):
140+
lp = t.log_prob(v.expand_as(t.loc))
141+
empirical_mode = torch.where(lp > empirical_mode_lp, v, empirical_mode)
142+
empirical_mode_lp = torch.where(
143+
lp > empirical_mode_lp, lp, empirical_mode_lp
144+
)
145+
assert abs(empirical_mode - mode).max() < 0.1, abs(empirical_mode - mode).max()
146+
assert mode.shape == loc.shape
147+
assert (mode.std(0).max() < 0.1).all(), mode.std(0)
148+
149+
@pytest.mark.parametrize("event_dims", [0, 1, 2])
150+
def test_tanhnormal_event_dims(self, event_dims):
151+
scale = 1
152+
loc = torch.randn(1, 2, 3, 4)
153+
t = TanhNormal(loc=loc, scale=scale, event_dims=event_dims)
154+
sample = t.sample()
155+
assert sample.shape == loc.shape
156+
exp_shape = loc.shape[:-event_dims] if event_dims > 0 else loc.shape
157+
assert t.log_prob(sample).shape == exp_shape, (
158+
t.log_prob(sample).shape,
159+
event_dims,
160+
exp_shape,
161+
)
162+
124163

125164
class TestTruncatedNormal:
126165
@pytest.mark.parametrize(
@@ -159,13 +198,13 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device):
159198
a = d.rsample(shape)
160199
assert a.device == device
161200
assert a.shape[: len(shape)] == shape
162-
assert (a >= d.min).all()
163-
assert (a <= d.max).all()
201+
assert (a >= d.low).all()
202+
assert (a <= d.high).all()
164203
lp = d.log_prob(a)
165204
assert torch.isfinite(lp).all()
166-
oob_min = d.min.expand((*d.batch_shape, *d.event_shape)) - 1e-2
205+
oob_min = d.low.expand((*d.batch_shape, *d.event_shape)) - 1e-2
167206
assert not torch.isfinite(d.log_prob(oob_min)).any()
168-
oob_max = d.max.expand((*d.batch_shape, *d.event_shape)) + 1e-2
207+
oob_max = d.high.expand((*d.batch_shape, *d.event_shape)) + 1e-2
169208
assert not torch.isfinite(d.log_prob(oob_max)).any()
170209

171210
@pytest.mark.skipif(not _has_scipy, reason="scipy not installed")

test/test_exploration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def test_gsde(
585585
wrapper = NormalParamWrapper(model)
586586
module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
587587
distribution_class = TanhNormal
588-
distribution_kwargs = {"min": -bound, "max": bound}
588+
distribution_kwargs = {"low": -bound, "high": bound}
589589
spec = BoundedTensorSpec(
590590
-torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,)
591591
).to(device)

test/test_tensordictmodules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,8 +1416,8 @@ def test_dt_inference_wrapper(self, online):
14161416
)
14171417
dist_class = TanhDelta
14181418
dist_kwargs = {
1419-
"min": -1.0,
1420-
"max": 1.0,
1419+
"low": -1.0,
1420+
"high": 1.0,
14211421
}
14221422
actor = ProbabilisticActor(
14231423
in_keys=in_keys,

torchrl/collectors/collectors.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,9 @@ class SyncDataCollector(DataCollectorBase):
335335
information.
336336
Defaults to ``False``.
337337
exploration_type (ExplorationType, optional): interaction mode to be used when
338-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
339-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
340-
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
338+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
339+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
340+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
341341
return_same_td (bool, optional): if ``True``, the same TensorDict
342342
will be returned at each iteration, with its values
343343
updated. This feature should be used cautiously: if the same
@@ -1336,9 +1336,9 @@ class _MultiDataCollector(DataCollectorBase):
13361336
information.
13371337
Defaults to ``False``.
13381338
exploration_type (ExplorationType, optional): interaction mode to be used when
1339-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
1340-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
1341-
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
1339+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
1340+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
1341+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
13421342
reset_when_done (bool, optional): if ``True`` (default), an environment
13431343
that return a ``True`` value in its ``"done"`` or ``"truncated"``
13441344
entry will be reset at the corresponding indices.
@@ -2635,9 +2635,9 @@ class aSyncDataCollector(MultiaSyncDataCollector):
26352635
information.
26362636
Defaults to ``False``.
26372637
exploration_type (ExplorationType, optional): interaction mode to be used when
2638-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
2639-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
2640-
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
2638+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
2639+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
2640+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
26412641
reset_when_done (bool, optional): if ``True`` (default), an environment
26422642
that return a ``True`` value in its ``"done"`` or ``"truncated"``
26432643
entry will be reset at the corresponding indices.

torchrl/collectors/distributed/generic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,9 @@ class DistributedDataCollector(DataCollectorBase):
346346
information.
347347
Defaults to ``False``.
348348
exploration_type (ExplorationType, optional): interaction mode to be used when
349-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
350-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
351-
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
349+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
350+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
351+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
352352
collector_class (type or str, optional): a collector class for the remote node. Can be
353353
:class:`~torchrl.collectors.SyncDataCollector`,
354354
:class:`~torchrl.collectors.MultiSyncDataCollector`,

torchrl/collectors/distributed/ray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ class RayCollector(DataCollectorBase):
211211
information.
212212
Defaults to ``False``.
213213
exploration_type (ExplorationType, optional): interaction mode to be used when
214-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
215-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
216-
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
214+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
215+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
216+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
217217
collector_class (Python class): a collector class to be remotely instantiated. Can be
218218
:class:`~torchrl.collectors.SyncDataCollector`,
219219
:class:`~torchrl.collectors.MultiSyncDataCollector`,

torchrl/collectors/distributed/rpc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ class RPCDataCollector(DataCollectorBase):
187187
information.
188188
Defaults to ``False``.
189189
exploration_type (ExplorationType, optional): interaction mode to be used when
190-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
191-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
190+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
191+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
192+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
192193
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
193194
collector_class (type or str, optional): a collector class for the remote node. Can be
194195
:class:`~torchrl.collectors.SyncDataCollector`,

torchrl/collectors/distributed/sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ class DistributedSyncDataCollector(DataCollectorBase):
226226
information.
227227
Defaults to ``False``.
228228
exploration_type (ExplorationType, optional): interaction mode to be used when
229-
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
230-
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
231-
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
229+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
230+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
231+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
232232
collector_class (type or str, optional): a collector class for the remote node. Can be
233233
:class:`~torchrl.collectors.SyncDataCollector`,
234234
:class:`~torchrl.collectors.MultiSyncDataCollector`,

0 commit comments

Comments
 (0)