Skip to content

Commit 4c0eb1d

Browse files
author
Vincent Moens
authored
[Feature] First class dim compatibility (#525)
1 parent 302c342 commit 4c0eb1d

File tree

4 files changed

+178
-60
lines changed

4 files changed

+178
-60
lines changed

tensordict/nn/params.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Callable, Iterator, OrderedDict, Sequence
1414

1515
import torch
16+
from functorch import dim as ftdim
1617

1718
from tensordict import TensorDictBase
1819
from tensordict.nn.utils import Buffer
@@ -72,7 +73,7 @@ def _get_args_dict(func, args, kwargs):
7273

7374
def _maybe_make_param(tensor):
7475
if (
75-
isinstance(tensor, Tensor)
76+
isinstance(tensor, (Tensor, ftdim.Tensor))
7677
and not isinstance(tensor, nn.Parameter)
7778
and tensor.dtype in (torch.float, torch.double, torch.half)
7879
):
@@ -82,7 +83,7 @@ def _maybe_make_param(tensor):
8283

8384
def _maybe_make_param_or_buffer(tensor):
8485
if (
85-
isinstance(tensor, Tensor)
86+
isinstance(tensor, (Tensor, ftdim.Tensor))
8687
and not isinstance(tensor, nn.Parameter)
8788
and tensor.dtype in (torch.float, torch.double, torch.half)
8889
):
@@ -319,7 +320,7 @@ def __torch_function__(
319320
if kwargs is None:
320321
kwargs = {}
321322
if func not in TDPARAM_HANDLED_FUNCTIONS or not all(
322-
issubclass(t, (Tensor, TensorDictBase)) for t in types
323+
issubclass(t, (Tensor, ftdim.Tensor, TensorDictBase)) for t in types
323324
):
324325
return NotImplemented
325326
return TDPARAM_HANDLED_FUNCTIONS[func](*args, **kwargs)

tensordict/tensordict.py

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
TypeVar,
3333
Union,
3434
)
35+
3536
from warnings import warn
3637

3738
import numpy as np
3839

3940
import torch
41+
from functorch import dim as ftdim
4042
from tensordict._tensordict import _unravel_key_to_tuple
4143
from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor
4244
from tensordict.utils import (
@@ -69,7 +71,7 @@
6971
NestedKey,
7072
prod,
7173
)
72-
from torch import distributed as dist, multiprocessing as mp, Tensor
74+
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
7375
from torch.utils._pytree import tree_map
7476

7577
try:
@@ -408,6 +410,24 @@ def from_module(module, as_module: bool = False):
408410
return TensorDictParams(td, no_convert=True)
409411
return td
410412

413+
def to_module(self, module):
414+
from tensordict.nn.functional_modules import set_tensor_dict
415+
416+
__base__setattr__ = nn.Module.__setattr__
417+
# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
418+
__dict__ = module.__dict__
419+
420+
for key, value in self.items():
421+
cls = value.__class__
422+
if _is_tensor_collection(cls) or issubclass(cls, dict):
423+
value.to_module(__dict__["_modules"][key])
424+
else:
425+
if module.__class__.__setattr__ is __base__setattr__:
426+
set_tensor_dict(__dict__, module, key, value)
427+
else:
428+
# use specialized __setattr__ if needed
429+
setattr(module, key, value)
430+
411431
@property
412432
def shape(self) -> torch.Size:
413433
"""See :obj:`TensorDictBase.batch_size`."""
@@ -3515,6 +3535,8 @@ def _get_names_idx(self, idx):
35153535
else:
35163536

35173537
def is_boolean(idx):
3538+
if isinstance(idx, ftdim.Dim):
3539+
return None
35183540
if isinstance(idx, tuple) and len(idx) == 1:
35193541
return is_boolean(idx[0])
35203542
if hasattr(idx, "dtype") and idx.dtype is torch.bool:
@@ -3886,6 +3908,7 @@ def type(self, dst_type):
38863908
Tensor,
38873909
MemmapTensor,
38883910
TensorDictBase,
3911+
ftdim.Tensor,
38893912
]
38903913
if _has_torchrec:
38913914
_ACCEPTED_CLASSES += [KeyedJaggedTensor]
@@ -4584,11 +4607,6 @@ def memmap_(
45844607
raise RuntimeError(
45854608
"memmap and shared memory are mutually exclusive features."
45864609
)
4587-
# if not self._tensordict.keys():
4588-
# raise Exception(
4589-
# "memmap_() must be called when the TensorDict is (partially) "
4590-
# "populated. Set a tensor first."
4591-
# )
45924610
for key, value in self.items():
45934611
if value.requires_grad:
45944612
raise Exception(
@@ -6527,7 +6545,13 @@ def _split_index(self, index):
65276545
continue
65286546
if cursor == self.stack_dim:
65296547
# we need to check which tds need to be indexed
6530-
if isinstance(idx, slice) or _is_number(idx):
6548+
if isinstance(idx, ftdim.Dim):
6549+
raise ValueError(
6550+
"Cannot index a lazy stacked tensordict along the stack dimension with "
6551+
"a first-class dimension index. Consider consolidating the tensordict first "
6552+
"using `tensordict.contiguous()`."
6553+
)
6554+
elif isinstance(idx, slice) or _is_number(idx):
65316555
selected_td_idx = range(len(self.tensordicts))[idx]
65326556
if not isinstance(selected_td_idx, range):
65336557
isinteger = True
@@ -6559,6 +6583,7 @@ def _split_index(self, index):
65596583
idx,
65606584
(
65616585
int,
6586+
ftdim.Dim,
65626587
slice,
65636588
list,
65646589
range,
@@ -7372,54 +7397,6 @@ def __getitem__(self, index: IndexType) -> T:
73727397
out._td_dim_name = self._td_dim_name
73737398
return out
73747399

7375-
# index_dict = _convert_index_lazystack(index, self.stack_dim, self.batch_size)
7376-
# if index_dict is None:
7377-
# # then we use a sub-tensordict
7378-
# return self.get_sub_tensordict(index)
7379-
# td_index = index_dict["remaining_index"]
7380-
# stack_index = index_dict["stack_index"]
7381-
# new_stack_dim = index_dict["new_stack_dim"]
7382-
# if new_stack_dim is not None:
7383-
# if isinstance(stack_index, slice):
7384-
# # we can't iterate but we can index the list directly
7385-
# out = LazyStackedTensorDict(
7386-
# *[td[td_index] for td in self.tensordicts[stack_index]],
7387-
# stack_dim=new_stack_dim,
7388-
# )
7389-
# elif isinstance(stack_index, (list, range)):
7390-
# # then we can iterate
7391-
# out = LazyStackedTensorDict(
7392-
# *[self.tensordicts[idx][td_index] for idx in stack_index],
7393-
# stack_dim=new_stack_dim,
7394-
# )
7395-
# elif isinstance(stack_index, Tensor):
7396-
# # td_index is a nested tuple that mimics the shape of stack_index
7397-
# def _nested_stack(t: list, stack_idx: Tensor, td_index):
7398-
# if stack_idx.ndim:
7399-
# out = LazyStackedTensorDict(
7400-
# *[
7401-
# _nested_stack(t, _idx, td_index[i])
7402-
# for i, _idx in enumerate(stack_idx.unbind(0))
7403-
# ],
7404-
# stack_dim=new_stack_dim,
7405-
# )
7406-
# return out
7407-
# return t[stack_idx][td_index]
7408-
#
7409-
# # print(index, td_index, stack_index)
7410-
# out = _nested_stack(self.tensordicts, stack_index, td_index)
7411-
# else:
7412-
# raise TypeError("Invalid index used for stack dimension.")
7413-
# out._td_dim_name = self._td_dim_name
7414-
# return out
7415-
# out = self.tensordicts[stack_index]
7416-
# if td_index:
7417-
# return out[td_index]
7418-
# return out
7419-
7420-
# def __hash__(self):
7421-
# return hash(self.tensordicts)
7422-
74237400
def __eq__(self, other):
74247401
if is_tensorclass(other):
74257402
return other == self
@@ -9084,7 +9061,7 @@ def _clone_value(value: CompatibleType, recurse: bool) -> CompatibleType:
90849061

90859062

90869063
def _is_number(item):
9087-
if isinstance(item, Number):
9064+
if isinstance(item, (Number, ftdim.Dim)):
90889065
return True
90899066
if isinstance(item, Tensor) and item.ndim == 0:
90909067
return True

tensordict/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import torch
23+
from functorch import dim as ftdim
2324

2425
from packaging.version import parse
2526
from tensordict._tensordict import ( # noqa: F401
@@ -150,7 +151,7 @@ def _getitem_batch_size(batch_size, index):
150151
out.extend(bs_shape)
151152
bs_shape = None
152153
continue
153-
elif isinstance(idx, int):
154+
elif isinstance(idx, (int, ftdim.Dim)):
154155
# could be spared for efficiency
155156
continue
156157
elif isinstance(idx, slice):
@@ -761,9 +762,12 @@ def _is_shared(tensor: torch.Tensor) -> bool:
761762
if torch._C._functorch.is_batchedtensor(tensor):
762763
return None
763764
return tensor.is_shared()
765+
if isinstance(tensor, ftdim.Tensor):
766+
return None
764767
elif isinstance(tensor, KeyedJaggedTensor):
765768
return False
766769
else:
770+
print(type(tensor))
767771
return tensor.is_shared()
768772

769773

test/test_tensordict.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pytest
1313
import torch
1414

15+
from tensordict.nn import TensorDictParams
1516

1617
try:
1718
import torchsnapshot
@@ -30,6 +31,7 @@
3031
_has_h5py = False
3132

3233
from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase
34+
from functorch import dim as ftdim
3335

3436
from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict
3537
from tensordict.tensordict import (
@@ -6239,6 +6241,140 @@ def _pool_fixt():
62396241
yield pool
62406242

62416243

6244+
class TestFCD(TestTensorDictsBase):
6245+
"""Test stack for first-class dimension."""
6246+
6247+
@pytest.mark.parametrize(
6248+
"td_name",
6249+
[
6250+
"td",
6251+
"stacked_td",
6252+
"sub_td",
6253+
"sub_td2",
6254+
"idx_td",
6255+
"memmap_td",
6256+
"unsqueezed_td",
6257+
"squeezed_td",
6258+
"td_reset_bs",
6259+
"nested_td",
6260+
"nested_tensorclass",
6261+
"permute_td",
6262+
"nested_stacked_td",
6263+
"td_params",
6264+
pytest.param(
6265+
"td_h5",
6266+
marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."),
6267+
),
6268+
],
6269+
)
6270+
@pytest.mark.parametrize("device", get_available_devices())
6271+
def test_fcd(self, td_name, device):
6272+
td = getattr(self, td_name)(device)
6273+
d0 = ftdim.dims(1)
6274+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0:
6275+
with pytest.raises(ValueError, match="Cannot index"):
6276+
td[d0]
6277+
else:
6278+
assert td[d0].shape == td.shape[1:]
6279+
d0, d1 = ftdim.dims(2)
6280+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1):
6281+
with pytest.raises(ValueError, match="Cannot index"):
6282+
td[d0, d1]
6283+
else:
6284+
assert td[d0, d1].shape == td.shape[2:]
6285+
d0, d1, d2 = ftdim.dims(3)
6286+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2):
6287+
with pytest.raises(ValueError, match="Cannot index"):
6288+
td[d0, d1, d2]
6289+
else:
6290+
assert td[d0, d1, d2].shape == td.shape[3:]
6291+
d0 = ftdim.dims(1)
6292+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1:
6293+
with pytest.raises(ValueError, match="Cannot index"):
6294+
td[:, d0]
6295+
else:
6296+
assert td[:, d0].shape == torch.Size((td.shape[0], *td.shape[2:]))
6297+
6298+
@pytest.mark.parametrize(
6299+
"td_name",
6300+
[
6301+
"td",
6302+
"stacked_td",
6303+
"idx_td",
6304+
"memmap_td",
6305+
"td_reset_bs",
6306+
"nested_td",
6307+
"nested_tensorclass",
6308+
"nested_stacked_td",
6309+
"td_params",
6310+
pytest.param(
6311+
"td_h5",
6312+
marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."),
6313+
),
6314+
# these tds cannot see their dim names edited:
6315+
# "sub_td",
6316+
# "sub_td2",
6317+
# "unsqueezed_td",
6318+
# "squeezed_td",
6319+
# "permute_td",
6320+
],
6321+
)
6322+
@pytest.mark.parametrize("device", get_available_devices())
6323+
def test_fcd_names(self, td_name, device):
6324+
td = getattr(self, td_name)(device)
6325+
td.names = ["a", "b", "c", "d"]
6326+
d0 = ftdim.dims(1)
6327+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0:
6328+
with pytest.raises(ValueError, match="Cannot index"):
6329+
td[d0]
6330+
else:
6331+
assert td[d0].names == ["b", "c", "d"]
6332+
d0, d1 = ftdim.dims(2)
6333+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1):
6334+
with pytest.raises(ValueError, match="Cannot index"):
6335+
td[d0, d1]
6336+
else:
6337+
assert td[d0, d1].names == ["c", "d"]
6338+
d0, d1, d2 = ftdim.dims(3)
6339+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2):
6340+
with pytest.raises(ValueError, match="Cannot index"):
6341+
td[d0, d1, d2]
6342+
else:
6343+
assert td[d0, d1, d2].names == ["d"]
6344+
d0 = ftdim.dims(1)
6345+
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1:
6346+
with pytest.raises(ValueError, match="Cannot index"):
6347+
td[:, d0]
6348+
else:
6349+
assert td[:, d0].names == ["a", "c", "d"]
6350+
6351+
@pytest.mark.parametrize("as_module", [False, True])
6352+
def test_modules(self, as_module):
6353+
modules = [
6354+
lambda: nn.Linear(3, 4),
6355+
lambda: nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)),
6356+
lambda: nn.Transformer(16, 4, 2, 2, 8),
6357+
lambda: nn.Sequential(nn.Conv2d(3, 4, 3), nn.Conv2d(4, 4, 3)),
6358+
]
6359+
inputs = [
6360+
lambda: (torch.randn(2, 3),),
6361+
lambda: (torch.randn(2, 3),),
6362+
lambda: (torch.randn(2, 3, 16), torch.randn(2, 3, 16)),
6363+
lambda: (torch.randn(2, 3, 16, 16),),
6364+
]
6365+
param_batch = 5
6366+
for make_module, make_input in zip(modules, inputs):
6367+
module = make_module()
6368+
td = TensorDict.from_module(module, as_module=as_module)
6369+
td = td.expand(param_batch).clone()
6370+
d0 = ftdim.dims(1)
6371+
td = TensorDictParams(td)[d0]
6372+
td.to_module(module)
6373+
y = module(*make_input())
6374+
assert y.dims == (d0,)
6375+
assert y._tensor.shape[0] == param_batch
6376+
6377+
62426378
if __name__ == "__main__":
62436379
args, unknown = argparse.ArgumentParser().parse_known_args()
62446380
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)