Skip to content

Commit 0dbb1a5

Browse files
author
Vincent Moens
committed
[Feature] non-functional SAC loss
ghstack-source-id: fd766d1 Pull Request resolved: #2393
1 parent 63a1457 commit 0dbb1a5

21 files changed

+591
-181
lines changed

test/test_cost.py

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
from torchrl.objectives.redq import REDQLoss
114114
from torchrl.objectives.reinforce import ReinforceLoss
115115
from torchrl.objectives.utils import (
116-
_vmap_func,
116+
_maybe_vmap_maybe_func,
117117
HardUpdate,
118118
hold_out_net,
119119
SoftUpdate,
@@ -254,11 +254,11 @@ def __init__(self):
254254
layers.append(nn.Linear(4, 4))
255255
net = nn.Sequential(*layers).to(device)
256256
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
257-
self.convert_to_functional(model, "model", expand_dim=4)
257+
self.maybe_convert_to_functional(model, "model", expand_dim=4)
258258
self._make_vmap()
259259

260260
def _make_vmap(self):
261-
self.vmap_model = _vmap_func(
261+
self.vmap_model = _maybe_vmap_maybe_func(
262262
self.model,
263263
(None, 0),
264264
randomness=(
@@ -3876,6 +3876,116 @@ def test_sac_vmap_equiv(
38763876

38773877
assert_allclose_td(loss_vmap, loss_novmap)
38783878

3879+
@pytest.mark.parametrize("device", get_default_devices())
3880+
@pytest.mark.parametrize("as_list", [True, False])
3881+
@pytest.mark.parametrize("provide_target", [True, False])
3882+
@pytest.mark.parametrize("delay_value", (True, False))
3883+
@pytest.mark.parametrize("delay_actor", (True, False))
3884+
@pytest.mark.parametrize("delay_qvalue", (True, False))
3885+
def test_sac_nofunc(
3886+
self,
3887+
device,
3888+
version,
3889+
as_list,
3890+
provide_target,
3891+
delay_value,
3892+
delay_actor,
3893+
delay_qvalue,
3894+
num_qvalue=4,
3895+
td_est=None,
3896+
):
3897+
if (delay_actor or delay_qvalue) and not delay_value:
3898+
pytest.skip("incompatible config")
3899+
3900+
torch.manual_seed(self.seed)
3901+
td = self._create_mock_data_sac(device=device)
3902+
3903+
kwargs = {}
3904+
3905+
actor = self._create_mock_actor(device=device)
3906+
if delay_actor:
3907+
kwargs["delay_actor"] = True
3908+
if provide_target:
3909+
kwargs["target_actor_network"] = self._create_mock_actor(device=device)
3910+
kwargs["target_actor_network"].load_state_dict(actor.state_dict())
3911+
if as_list:
3912+
qvalue = [
3913+
self._create_mock_qvalue(device=device) for _ in range(num_qvalue)
3914+
]
3915+
else:
3916+
qvalue = self._create_mock_qvalue(device=device)
3917+
3918+
if delay_qvalue:
3919+
kwargs["delay_qvalue"] = True
3920+
if provide_target:
3921+
if as_list:
3922+
kwargs["target_qvalue_network"] = [
3923+
self._create_mock_qvalue(device=device) for _ in qvalue
3924+
]
3925+
for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue):
3926+
qval_t.load_state_dict(qval.state_dict())
3927+
3928+
else:
3929+
kwargs["target_qvalue_network"] = self._create_mock_qvalue(
3930+
device=device
3931+
)
3932+
kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict())
3933+
3934+
if version == 1:
3935+
value = self._create_mock_value(device=device)
3936+
else:
3937+
value = None
3938+
if delay_value:
3939+
kwargs["delay_value"] = True
3940+
if provide_target and version == 1:
3941+
kwargs["target_value_network"] = self._create_mock_value(device=device)
3942+
kwargs["target_value_network"].load_state_dict(value.state_dict())
3943+
3944+
rng_state = torch.random.get_rng_state()
3945+
with pytest.warns(
3946+
UserWarning, match="The target network is ignored as the"
3947+
) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext():
3948+
loss_fn_nofunc = SACLoss(
3949+
actor_network=actor,
3950+
qvalue_network=qvalue,
3951+
value_network=value,
3952+
num_qvalue_nets=num_qvalue,
3953+
loss_function="l2",
3954+
use_vmap=False,
3955+
functional=False,
3956+
**kwargs,
3957+
)
3958+
torch.random.set_rng_state(rng_state)
3959+
loss_fn_func = SACLoss(
3960+
actor_network=actor,
3961+
qvalue_network=qvalue,
3962+
value_network=value,
3963+
num_qvalue_nets=num_qvalue,
3964+
loss_function="l2",
3965+
use_vmap=False,
3966+
functional=True,
3967+
**kwargs,
3968+
)
3969+
assert_allclose_td(
3970+
torch.stack(
3971+
list(
3972+
TensorDict.from_module(loss_fn_nofunc.qvalue_network)[
3973+
"module"
3974+
].values()
3975+
)
3976+
),
3977+
loss_fn_func.qvalue_network_params.data,
3978+
)
3979+
with torch.no_grad(), _check_td_steady(td), pytest.warns(
3980+
UserWarning, match="No target network updater"
3981+
):
3982+
rng_state = torch.random.get_rng_state()
3983+
loss_func = loss_fn_nofunc(td.clone())
3984+
torch.random.set_rng_state(rng_state)
3985+
loss_nofunc = loss_fn_func(td.clone())
3986+
3987+
assert_allclose_td(loss_func, loss_nofunc)
3988+
38793989
@pytest.mark.parametrize("delay_value", (True, False))
38803990
@pytest.mark.parametrize("delay_actor", (True, False))
38813991
@pytest.mark.parametrize("delay_qvalue", (True, False))
@@ -12383,7 +12493,7 @@ class MyLoss(LossModule):
1238312493

1238412494
def __init__(self, actor_network):
1238512495
super().__init__()
12386-
self.convert_to_functional(
12496+
self.maybe_convert_to_functional(
1238712497
actor_network,
1238812498
"actor_network",
1238912499
create_target_params=create_target_params,
@@ -12532,7 +12642,7 @@ class custom_module(LossModule):
1253212642
def __init__(self, delay_module=True):
1253312643
super().__init__()
1253412644
module1 = torch.nn.BatchNorm2d(10).eval()
12535-
self.convert_to_functional(
12645+
self.maybe_convert_to_functional(
1253612646
module1, "module1", create_target_params=delay_module
1253712647
)
1253812648

@@ -14296,12 +14406,12 @@ class MyLoss(LossModule):
1429614406

1429714407
def __init__(self, actor_network, qvalue_network):
1429814408
super().__init__()
14299-
self.convert_to_functional(
14409+
self.maybe_convert_to_functional(
1430014410
actor_network,
1430114411
"actor_network",
1430214412
create_target_params=True,
1430314413
)
14304-
self.convert_to_functional(
14414+
self.maybe_convert_to_functional(
1430514415
qvalue_network,
1430614416
"qvalue_network",
1430714417
3,
@@ -14869,8 +14979,8 @@ def __init__(self, compare_against, expand_dim):
1486914979
module_b = TensorDictModule(
1487014980
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
1487114981
)
14872-
self.convert_to_functional(module_a, "module_a")
14873-
self.convert_to_functional(
14982+
self.maybe_convert_to_functional(module_a, "module_a")
14983+
self.maybe_convert_to_functional(
1487414984
module_b,
1487514985
"module_b",
1487614986
compare_against=module_a.parameters() if compare_against else [],
@@ -14918,8 +15028,8 @@ def __init__(self, expand_dim=2):
1491815028
module_b = TensorDictModule(
1491915029
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
1492015030
)
14921-
self.convert_to_functional(module_a, "module_a")
14922-
self.convert_to_functional(
15031+
self.maybe_convert_to_functional(module_a, "module_a")
15032+
self.maybe_convert_to_functional(
1492315033
module_b,
1492415034
"module_b",
1492515035
compare_against=module_a.parameters(),
@@ -14967,8 +15077,8 @@ class MyLoss(LossModule):
1496715077

1496815078
def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
1496915079
super().__init__()
14970-
self.convert_to_functional(module_a, "module_a")
14971-
self.convert_to_functional(
15080+
self.maybe_convert_to_functional(module_a, "module_a")
15081+
self.maybe_convert_to_functional(
1497215082
[module_b0, module_b1],
1497315083
"module_b",
1497415084
# This will be ignored
@@ -15337,14 +15447,14 @@ def __init__(self):
1533715447
TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]),
1533815448
)
1533915449
super().__init__()
15340-
self.convert_to_functional(
15450+
self.maybe_convert_to_functional(
1534115451
actor,
1534215452
"actor",
1534315453
expand_dim=None,
1534415454
create_target_params=False,
1534515455
compare_against=None,
1534615456
)
15347-
self.convert_to_functional(
15457+
self.maybe_convert_to_functional(
1534815458
value,
1534915459
"value",
1535015460
expand_dim=2,

torchrl/objectives/a2c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __init__(
278278
)
279279

280280
if functional:
281-
self.convert_to_functional(
281+
self.maybe_convert_to_functional(
282282
actor_network,
283283
"actor_network",
284284
)
@@ -292,7 +292,7 @@ def __init__(
292292
else:
293293
policy_params = None
294294
if functional:
295-
self.convert_to_functional(
295+
self.maybe_convert_to_functional(
296296
critic_network, "critic_network", compare_against=policy_params
297297
)
298298
else:

0 commit comments

Comments
 (0)