Skip to content

Commit d545364

Browse files
author
Vincent Moens
authored
[Refactor] Better weight update in collectors (#1723)
1 parent 6c27bdb commit d545364

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

torchrl/collectors/collectors.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,10 @@ def _get_policy_and_device(
280280
device = torch.device(device) if device is not None else policy_device
281281
get_weights_fn = None
282282
if policy_device != device:
283-
param_and_buf = dict(policy.named_parameters())
284-
param_and_buf.update(dict(policy.named_buffers()))
283+
param_and_buf = TensorDict.from_module(policy, as_module=True)
285284

286285
def get_weights_fn(param_and_buf=param_and_buf):
287-
return TensorDict(param_and_buf, []).apply(lambda x: x.data)
286+
return param_and_buf.data
288287

289288
policy_cast = deepcopy(policy).requires_grad_(False).to(device)
290289
# here things may break bc policy.to("cuda") gives us weights on cuda:0 (same
@@ -308,9 +307,9 @@ def update_policy_weights_(
308307
309308
"""
310309
if policy_weights is not None:
311-
self.policy_weights.apply(lambda x: x.data).update_(policy_weights)
310+
self.policy_weights.data.update_(policy_weights)
312311
elif self.get_weights_fn is not None:
313-
self.policy_weights.apply(lambda x: x.data).update_(self.get_weights_fn())
312+
self.policy_weights.data.update_(self.get_weights_fn())
314313

315314
def __iter__(self) -> Iterator[TensorDictBase]:
316315
return self.iterator()
@@ -559,10 +558,7 @@ def __init__(
559558
)
560559

561560
if isinstance(self.policy, nn.Module):
562-
self.policy_weights = TensorDict(dict(self.policy.named_parameters()), [])
563-
self.policy_weights.update(
564-
TensorDict(dict(self.policy.named_buffers()), [])
565-
)
561+
self.policy_weights = TensorDict.from_module(self.policy, as_module=True)
566562
else:
567563
self.policy_weights = TensorDict({}, [])
568564

@@ -1200,9 +1196,9 @@ def device_err_msg(device_name, devices_list):
12001196
)
12011197
self._policy_dict[_device] = _policy
12021198
if isinstance(_policy, nn.Module):
1203-
param_dict = dict(_policy.named_parameters())
1204-
param_dict.update(_policy.named_buffers())
1205-
self._policy_weights_dict[_device] = TensorDict(param_dict, [])
1199+
self._policy_weights_dict[_device] = TensorDict.from_module(
1200+
_policy, as_module=True
1201+
)
12061202
else:
12071203
self._policy_weights_dict[_device] = TensorDict({}, [])
12081204

@@ -1288,11 +1284,9 @@ def frames_per_batch_worker(self):
12881284
def update_policy_weights_(self, policy_weights=None) -> None:
12891285
for _device in self._policy_dict:
12901286
if policy_weights is not None:
1291-
self._policy_weights_dict[_device].apply(lambda x: x.data).update_(
1292-
policy_weights
1293-
)
1287+
self._policy_weights_dict[_device].data.update_(policy_weights)
12941288
elif self._get_weights_fn_dict[_device] is not None:
1295-
self._policy_weights_dict[_device].update_(
1289+
self._policy_weights_dict[_device].data.update_(
12961290
self._get_weights_fn_dict[_device]()
12971291
)
12981292

0 commit comments

Comments
 (0)