Skip to content

Commit 68ae620

Browse files
author
Vincent Moens
committed
[Feature] non-functional SAC loss
ghstack-source-id: f904e47 Pull Request resolved: #2393
1 parent b11b641 commit 68ae620

21 files changed

+591
-181
lines changed

test/test_cost.py

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
from torchrl.objectives.redq import REDQLoss
114114
from torchrl.objectives.reinforce import ReinforceLoss
115115
from torchrl.objectives.utils import (
116-
_vmap_func,
116+
_maybe_vmap_maybe_func,
117117
HardUpdate,
118118
hold_out_net,
119119
SoftUpdate,
@@ -249,11 +249,11 @@ def __init__(self):
249249
layers.append(nn.Linear(4, 4))
250250
net = nn.Sequential(*layers).to(device)
251251
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
252-
self.convert_to_functional(model, "model", expand_dim=4)
252+
self.maybe_convert_to_functional(model, "model", expand_dim=4)
253253
self._make_vmap()
254254

255255
def _make_vmap(self):
256-
self.vmap_model = _vmap_func(
256+
self.vmap_model = _maybe_vmap_maybe_func(
257257
self.model,
258258
(None, 0),
259259
randomness=(
@@ -3871,6 +3871,116 @@ def test_sac_vmap_equiv(
38713871

38723872
assert_allclose_td(loss_vmap, loss_novmap)
38733873

3874+
@pytest.mark.parametrize("device", get_default_devices())
3875+
@pytest.mark.parametrize("as_list", [True, False])
3876+
@pytest.mark.parametrize("provide_target", [True, False])
3877+
@pytest.mark.parametrize("delay_value", (True, False))
3878+
@pytest.mark.parametrize("delay_actor", (True, False))
3879+
@pytest.mark.parametrize("delay_qvalue", (True, False))
3880+
def test_sac_nofunc(
3881+
self,
3882+
device,
3883+
version,
3884+
as_list,
3885+
provide_target,
3886+
delay_value,
3887+
delay_actor,
3888+
delay_qvalue,
3889+
num_qvalue=4,
3890+
td_est=None,
3891+
):
3892+
if (delay_actor or delay_qvalue) and not delay_value:
3893+
pytest.skip("incompatible config")
3894+
3895+
torch.manual_seed(self.seed)
3896+
td = self._create_mock_data_sac(device=device)
3897+
3898+
kwargs = {}
3899+
3900+
actor = self._create_mock_actor(device=device)
3901+
if delay_actor:
3902+
kwargs["delay_actor"] = True
3903+
if provide_target:
3904+
kwargs["target_actor_network"] = self._create_mock_actor(device=device)
3905+
kwargs["target_actor_network"].load_state_dict(actor.state_dict())
3906+
if as_list:
3907+
qvalue = [
3908+
self._create_mock_qvalue(device=device) for _ in range(num_qvalue)
3909+
]
3910+
else:
3911+
qvalue = self._create_mock_qvalue(device=device)
3912+
3913+
if delay_qvalue:
3914+
kwargs["delay_qvalue"] = True
3915+
if provide_target:
3916+
if as_list:
3917+
kwargs["target_qvalue_network"] = [
3918+
self._create_mock_qvalue(device=device) for _ in qvalue
3919+
]
3920+
for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue):
3921+
qval_t.load_state_dict(qval.state_dict())
3922+
3923+
else:
3924+
kwargs["target_qvalue_network"] = self._create_mock_qvalue(
3925+
device=device
3926+
)
3927+
kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict())
3928+
3929+
if version == 1:
3930+
value = self._create_mock_value(device=device)
3931+
else:
3932+
value = None
3933+
if delay_value:
3934+
kwargs["delay_value"] = True
3935+
if provide_target and version == 1:
3936+
kwargs["target_value_network"] = self._create_mock_value(device=device)
3937+
kwargs["target_value_network"].load_state_dict(value.state_dict())
3938+
3939+
rng_state = torch.random.get_rng_state()
3940+
with pytest.warns(
3941+
UserWarning, match="The target network is ignored as the"
3942+
) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext():
3943+
loss_fn_nofunc = SACLoss(
3944+
actor_network=actor,
3945+
qvalue_network=qvalue,
3946+
value_network=value,
3947+
num_qvalue_nets=num_qvalue,
3948+
loss_function="l2",
3949+
use_vmap=False,
3950+
functional=False,
3951+
**kwargs,
3952+
)
3953+
torch.random.set_rng_state(rng_state)
3954+
loss_fn_func = SACLoss(
3955+
actor_network=actor,
3956+
qvalue_network=qvalue,
3957+
value_network=value,
3958+
num_qvalue_nets=num_qvalue,
3959+
loss_function="l2",
3960+
use_vmap=False,
3961+
functional=True,
3962+
**kwargs,
3963+
)
3964+
assert_allclose_td(
3965+
torch.stack(
3966+
list(
3967+
TensorDict.from_module(loss_fn_nofunc.qvalue_network)[
3968+
"module"
3969+
].values()
3970+
)
3971+
),
3972+
loss_fn_func.qvalue_network_params.data,
3973+
)
3974+
with torch.no_grad(), _check_td_steady(td), pytest.warns(
3975+
UserWarning, match="No target network updater"
3976+
):
3977+
rng_state = torch.random.get_rng_state()
3978+
loss_func = loss_fn_nofunc(td.clone())
3979+
torch.random.set_rng_state(rng_state)
3980+
loss_nofunc = loss_fn_func(td.clone())
3981+
3982+
assert_allclose_td(loss_func, loss_nofunc)
3983+
38743984
@pytest.mark.parametrize("delay_value", (True, False))
38753985
@pytest.mark.parametrize("delay_actor", (True, False))
38763986
@pytest.mark.parametrize("delay_qvalue", (True, False))
@@ -12378,7 +12488,7 @@ class MyLoss(LossModule):
1237812488

1237912489
def __init__(self, actor_network):
1238012490
super().__init__()
12381-
self.convert_to_functional(
12491+
self.maybe_convert_to_functional(
1238212492
actor_network,
1238312493
"actor_network",
1238412494
create_target_params=create_target_params,
@@ -12527,7 +12637,7 @@ class custom_module(LossModule):
1252712637
def __init__(self, delay_module=True):
1252812638
super().__init__()
1252912639
module1 = torch.nn.BatchNorm2d(10).eval()
12530-
self.convert_to_functional(
12640+
self.maybe_convert_to_functional(
1253112641
module1, "module1", create_target_params=delay_module
1253212642
)
1253312643

@@ -14291,12 +14401,12 @@ class MyLoss(LossModule):
1429114401

1429214402
def __init__(self, actor_network, qvalue_network):
1429314403
super().__init__()
14294-
self.convert_to_functional(
14404+
self.maybe_convert_to_functional(
1429514405
actor_network,
1429614406
"actor_network",
1429714407
create_target_params=True,
1429814408
)
14299-
self.convert_to_functional(
14409+
self.maybe_convert_to_functional(
1430014410
qvalue_network,
1430114411
"qvalue_network",
1430214412
3,
@@ -14864,8 +14974,8 @@ def __init__(self, compare_against, expand_dim):
1486414974
module_b = TensorDictModule(
1486514975
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
1486614976
)
14867-
self.convert_to_functional(module_a, "module_a")
14868-
self.convert_to_functional(
14977+
self.maybe_convert_to_functional(module_a, "module_a")
14978+
self.maybe_convert_to_functional(
1486914979
module_b,
1487014980
"module_b",
1487114981
compare_against=module_a.parameters() if compare_against else [],
@@ -14913,8 +15023,8 @@ def __init__(self, expand_dim=2):
1491315023
module_b = TensorDictModule(
1491415024
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
1491515025
)
14916-
self.convert_to_functional(module_a, "module_a")
14917-
self.convert_to_functional(
15026+
self.maybe_convert_to_functional(module_a, "module_a")
15027+
self.maybe_convert_to_functional(
1491815028
module_b,
1491915029
"module_b",
1492015030
compare_against=module_a.parameters(),
@@ -14962,8 +15072,8 @@ class MyLoss(LossModule):
1496215072

1496315073
def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
1496415074
super().__init__()
14965-
self.convert_to_functional(module_a, "module_a")
14966-
self.convert_to_functional(
15075+
self.maybe_convert_to_functional(module_a, "module_a")
15076+
self.maybe_convert_to_functional(
1496715077
[module_b0, module_b1],
1496815078
"module_b",
1496915079
# This will be ignored
@@ -15332,14 +15442,14 @@ def __init__(self):
1533215442
TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]),
1533315443
)
1533415444
super().__init__()
15335-
self.convert_to_functional(
15445+
self.maybe_convert_to_functional(
1533615446
actor,
1533715447
"actor",
1533815448
expand_dim=None,
1533915449
create_target_params=False,
1534015450
compare_against=None,
1534115451
)
15342-
self.convert_to_functional(
15452+
self.maybe_convert_to_functional(
1534315453
value,
1534415454
"value",
1534515455
expand_dim=2,

torchrl/objectives/a2c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __init__(
278278
)
279279

280280
if functional:
281-
self.convert_to_functional(
281+
self.maybe_convert_to_functional(
282282
actor_network,
283283
"actor_network",
284284
)
@@ -292,7 +292,7 @@ def __init__(
292292
else:
293293
policy_params = None
294294
if functional:
295-
self.convert_to_functional(
295+
self.maybe_convert_to_functional(
296296
critic_network, "critic_network", compare_against=policy_params
297297
)
298298
else:

0 commit comments

Comments
 (0)