25
25
FunctionalModule ,
26
26
FunctionalModuleWithBuffers ,
27
27
)
28
- from torchrl .modules .models import MLP , NoisyLazyLinear , NoisyLinear
28
+ from torchrl .modules .models import ConvNet , MLP , NoisyLazyLinear , NoisyLinear
29
+ from torchrl .modules .models .utils import SquashDims
29
30
30
31
31
32
@pytest .mark .parametrize ("in_features" , [3 , 10 , None ])
32
33
@pytest .mark .parametrize ("out_features" , [3 , (3 , 10 )])
33
34
@pytest .mark .parametrize ("depth, num_cells" , [(3 , 32 ), (None , (32 , 32 , 32 ))])
34
- @pytest .mark .parametrize ("activation_kwargs" , [{"inplace" : True }, {}])
35
+ @pytest .mark .parametrize (
36
+ "activation_class, activation_kwargs" ,
37
+ [(nn .ReLU , {"inplace" : True }), (nn .ReLU , {}), (nn .PReLU , {})],
38
+ )
35
39
@pytest .mark .parametrize (
36
40
"norm_class, norm_kwargs" ,
37
41
[(nn .LazyBatchNorm1d , {}), (nn .BatchNorm1d , {"num_features" : 32 })],
@@ -45,6 +49,7 @@ def test_mlp(
45
49
out_features ,
46
50
depth ,
47
51
num_cells ,
52
+ activation_class ,
48
53
activation_kwargs ,
49
54
bias_last_layer ,
50
55
norm_class ,
@@ -61,14 +66,15 @@ def test_mlp(
61
66
out_features = out_features ,
62
67
depth = depth ,
63
68
num_cells = num_cells ,
64
- activation_class = nn . ReLU ,
69
+ activation_class = activation_class ,
65
70
activation_kwargs = activation_kwargs ,
66
71
norm_class = norm_class ,
67
72
norm_kwargs = norm_kwargs ,
68
73
bias_last_layer = bias_last_layer ,
69
74
single_bias_last_layer = False ,
70
75
layer_class = layer_class ,
71
- ).to (device )
76
+ device = device ,
77
+ )
72
78
if in_features is None :
73
79
in_features = 5
74
80
x = torch .randn (batch , in_features , device = device )
@@ -77,6 +83,72 @@ def test_mlp(
77
83
assert y .shape == torch .Size ([batch , * out_features ])
78
84
79
85
86
+ @pytest .mark .parametrize ("in_features" , [3 , 10 , None ])
87
+ @pytest .mark .parametrize (
88
+ "input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features" ,
89
+ [(100 , None , None , 3 , 1 , 0 , 32 * 94 * 94 ), (100 , 3 , 32 , 3 , 1 , 1 , 32 * 100 * 100 )],
90
+ )
91
+ @pytest .mark .parametrize (
92
+ "activation_class, activation_kwargs" ,
93
+ [(nn .ReLU , {"inplace" : True }), (nn .ReLU , {}), (nn .PReLU , {})],
94
+ )
95
+ @pytest .mark .parametrize (
96
+ "norm_class, norm_kwargs" ,
97
+ [(None , None ), (nn .LazyBatchNorm2d , {}), (nn .BatchNorm2d , {"num_features" : 32 })],
98
+ )
99
+ @pytest .mark .parametrize ("bias_last_layer" , [True , False ])
100
+ @pytest .mark .parametrize (
101
+ "aggregator_class, aggregator_kwargs" ,
102
+ [(SquashDims , {})],
103
+ )
104
+ @pytest .mark .parametrize ("squeeze_output" , [False ])
105
+ @pytest .mark .parametrize ("device" , get_available_devices ())
106
+ def test_convnet (
107
+ in_features ,
108
+ depth ,
109
+ num_cells ,
110
+ kernel_sizes ,
111
+ strides ,
112
+ paddings ,
113
+ activation_class ,
114
+ activation_kwargs ,
115
+ norm_class ,
116
+ norm_kwargs ,
117
+ bias_last_layer ,
118
+ aggregator_class ,
119
+ aggregator_kwargs ,
120
+ squeeze_output ,
121
+ device ,
122
+ input_size ,
123
+ expected_features ,
124
+ seed = 0 ,
125
+ ):
126
+ torch .manual_seed (seed )
127
+ batch = 2
128
+ convnet = ConvNet (
129
+ in_features = in_features ,
130
+ depth = depth ,
131
+ num_cells = num_cells ,
132
+ kernel_sizes = kernel_sizes ,
133
+ strides = strides ,
134
+ paddings = paddings ,
135
+ activation_class = activation_class ,
136
+ activation_kwargs = activation_kwargs ,
137
+ norm_class = norm_class ,
138
+ norm_kwargs = norm_kwargs ,
139
+ bias_last_layer = bias_last_layer ,
140
+ aggregator_class = aggregator_class ,
141
+ aggregator_kwargs = aggregator_kwargs ,
142
+ squeeze_output = squeeze_output ,
143
+ device = device ,
144
+ )
145
+ if in_features is None :
146
+ in_features = 5
147
+ x = torch .randn (batch , in_features , input_size , input_size , device = device )
148
+ y = convnet (x )
149
+ assert y .shape == torch .Size ([batch , expected_features ])
150
+
151
+
80
152
@pytest .mark .parametrize (
81
153
"layer_class" ,
82
154
[
@@ -87,7 +159,7 @@ def test_mlp(
87
159
@pytest .mark .parametrize ("device" , get_available_devices ())
88
160
def test_noisy (layer_class , device , seed = 0 ):
89
161
torch .manual_seed (seed )
90
- layer = layer_class (3 , 4 ). to ( device )
162
+ layer = layer_class (3 , 4 , device = device )
91
163
x = torch .randn (10 , 3 , device = device )
92
164
y1 = layer (x )
93
165
layer .reset_noise ()
@@ -106,25 +178,25 @@ def test_value_based_policy(device):
106
178
action_spec = OneHotDiscreteTensorSpec (action_dim )
107
179
108
180
def make_net ():
109
- net = MLP (in_features = obs_dim , out_features = action_dim , depth = 2 )
181
+ net = MLP (in_features = obs_dim , out_features = action_dim , depth = 2 , device = device )
110
182
for mod in net .modules ():
111
183
if hasattr (mod , "bias" ) and mod .bias is not None :
112
184
mod .bias .data .zero_ ()
113
185
return net
114
186
115
- actor = QValueActor (spec = action_spec , module = make_net (), safe = True ). to ( device )
187
+ actor = QValueActor (spec = action_spec , module = make_net (), safe = True )
116
188
obs = torch .zeros (2 , obs_dim , device = device )
117
189
td = TensorDict (batch_size = [2 ], source = {"observation" : obs })
118
190
action = actor (td ).get ("action" )
119
191
assert (action .sum (- 1 ) == 1 ).all ()
120
192
121
- actor = QValueActor (spec = action_spec , module = make_net (), safe = False ). to ( device )
193
+ actor = QValueActor (spec = action_spec , module = make_net (), safe = False )
122
194
obs = torch .randn (2 , obs_dim , device = device )
123
195
td = TensorDict (batch_size = [2 ], source = {"observation" : obs })
124
196
action = actor (td ).get ("action" )
125
197
assert (action .sum (- 1 ) == 1 ).all ()
126
198
127
- actor = QValueActor (spec = action_spec , module = make_net (), safe = False ). to ( device )
199
+ actor = QValueActor (spec = action_spec , module = make_net (), safe = False )
128
200
obs = torch .zeros (2 , obs_dim , device = device )
129
201
td = TensorDict (batch_size = [2 ], source = {"observation" : obs })
130
202
action = actor (td ).get ("action" )
@@ -198,7 +270,8 @@ def test_lstm_net(device, out_features, hidden_size, num_layers, has_precond_hid
198
270
"num_layers" : num_layers ,
199
271
},
200
272
{"out_features" : hidden_size },
201
- ).to (device )
273
+ device = device ,
274
+ )
202
275
# test single step vs multi-step
203
276
x = torch .randn (batch , time_steps , in_features , device = device )
204
277
x_unbind = x .unbind (1 )
@@ -264,7 +337,8 @@ def test_lstm_net_nobatch(device, out_features, hidden_size):
264
337
out_features ,
265
338
{"input_size" : hidden_size , "hidden_size" : hidden_size },
266
339
{"out_features" : hidden_size },
267
- ).to (device )
340
+ device = device ,
341
+ )
268
342
# test single step vs multi-step
269
343
x = torch .randn (time_steps , in_features , device = device )
270
344
x_unbind = x .unbind (0 )
0 commit comments