Skip to content

Commit 26e3083

Browse files
authored
[Feature] Implemented device argument for modules.models (#524)
Co-authored-by: Yu Shiyang <yushiyang@fb.com>
1 parent 59b1f2b commit 26e3083

File tree

5 files changed

+249
-90
lines changed

5 files changed

+249
-90
lines changed

test/test_exploration.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,25 +227,23 @@ def test_gsde(
227227
):
228228
torch.manual_seed(0)
229229
if gSDE:
230-
model = torch.nn.LazyLinear(action_dim)
230+
model = torch.nn.LazyLinear(action_dim, device=device)
231231
in_keys = ["observation"]
232232
module = TensorDictSequential(
233233
TensorDictModule(model, in_keys=in_keys, out_keys=["action"]),
234234
TensorDictModule(
235-
LazygSDEModule(),
235+
LazygSDEModule(device=device),
236236
in_keys=["action", "observation", "_eps_gSDE"],
237237
out_keys=["loc", "scale", "action", "_eps_gSDE"],
238238
),
239-
).to(device)
239+
)
240240
distribution_class = IndependentNormal
241241
distribution_kwargs = {}
242242
else:
243243
in_keys = ["observation"]
244-
model = torch.nn.LazyLinear(action_dim * 2)
244+
model = torch.nn.LazyLinear(action_dim * 2, device=device)
245245
wrapper = NormalParamWrapper(model)
246-
module = TensorDictModule(
247-
wrapper, in_keys=in_keys, out_keys=["loc", "scale"]
248-
).to(device)
246+
module = TensorDictModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
249247
distribution_class = TanhNormal
250248
distribution_kwargs = {"min": -bound, "max": bound}
251249
spec = NdBoundedTensorSpec(

test/test_modules.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@
2525
FunctionalModule,
2626
FunctionalModuleWithBuffers,
2727
)
28-
from torchrl.modules.models import MLP, NoisyLazyLinear, NoisyLinear
28+
from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear
29+
from torchrl.modules.models.utils import SquashDims
2930

3031

3132
@pytest.mark.parametrize("in_features", [3, 10, None])
3233
@pytest.mark.parametrize("out_features", [3, (3, 10)])
3334
@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))])
34-
@pytest.mark.parametrize("activation_kwargs", [{"inplace": True}, {}])
35+
@pytest.mark.parametrize(
36+
"activation_class, activation_kwargs",
37+
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
38+
)
3539
@pytest.mark.parametrize(
3640
"norm_class, norm_kwargs",
3741
[(nn.LazyBatchNorm1d, {}), (nn.BatchNorm1d, {"num_features": 32})],
@@ -45,6 +49,7 @@ def test_mlp(
4549
out_features,
4650
depth,
4751
num_cells,
52+
activation_class,
4853
activation_kwargs,
4954
bias_last_layer,
5055
norm_class,
@@ -61,14 +66,15 @@ def test_mlp(
6166
out_features=out_features,
6267
depth=depth,
6368
num_cells=num_cells,
64-
activation_class=nn.ReLU,
69+
activation_class=activation_class,
6570
activation_kwargs=activation_kwargs,
6671
norm_class=norm_class,
6772
norm_kwargs=norm_kwargs,
6873
bias_last_layer=bias_last_layer,
6974
single_bias_last_layer=False,
7075
layer_class=layer_class,
71-
).to(device)
76+
device=device,
77+
)
7278
if in_features is None:
7379
in_features = 5
7480
x = torch.randn(batch, in_features, device=device)
@@ -77,6 +83,72 @@ def test_mlp(
7783
assert y.shape == torch.Size([batch, *out_features])
7884

7985

86+
@pytest.mark.parametrize("in_features", [3, 10, None])
87+
@pytest.mark.parametrize(
88+
"input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features",
89+
[(100, None, None, 3, 1, 0, 32 * 94 * 94), (100, 3, 32, 3, 1, 1, 32 * 100 * 100)],
90+
)
91+
@pytest.mark.parametrize(
92+
"activation_class, activation_kwargs",
93+
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
94+
)
95+
@pytest.mark.parametrize(
96+
"norm_class, norm_kwargs",
97+
[(None, None), (nn.LazyBatchNorm2d, {}), (nn.BatchNorm2d, {"num_features": 32})],
98+
)
99+
@pytest.mark.parametrize("bias_last_layer", [True, False])
100+
@pytest.mark.parametrize(
101+
"aggregator_class, aggregator_kwargs",
102+
[(SquashDims, {})],
103+
)
104+
@pytest.mark.parametrize("squeeze_output", [False])
105+
@pytest.mark.parametrize("device", get_available_devices())
106+
def test_convnet(
107+
in_features,
108+
depth,
109+
num_cells,
110+
kernel_sizes,
111+
strides,
112+
paddings,
113+
activation_class,
114+
activation_kwargs,
115+
norm_class,
116+
norm_kwargs,
117+
bias_last_layer,
118+
aggregator_class,
119+
aggregator_kwargs,
120+
squeeze_output,
121+
device,
122+
input_size,
123+
expected_features,
124+
seed=0,
125+
):
126+
torch.manual_seed(seed)
127+
batch = 2
128+
convnet = ConvNet(
129+
in_features=in_features,
130+
depth=depth,
131+
num_cells=num_cells,
132+
kernel_sizes=kernel_sizes,
133+
strides=strides,
134+
paddings=paddings,
135+
activation_class=activation_class,
136+
activation_kwargs=activation_kwargs,
137+
norm_class=norm_class,
138+
norm_kwargs=norm_kwargs,
139+
bias_last_layer=bias_last_layer,
140+
aggregator_class=aggregator_class,
141+
aggregator_kwargs=aggregator_kwargs,
142+
squeeze_output=squeeze_output,
143+
device=device,
144+
)
145+
if in_features is None:
146+
in_features = 5
147+
x = torch.randn(batch, in_features, input_size, input_size, device=device)
148+
y = convnet(x)
149+
assert y.shape == torch.Size([batch, expected_features])
150+
151+
80152
@pytest.mark.parametrize(
81153
"layer_class",
82154
[
@@ -87,7 +159,7 @@ def test_mlp(
87159
@pytest.mark.parametrize("device", get_available_devices())
88160
def test_noisy(layer_class, device, seed=0):
89161
torch.manual_seed(seed)
90-
layer = layer_class(3, 4).to(device)
162+
layer = layer_class(3, 4, device=device)
91163
x = torch.randn(10, 3, device=device)
92164
y1 = layer(x)
93165
layer.reset_noise()
@@ -106,25 +178,25 @@ def test_value_based_policy(device):
106178
action_spec = OneHotDiscreteTensorSpec(action_dim)
107179

108180
def make_net():
109-
net = MLP(in_features=obs_dim, out_features=action_dim, depth=2)
181+
net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device)
110182
for mod in net.modules():
111183
if hasattr(mod, "bias") and mod.bias is not None:
112184
mod.bias.data.zero_()
113185
return net
114186

115-
actor = QValueActor(spec=action_spec, module=make_net(), safe=True).to(device)
187+
actor = QValueActor(spec=action_spec, module=make_net(), safe=True)
116188
obs = torch.zeros(2, obs_dim, device=device)
117189
td = TensorDict(batch_size=[2], source={"observation": obs})
118190
action = actor(td).get("action")
119191
assert (action.sum(-1) == 1).all()
120192

121-
actor = QValueActor(spec=action_spec, module=make_net(), safe=False).to(device)
193+
actor = QValueActor(spec=action_spec, module=make_net(), safe=False)
122194
obs = torch.randn(2, obs_dim, device=device)
123195
td = TensorDict(batch_size=[2], source={"observation": obs})
124196
action = actor(td).get("action")
125197
assert (action.sum(-1) == 1).all()
126198

127-
actor = QValueActor(spec=action_spec, module=make_net(), safe=False).to(device)
199+
actor = QValueActor(spec=action_spec, module=make_net(), safe=False)
128200
obs = torch.zeros(2, obs_dim, device=device)
129201
td = TensorDict(batch_size=[2], source={"observation": obs})
130202
action = actor(td).get("action")
@@ -198,7 +270,8 @@ def test_lstm_net(device, out_features, hidden_size, num_layers, has_precond_hid
198270
"num_layers": num_layers,
199271
},
200272
{"out_features": hidden_size},
201-
).to(device)
273+
device=device,
274+
)
202275
# test single step vs multi-step
203276
x = torch.randn(batch, time_steps, in_features, device=device)
204277
x_unbind = x.unbind(1)
@@ -264,7 +337,8 @@ def test_lstm_net_nobatch(device, out_features, hidden_size):
264337
out_features,
265338
{"input_size": hidden_size, "hidden_size": hidden_size},
266339
{"out_features": hidden_size},
267-
).to(device)
340+
device=device,
341+
)
268342
# test single step vs multi-step
269343
x = torch.randn(time_steps, in_features, device=device)
270344
x_unbind = x.unbind(0)

torchrl/modules/models/exploration.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class NoisyLinear(nn.Linear):
3535
out_features (int): out features dimension
3636
bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b.
3737
default: True
38-
device (str, int or torch.device, optional): device of the layer.
38+
device (DEVICE_TYPING, optional): device of the layer.
3939
default: "cpu"
4040
dtype (torch.dtype, optional): dtype of the parameters.
4141
default: None
@@ -157,8 +157,7 @@ class NoisyLazyLinear(LazyModuleMixin, NoisyLinear):
157157
out_features (int): out features dimension
158158
bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b.
159159
default: True
160-
device (str, int or torch.device, optional): device of the layer.
161-
default: "cpu"
160+
device (DEVICE_TYPING, optional): device of the layer.
162161
dtype (torch.dtype, optional): dtype of the parameters.
163162
default: None
164163
std_init (scalar): initial value of the Gaussian standard deviation before optimization.
@@ -173,7 +172,7 @@ def __init__(
173172
dtype: Optional[torch.dtype] = None,
174173
std_init: float = 0.1,
175174
):
176-
super().__init__(0, 0, False)
175+
super().__init__(0, 0, False, device=device)
177176
self.out_features = out_features
178177
self.std_init = std_init
179178

@@ -260,6 +259,7 @@ class gSDEModule(nn.Module):
260259
scale_max (float, optional): max value of the scale.
261260
transform (torch.distribution.Transform, optional): a transform to apply
262261
to the sampled action.
262+
device (DEVICE_TYPING, optional): device to create the model on.
263263
264264
Examples:
265265
>>> from torchrl.modules import TensorDictModule, TensorDictSequential, ProbabilisticActor, TanhNormal
@@ -308,6 +308,7 @@ def __init__(
308308
scale_max: float = 10.0,
309309
learn_sigma: bool = True,
310310
transform: Optional[d.Transform] = None,
311+
device: Optional[DEVICE_TYPING] = None,
311312
) -> None:
312313
super().__init__()
313314
self.action_dim = action_dim
@@ -321,18 +322,22 @@ def __init__(
321322
sigma_init = inv_softplus(math.sqrt((1.0 - scale_min) / state_dim))
322323
self.register_parameter(
323324
"log_sigma",
324-
nn.Parameter(torch.zeros((action_dim, state_dim), requires_grad=True)),
325+
nn.Parameter(
326+
torch.zeros(
327+
(action_dim, state_dim), requires_grad=True, device=device
328+
)
329+
),
325330
)
326331
else:
327332
if sigma_init is None:
328333
sigma_init = math.sqrt((1.0 - scale_min) / state_dim)
329334
self.register_buffer(
330335
"_sigma",
331-
torch.full((action_dim, state_dim), sigma_init),
336+
torch.full((action_dim, state_dim), sigma_init, device=device),
332337
)
333338

334339
if sigma_init != 0.0:
335-
self.register_buffer("sigma_init", torch.tensor(sigma_init))
340+
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
336341

337342
@property
338343
def sigma(self):
@@ -417,11 +422,8 @@ def __init__(
417422
scale_max: float = 10.0,
418423
learn_sigma: bool = True,
419424
transform: Optional[d.Transform] = None,
425+
device: Optional[DEVICE_TYPING] = None,
420426
) -> None:
421-
factory_kwargs = {
422-
"device": torch.device("cpu"),
423-
"dtype": torch.get_default_dtype(),
424-
}
425427
super().__init__(
426428
0,
427429
0,
@@ -430,7 +432,12 @@ def __init__(
430432
scale_max=scale_max,
431433
learn_sigma=learn_sigma,
432434
transform=transform,
435+
device=device,
433436
)
437+
factory_kwargs = {
438+
"device": device,
439+
"dtype": torch.get_default_dtype(),
440+
}
434441
self._sigma_init = sigma_init
435442
self.sigma_init = UninitializedBuffer(**factory_kwargs)
436443
if learn_sigma:
@@ -445,18 +452,17 @@ def initialize_parameters(
445452
self, mu: torch.Tensor, state: torch.Tensor, _eps_gSDE: torch.Tensor
446453
) -> None:
447454
if self.has_uninitialized_params():
448-
device = mu.device
449455
action_dim = mu.shape[-1]
450456
state_dim = state.shape[-1]
451457
with torch.no_grad():
452458
if state.ndimension() > 2:
453459
state = state.flatten(0, -2).squeeze(0)
454460
if state.ndimension() == 1:
455-
state_flatten_var = torch.ones(1).to(device)
461+
state_flatten_var = torch.ones(1, device=mu.device)
456462
else:
457463
state_flatten_var = state.pow(2).mean(dim=0).reciprocal()
458464

459-
self.sigma_init.materialize(state_flatten_var.shape, device=device)
465+
self.sigma_init.materialize(state_flatten_var.shape)
460466
if self.learn_sigma:
461467
if self._sigma_init is None:
462468
state_flatten_var.clamp_min_(self.scale_min)
@@ -471,7 +477,7 @@ def initialize_parameters(
471477
)
472478
)
473479

474-
self.log_sigma.materialize((action_dim, state_dim), device=device)
480+
self.log_sigma.materialize((action_dim, state_dim))
475481
self.log_sigma.data.copy_(self.sigma_init.expand_as(self.log_sigma))
476482

477483
else:
@@ -483,5 +489,5 @@ def initialize_parameters(
483489
self.sigma_init.data.copy_(
484490
(state_flatten_var / state_dim).sqrt() * self._sigma_init
485491
)
486-
self._sigma.materialize((action_dim, state_dim), device=device)
492+
self._sigma.materialize((action_dim, state_dim))
487493
self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma))

0 commit comments

Comments
 (0)