@@ -2178,18 +2178,8 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device):
21782178 assert observation_spec [key ].shape == torch .Size ([nchannels , 20 , h ])
21792179
21802180 @pytest .mark .parametrize ("nchannels" , [3 ])
2181- @pytest .mark .parametrize (
2182- "batch" ,
2183- [
2184- [2 ]
2185- ]
2186- )
2187- @pytest .mark .parametrize (
2188- "h" ,
2189- [
2190- None
2191- ]
2192- )
2181+ @pytest .mark .parametrize ("batch" , [[2 ]])
2182+ @pytest .mark .parametrize ("h" , [None ])
21932183 @pytest .mark .parametrize ("keys" , [["observation_pixels" ]])
21942184 @pytest .mark .parametrize ("device" , get_default_devices ())
21952185 def test_transform_model (self , keys , h , nchannels , batch , device ):
@@ -2214,18 +2204,8 @@ def test_transform_model(self, keys, h, nchannels, batch, device):
22142204 assert (td .get ("dont touch" ) == dont_touch ).all ()
22152205
22162206 @pytest .mark .parametrize ("nchannels" , [3 ])
2217- @pytest .mark .parametrize (
2218- "batch" ,
2219- [
2220- [2 ]
2221- ]
2222- )
2223- @pytest .mark .parametrize (
2224- "h" ,
2225- [
2226- None
2227- ]
2228- )
2207+ @pytest .mark .parametrize ("batch" , [[2 ]])
2208+ @pytest .mark .parametrize ("h" , [None ])
22292209 @pytest .mark .parametrize ("keys" , [["observation_pixels" ]])
22302210 @pytest .mark .parametrize ("device" , get_default_devices ())
22312211 def test_transform_compose (self , keys , h , nchannels , batch , device ):
@@ -2254,18 +2234,8 @@ def test_transform_compose(self, keys, h, nchannels, batch, device):
22542234 assert (tdc .get ("dont touch" ) == dont_touch ).all ()
22552235
22562236 @pytest .mark .parametrize ("nchannels" , [3 ])
2257- @pytest .mark .parametrize (
2258- "batch" ,
2259- [
2260- [2 ]
2261- ]
2262- )
2263- @pytest .mark .parametrize (
2264- "h" ,
2265- [
2266- None
2267- ]
2268- )
2237+ @pytest .mark .parametrize ("batch" , [[2 ]])
2238+ @pytest .mark .parametrize ("h" , [None ])
22692239 @pytest .mark .parametrize ("keys" , [["observation_pixels" ]])
22702240 @pytest .mark .parametrize ("rbclass" , [ReplayBuffer , TensorDictReplayBuffer ])
22712241 def test_transform_rb (
0 commit comments