@@ -280,11 +280,10 @@ def _get_policy_and_device(
280
280
device = torch .device (device ) if device is not None else policy_device
281
281
get_weights_fn = None
282
282
if policy_device != device :
283
- param_and_buf = dict (policy .named_parameters ())
284
- param_and_buf .update (dict (policy .named_buffers ()))
283
+ param_and_buf = TensorDict .from_module (policy , as_module = True )
285
284
286
285
def get_weights_fn (param_and_buf = param_and_buf ):
287
- return TensorDict ( param_and_buf , []). apply ( lambda x : x . data )
286
+ return param_and_buf . data
288
287
289
288
policy_cast = deepcopy (policy ).requires_grad_ (False ).to (device )
290
289
# here things may break bc policy.to("cuda") gives us weights on cuda:0 (same
@@ -308,9 +307,9 @@ def update_policy_weights_(
308
307
309
308
"""
310
309
if policy_weights is not None :
311
- self .policy_weights .apply ( lambda x : x . data ) .update_ (policy_weights )
310
+ self .policy_weights .data .update_ (policy_weights )
312
311
elif self .get_weights_fn is not None :
313
- self .policy_weights .apply ( lambda x : x . data ) .update_ (self .get_weights_fn ())
312
+ self .policy_weights .data .update_ (self .get_weights_fn ())
314
313
315
314
def __iter__ (self ) -> Iterator [TensorDictBase ]:
316
315
return self .iterator ()
@@ -559,10 +558,7 @@ def __init__(
559
558
)
560
559
561
560
if isinstance (self .policy , nn .Module ):
562
- self .policy_weights = TensorDict (dict (self .policy .named_parameters ()), [])
563
- self .policy_weights .update (
564
- TensorDict (dict (self .policy .named_buffers ()), [])
565
- )
561
+ self .policy_weights = TensorDict .from_module (self .policy , as_module = True )
566
562
else :
567
563
self .policy_weights = TensorDict ({}, [])
568
564
@@ -1200,9 +1196,9 @@ def device_err_msg(device_name, devices_list):
1200
1196
)
1201
1197
self ._policy_dict [_device ] = _policy
1202
1198
if isinstance (_policy , nn .Module ):
1203
- param_dict = dict ( _policy . named_parameters ())
1204
- param_dict . update ( _policy . named_buffers ())
1205
- self . _policy_weights_dict [ _device ] = TensorDict ( param_dict , [] )
1199
+ self . _policy_weights_dict [ _device ] = TensorDict . from_module (
1200
+ _policy , as_module = True
1201
+ )
1206
1202
else :
1207
1203
self ._policy_weights_dict [_device ] = TensorDict ({}, [])
1208
1204
@@ -1288,11 +1284,9 @@ def frames_per_batch_worker(self):
1288
1284
def update_policy_weights_ (self , policy_weights = None ) -> None :
1289
1285
for _device in self ._policy_dict :
1290
1286
if policy_weights is not None :
1291
- self ._policy_weights_dict [_device ].apply (lambda x : x .data ).update_ (
1292
- policy_weights
1293
- )
1287
+ self ._policy_weights_dict [_device ].data .update_ (policy_weights )
1294
1288
elif self ._get_weights_fn_dict [_device ] is not None :
1295
- self ._policy_weights_dict [_device ].update_ (
1289
+ self ._policy_weights_dict [_device ].data . update_ (
1296
1290
self ._get_weights_fn_dict [_device ]()
1297
1291
)
1298
1292
0 commit comments