|
113 | 113 | from torchrl.objectives.redq import REDQLoss
|
114 | 114 | from torchrl.objectives.reinforce import ReinforceLoss
|
115 | 115 | from torchrl.objectives.utils import (
|
116 |
| - _vmap_func, |
| 116 | + _maybe_vmap_maybe_func, |
117 | 117 | HardUpdate,
|
118 | 118 | hold_out_net,
|
119 | 119 | SoftUpdate,
|
@@ -254,11 +254,11 @@ def __init__(self):
|
254 | 254 | layers.append(nn.Linear(4, 4))
|
255 | 255 | net = nn.Sequential(*layers).to(device)
|
256 | 256 | model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
|
257 |
| - self.convert_to_functional(model, "model", expand_dim=4) |
| 257 | + self.maybe_convert_to_functional(model, "model", expand_dim=4) |
258 | 258 | self._make_vmap()
|
259 | 259 |
|
260 | 260 | def _make_vmap(self):
|
261 |
| - self.vmap_model = _vmap_func( |
| 261 | + self.vmap_model = _maybe_vmap_maybe_func( |
262 | 262 | self.model,
|
263 | 263 | (None, 0),
|
264 | 264 | randomness=(
|
@@ -3876,6 +3876,116 @@ def test_sac_vmap_equiv(
|
3876 | 3876 |
|
3877 | 3877 | assert_allclose_td(loss_vmap, loss_novmap)
|
3878 | 3878 |
|
| 3879 | + @pytest.mark.parametrize("device", get_default_devices()) |
| 3880 | + @pytest.mark.parametrize("as_list", [True, False]) |
| 3881 | + @pytest.mark.parametrize("provide_target", [True, False]) |
| 3882 | + @pytest.mark.parametrize("delay_value", (True, False)) |
| 3883 | + @pytest.mark.parametrize("delay_actor", (True, False)) |
| 3884 | + @pytest.mark.parametrize("delay_qvalue", (True, False)) |
| 3885 | + def test_sac_nofunc( |
| 3886 | + self, |
| 3887 | + device, |
| 3888 | + version, |
| 3889 | + as_list, |
| 3890 | + provide_target, |
| 3891 | + delay_value, |
| 3892 | + delay_actor, |
| 3893 | + delay_qvalue, |
| 3894 | + num_qvalue=4, |
| 3895 | + td_est=None, |
| 3896 | + ): |
| 3897 | + if (delay_actor or delay_qvalue) and not delay_value: |
| 3898 | + pytest.skip("incompatible config") |
| 3899 | + |
| 3900 | + torch.manual_seed(self.seed) |
| 3901 | + td = self._create_mock_data_sac(device=device) |
| 3902 | + |
| 3903 | + kwargs = {} |
| 3904 | + |
| 3905 | + actor = self._create_mock_actor(device=device) |
| 3906 | + if delay_actor: |
| 3907 | + kwargs["delay_actor"] = True |
| 3908 | + if provide_target: |
| 3909 | + kwargs["target_actor_network"] = self._create_mock_actor(device=device) |
| 3910 | + kwargs["target_actor_network"].load_state_dict(actor.state_dict()) |
| 3911 | + if as_list: |
| 3912 | + qvalue = [ |
| 3913 | + self._create_mock_qvalue(device=device) for _ in range(num_qvalue) |
| 3914 | + ] |
| 3915 | + else: |
| 3916 | + qvalue = self._create_mock_qvalue(device=device) |
| 3917 | + |
| 3918 | + if delay_qvalue: |
| 3919 | + kwargs["delay_qvalue"] = True |
| 3920 | + if provide_target: |
| 3921 | + if as_list: |
| 3922 | + kwargs["target_qvalue_network"] = [ |
| 3923 | + self._create_mock_qvalue(device=device) for _ in qvalue |
| 3924 | + ] |
| 3925 | + for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue): |
| 3926 | + qval_t.load_state_dict(qval.state_dict()) |
| 3927 | + |
| 3928 | + else: |
| 3929 | + kwargs["target_qvalue_network"] = self._create_mock_qvalue( |
| 3930 | + device=device |
| 3931 | + ) |
| 3932 | + kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict()) |
| 3933 | + |
| 3934 | + if version == 1: |
| 3935 | + value = self._create_mock_value(device=device) |
| 3936 | + else: |
| 3937 | + value = None |
| 3938 | + if delay_value: |
| 3939 | + kwargs["delay_value"] = True |
| 3940 | + if provide_target and version == 1: |
| 3941 | + kwargs["target_value_network"] = self._create_mock_value(device=device) |
| 3942 | + kwargs["target_value_network"].load_state_dict(value.state_dict()) |
| 3943 | + |
| 3944 | + rng_state = torch.random.get_rng_state() |
| 3945 | + with pytest.warns( |
| 3946 | + UserWarning, match="The target network is ignored as the" |
| 3947 | + ) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext(): |
| 3948 | + loss_fn_nofunc = SACLoss( |
| 3949 | + actor_network=actor, |
| 3950 | + qvalue_network=qvalue, |
| 3951 | + value_network=value, |
| 3952 | + num_qvalue_nets=num_qvalue, |
| 3953 | + loss_function="l2", |
| 3954 | + use_vmap=False, |
| 3955 | + functional=False, |
| 3956 | + **kwargs, |
| 3957 | + ) |
| 3958 | + torch.random.set_rng_state(rng_state) |
| 3959 | + loss_fn_func = SACLoss( |
| 3960 | + actor_network=actor, |
| 3961 | + qvalue_network=qvalue, |
| 3962 | + value_network=value, |
| 3963 | + num_qvalue_nets=num_qvalue, |
| 3964 | + loss_function="l2", |
| 3965 | + use_vmap=False, |
| 3966 | + functional=True, |
| 3967 | + **kwargs, |
| 3968 | + ) |
| 3969 | + assert_allclose_td( |
| 3970 | + torch.stack( |
| 3971 | + list( |
| 3972 | + TensorDict.from_module(loss_fn_nofunc.qvalue_network)[ |
| 3973 | + "module" |
| 3974 | + ].values() |
| 3975 | + ) |
| 3976 | + ), |
| 3977 | + loss_fn_func.qvalue_network_params.data, |
| 3978 | + ) |
| 3979 | + with torch.no_grad(), _check_td_steady(td), pytest.warns( |
| 3980 | + UserWarning, match="No target network updater" |
| 3981 | + ): |
| 3982 | + rng_state = torch.random.get_rng_state() |
| 3983 | + loss_func = loss_fn_nofunc(td.clone()) |
| 3984 | + torch.random.set_rng_state(rng_state) |
| 3985 | + loss_nofunc = loss_fn_func(td.clone()) |
| 3986 | + |
| 3987 | + assert_allclose_td(loss_func, loss_nofunc) |
| 3988 | + |
3879 | 3989 | @pytest.mark.parametrize("delay_value", (True, False))
|
3880 | 3990 | @pytest.mark.parametrize("delay_actor", (True, False))
|
3881 | 3991 | @pytest.mark.parametrize("delay_qvalue", (True, False))
|
@@ -12383,7 +12493,7 @@ class MyLoss(LossModule):
|
12383 | 12493 |
|
12384 | 12494 | def __init__(self, actor_network):
|
12385 | 12495 | super().__init__()
|
12386 |
| - self.convert_to_functional( |
| 12496 | + self.maybe_convert_to_functional( |
12387 | 12497 | actor_network,
|
12388 | 12498 | "actor_network",
|
12389 | 12499 | create_target_params=create_target_params,
|
@@ -12532,7 +12642,7 @@ class custom_module(LossModule):
|
12532 | 12642 | def __init__(self, delay_module=True):
|
12533 | 12643 | super().__init__()
|
12534 | 12644 | module1 = torch.nn.BatchNorm2d(10).eval()
|
12535 |
| - self.convert_to_functional( |
| 12645 | + self.maybe_convert_to_functional( |
12536 | 12646 | module1, "module1", create_target_params=delay_module
|
12537 | 12647 | )
|
12538 | 12648 |
|
@@ -14296,12 +14406,12 @@ class MyLoss(LossModule):
|
14296 | 14406 |
|
14297 | 14407 | def __init__(self, actor_network, qvalue_network):
|
14298 | 14408 | super().__init__()
|
14299 |
| - self.convert_to_functional( |
| 14409 | + self.maybe_convert_to_functional( |
14300 | 14410 | actor_network,
|
14301 | 14411 | "actor_network",
|
14302 | 14412 | create_target_params=True,
|
14303 | 14413 | )
|
14304 |
| - self.convert_to_functional( |
| 14414 | + self.maybe_convert_to_functional( |
14305 | 14415 | qvalue_network,
|
14306 | 14416 | "qvalue_network",
|
14307 | 14417 | 3,
|
@@ -14869,8 +14979,8 @@ def __init__(self, compare_against, expand_dim):
|
14869 | 14979 | module_b = TensorDictModule(
|
14870 | 14980 | nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
|
14871 | 14981 | )
|
14872 |
| - self.convert_to_functional(module_a, "module_a") |
14873 |
| - self.convert_to_functional( |
| 14982 | + self.maybe_convert_to_functional(module_a, "module_a") |
| 14983 | + self.maybe_convert_to_functional( |
14874 | 14984 | module_b,
|
14875 | 14985 | "module_b",
|
14876 | 14986 | compare_against=module_a.parameters() if compare_against else [],
|
@@ -14918,8 +15028,8 @@ def __init__(self, expand_dim=2):
|
14918 | 15028 | module_b = TensorDictModule(
|
14919 | 15029 | nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
|
14920 | 15030 | )
|
14921 |
| - self.convert_to_functional(module_a, "module_a") |
14922 |
| - self.convert_to_functional( |
| 15031 | + self.maybe_convert_to_functional(module_a, "module_a") |
| 15032 | + self.maybe_convert_to_functional( |
14923 | 15033 | module_b,
|
14924 | 15034 | "module_b",
|
14925 | 15035 | compare_against=module_a.parameters(),
|
@@ -14967,8 +15077,8 @@ class MyLoss(LossModule):
|
14967 | 15077 |
|
14968 | 15078 | def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
|
14969 | 15079 | super().__init__()
|
14970 |
| - self.convert_to_functional(module_a, "module_a") |
14971 |
| - self.convert_to_functional( |
| 15080 | + self.maybe_convert_to_functional(module_a, "module_a") |
| 15081 | + self.maybe_convert_to_functional( |
14972 | 15082 | [module_b0, module_b1],
|
14973 | 15083 | "module_b",
|
14974 | 15084 | # This will be ignored
|
@@ -15337,14 +15447,14 @@ def __init__(self):
|
15337 | 15447 | TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]),
|
15338 | 15448 | )
|
15339 | 15449 | super().__init__()
|
15340 |
| - self.convert_to_functional( |
| 15450 | + self.maybe_convert_to_functional( |
15341 | 15451 | actor,
|
15342 | 15452 | "actor",
|
15343 | 15453 | expand_dim=None,
|
15344 | 15454 | create_target_params=False,
|
15345 | 15455 | compare_against=None,
|
15346 | 15456 | )
|
15347 |
| - self.convert_to_functional( |
| 15457 | + self.maybe_convert_to_functional( |
15348 | 15458 | value,
|
15349 | 15459 | "value",
|
15350 | 15460 | expand_dim=2,
|
|
0 commit comments