Closed
Description
Motivation
Giving the ability of .masked_select()
inplace for the TensorDictBase
.
Solution
Giving the ability of .masked_select()
like but modification-inplace for the TensorDictBase
, by a method named .masked_select_()
.
Main steps to achieve this:
- Iterate key-values and collect masked tensors for values with type leaf tensor
- Iterate key-values with type of nested TensorDict, and call recursively .masked_select_()
- Modify the batch_size to the correct
Examples:
td = TensorDict(source={'a': torch.zeros(3, 4)},
batch_size=[3])
mask = torch.tensor([True, False, False])
td.masked_select_(mask)
td.get("a")
#output: tensor([[0., 0., 0., 0.]])