|
113 | 113 | from torchrl.objectives.redq import REDQLoss
|
114 | 114 | from torchrl.objectives.reinforce import ReinforceLoss
|
115 | 115 | from torchrl.objectives.utils import (
|
116 |
| - _vmap_func, |
| 116 | + _maybe_vmap_maybe_func, |
117 | 117 | HardUpdate,
|
118 | 118 | hold_out_net,
|
119 | 119 | SoftUpdate,
|
@@ -249,11 +249,11 @@ def __init__(self):
|
249 | 249 | layers.append(nn.Linear(4, 4))
|
250 | 250 | net = nn.Sequential(*layers).to(device)
|
251 | 251 | model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
|
252 |
| - self.convert_to_functional(model, "model", expand_dim=4) |
| 252 | + self.maybe_convert_to_functional(model, "model", expand_dim=4) |
253 | 253 | self._make_vmap()
|
254 | 254 |
|
255 | 255 | def _make_vmap(self):
|
256 |
| - self.vmap_model = _vmap_func( |
| 256 | + self.vmap_model = _maybe_vmap_maybe_func( |
257 | 257 | self.model,
|
258 | 258 | (None, 0),
|
259 | 259 | randomness=(
|
@@ -3871,6 +3871,116 @@ def test_sac_vmap_equiv(
|
3871 | 3871 |
|
3872 | 3872 | assert_allclose_td(loss_vmap, loss_novmap)
|
3873 | 3873 |
|
| 3874 | + @pytest.mark.parametrize("device", get_default_devices()) |
| 3875 | + @pytest.mark.parametrize("as_list", [True, False]) |
| 3876 | + @pytest.mark.parametrize("provide_target", [True, False]) |
| 3877 | + @pytest.mark.parametrize("delay_value", (True, False)) |
| 3878 | + @pytest.mark.parametrize("delay_actor", (True, False)) |
| 3879 | + @pytest.mark.parametrize("delay_qvalue", (True, False)) |
| 3880 | + def test_sac_nofunc( |
| 3881 | + self, |
| 3882 | + device, |
| 3883 | + version, |
| 3884 | + as_list, |
| 3885 | + provide_target, |
| 3886 | + delay_value, |
| 3887 | + delay_actor, |
| 3888 | + delay_qvalue, |
| 3889 | + num_qvalue=4, |
| 3890 | + td_est=None, |
| 3891 | + ): |
| 3892 | + if (delay_actor or delay_qvalue) and not delay_value: |
| 3893 | + pytest.skip("incompatible config") |
| 3894 | + |
| 3895 | + torch.manual_seed(self.seed) |
| 3896 | + td = self._create_mock_data_sac(device=device) |
| 3897 | + |
| 3898 | + kwargs = {} |
| 3899 | + |
| 3900 | + actor = self._create_mock_actor(device=device) |
| 3901 | + if delay_actor: |
| 3902 | + kwargs["delay_actor"] = True |
| 3903 | + if provide_target: |
| 3904 | + kwargs["target_actor_network"] = self._create_mock_actor(device=device) |
| 3905 | + kwargs["target_actor_network"].load_state_dict(actor.state_dict()) |
| 3906 | + if as_list: |
| 3907 | + qvalue = [ |
| 3908 | + self._create_mock_qvalue(device=device) for _ in range(num_qvalue) |
| 3909 | + ] |
| 3910 | + else: |
| 3911 | + qvalue = self._create_mock_qvalue(device=device) |
| 3912 | + |
| 3913 | + if delay_qvalue: |
| 3914 | + kwargs["delay_qvalue"] = True |
| 3915 | + if provide_target: |
| 3916 | + if as_list: |
| 3917 | + kwargs["target_qvalue_network"] = [ |
| 3918 | + self._create_mock_qvalue(device=device) for _ in qvalue |
| 3919 | + ] |
| 3920 | + for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue): |
| 3921 | + qval_t.load_state_dict(qval.state_dict()) |
| 3922 | + |
| 3923 | + else: |
| 3924 | + kwargs["target_qvalue_network"] = self._create_mock_qvalue( |
| 3925 | + device=device |
| 3926 | + ) |
| 3927 | + kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict()) |
| 3928 | + |
| 3929 | + if version == 1: |
| 3930 | + value = self._create_mock_value(device=device) |
| 3931 | + else: |
| 3932 | + value = None |
| 3933 | + if delay_value: |
| 3934 | + kwargs["delay_value"] = True |
| 3935 | + if provide_target and version == 1: |
| 3936 | + kwargs["target_value_network"] = self._create_mock_value(device=device) |
| 3937 | + kwargs["target_value_network"].load_state_dict(value.state_dict()) |
| 3938 | + |
| 3939 | + rng_state = torch.random.get_rng_state() |
| 3940 | + with pytest.warns( |
| 3941 | + UserWarning, match="The target network is ignored as the" |
| 3942 | + ) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext(): |
| 3943 | + loss_fn_nofunc = SACLoss( |
| 3944 | + actor_network=actor, |
| 3945 | + qvalue_network=qvalue, |
| 3946 | + value_network=value, |
| 3947 | + num_qvalue_nets=num_qvalue, |
| 3948 | + loss_function="l2", |
| 3949 | + use_vmap=False, |
| 3950 | + functional=False, |
| 3951 | + **kwargs, |
| 3952 | + ) |
| 3953 | + torch.random.set_rng_state(rng_state) |
| 3954 | + loss_fn_func = SACLoss( |
| 3955 | + actor_network=actor, |
| 3956 | + qvalue_network=qvalue, |
| 3957 | + value_network=value, |
| 3958 | + num_qvalue_nets=num_qvalue, |
| 3959 | + loss_function="l2", |
| 3960 | + use_vmap=False, |
| 3961 | + functional=True, |
| 3962 | + **kwargs, |
| 3963 | + ) |
| 3964 | + assert_allclose_td( |
| 3965 | + torch.stack( |
| 3966 | + list( |
| 3967 | + TensorDict.from_module(loss_fn_nofunc.qvalue_network)[ |
| 3968 | + "module" |
| 3969 | + ].values() |
| 3970 | + ) |
| 3971 | + ), |
| 3972 | + loss_fn_func.qvalue_network_params.data, |
| 3973 | + ) |
| 3974 | + with torch.no_grad(), _check_td_steady(td), pytest.warns( |
| 3975 | + UserWarning, match="No target network updater" |
| 3976 | + ): |
| 3977 | + rng_state = torch.random.get_rng_state() |
| 3978 | + loss_func = loss_fn_nofunc(td.clone()) |
| 3979 | + torch.random.set_rng_state(rng_state) |
| 3980 | + loss_nofunc = loss_fn_func(td.clone()) |
| 3981 | + |
| 3982 | + assert_allclose_td(loss_func, loss_nofunc) |
| 3983 | + |
3874 | 3984 | @pytest.mark.parametrize("delay_value", (True, False))
|
3875 | 3985 | @pytest.mark.parametrize("delay_actor", (True, False))
|
3876 | 3986 | @pytest.mark.parametrize("delay_qvalue", (True, False))
|
@@ -12378,7 +12488,7 @@ class MyLoss(LossModule):
|
12378 | 12488 |
|
12379 | 12489 | def __init__(self, actor_network):
|
12380 | 12490 | super().__init__()
|
12381 |
| - self.convert_to_functional( |
| 12491 | + self.maybe_convert_to_functional( |
12382 | 12492 | actor_network,
|
12383 | 12493 | "actor_network",
|
12384 | 12494 | create_target_params=create_target_params,
|
@@ -12527,7 +12637,7 @@ class custom_module(LossModule):
|
12527 | 12637 | def __init__(self, delay_module=True):
|
12528 | 12638 | super().__init__()
|
12529 | 12639 | module1 = torch.nn.BatchNorm2d(10).eval()
|
12530 |
| - self.convert_to_functional( |
| 12640 | + self.maybe_convert_to_functional( |
12531 | 12641 | module1, "module1", create_target_params=delay_module
|
12532 | 12642 | )
|
12533 | 12643 |
|
@@ -14291,12 +14401,12 @@ class MyLoss(LossModule):
|
14291 | 14401 |
|
14292 | 14402 | def __init__(self, actor_network, qvalue_network):
|
14293 | 14403 | super().__init__()
|
14294 |
| - self.convert_to_functional( |
| 14404 | + self.maybe_convert_to_functional( |
14295 | 14405 | actor_network,
|
14296 | 14406 | "actor_network",
|
14297 | 14407 | create_target_params=True,
|
14298 | 14408 | )
|
14299 |
| - self.convert_to_functional( |
| 14409 | + self.maybe_convert_to_functional( |
14300 | 14410 | qvalue_network,
|
14301 | 14411 | "qvalue_network",
|
14302 | 14412 | 3,
|
@@ -14864,8 +14974,8 @@ def __init__(self, compare_against, expand_dim):
|
14864 | 14974 | module_b = TensorDictModule(
|
14865 | 14975 | nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
|
14866 | 14976 | )
|
14867 |
| - self.convert_to_functional(module_a, "module_a") |
14868 |
| - self.convert_to_functional( |
| 14977 | + self.maybe_convert_to_functional(module_a, "module_a") |
| 14978 | + self.maybe_convert_to_functional( |
14869 | 14979 | module_b,
|
14870 | 14980 | "module_b",
|
14871 | 14981 | compare_against=module_a.parameters() if compare_against else [],
|
@@ -14913,8 +15023,8 @@ def __init__(self, expand_dim=2):
|
14913 | 15023 | module_b = TensorDictModule(
|
14914 | 15024 | nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
|
14915 | 15025 | )
|
14916 |
| - self.convert_to_functional(module_a, "module_a") |
14917 |
| - self.convert_to_functional( |
| 15026 | + self.maybe_convert_to_functional(module_a, "module_a") |
| 15027 | + self.maybe_convert_to_functional( |
14918 | 15028 | module_b,
|
14919 | 15029 | "module_b",
|
14920 | 15030 | compare_against=module_a.parameters(),
|
@@ -14962,8 +15072,8 @@ class MyLoss(LossModule):
|
14962 | 15072 |
|
14963 | 15073 | def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
|
14964 | 15074 | super().__init__()
|
14965 |
| - self.convert_to_functional(module_a, "module_a") |
14966 |
| - self.convert_to_functional( |
| 15075 | + self.maybe_convert_to_functional(module_a, "module_a") |
| 15076 | + self.maybe_convert_to_functional( |
14967 | 15077 | [module_b0, module_b1],
|
14968 | 15078 | "module_b",
|
14969 | 15079 | # This will be ignored
|
@@ -15332,14 +15442,14 @@ def __init__(self):
|
15332 | 15442 | TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]),
|
15333 | 15443 | )
|
15334 | 15444 | super().__init__()
|
15335 |
| - self.convert_to_functional( |
| 15445 | + self.maybe_convert_to_functional( |
15336 | 15446 | actor,
|
15337 | 15447 | "actor",
|
15338 | 15448 | expand_dim=None,
|
15339 | 15449 | create_target_params=False,
|
15340 | 15450 | compare_against=None,
|
15341 | 15451 | )
|
15342 |
| - self.convert_to_functional( |
| 15452 | + self.maybe_convert_to_functional( |
15343 | 15453 | value,
|
15344 | 15454 | "value",
|
15345 | 15455 | expand_dim=2,
|
|
0 commit comments