Skip to content

[Feature Request] support for TensorDictBase.masked_select inplace #1178

Closed
@xmaples

Description

@xmaples

Motivation

Giving the ability of .masked_select() inplace for the TensorDictBase.

Solution

Giving the ability of .masked_select() like but modification-inplace for the TensorDictBase, by a method named .masked_select_().

Main steps to achieve this:

  1. Iterate key-values and collect masked tensors for values with type leaf tensor
  2. Iterate key-values with type of nested TensorDict, and call recursively .masked_select_()
  3. Modify the batch_size to the correct

Examples:

td = TensorDict(source={'a': torch.zeros(3, 4)},
    batch_size=[3])
mask = torch.tensor([True, False, False])
td.masked_select_(mask)
td.get("a")
#output: tensor([[0., 0., 0., 0.]])

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions