Skip to content

Commit 75b33c4

Browse files
author
Vincent Moens
committed
[Feature] param_count
ghstack-source-id: 69b76ae Pull Request resolved: #1046
1 parent ee49fc7 commit 75b33c4

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

tensordict/base.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,6 +3142,92 @@ def _set_device(self, device: torch.device) -> T:
31423142
value._set_device(device=device)
31433143
return self
31443144

3145+
@cache # noqa: B019
3146+
def param_count(self, *, count_duplicates: bool = True) -> int:
3147+
"""Counts the number of parameters (total number of indexable items), accounting for tensors only.
3148+
3149+
Keyword Args:
3150+
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
3151+
If ``False``, only strictly identical tensors will be discarded (same views but different
3152+
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
3153+
to be a single copy).
3154+
3155+
"""
3156+
vals = self._values_list(True, True)
3157+
total = 0
3158+
if not count_duplicates:
3159+
vals = set(vals)
3160+
for v in vals:
3161+
total += v.numel()
3162+
return total
3163+
3164+
@cache # noqa: B019
3165+
def bytes(self, *, count_duplicates: bool = True) -> int:
3166+
"""Counts the number of bytes of the contained tensors.
3167+
3168+
Keyword Args:
3169+
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
3170+
If ``False``, only strictly identical tensors will be discarded (same views but different
3171+
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
3172+
to be a single copy).
3173+
3174+
"""
3175+
set_of_tensors = set() if not count_duplicates else []
3176+
3177+
def add(tensor):
3178+
if count_duplicates:
3179+
set_of_tensors.append(tensor)
3180+
else:
3181+
set_of_tensors.add(tensor)
3182+
3183+
def count_bytes(tensor):
3184+
if tensor.is_nested:
3185+
if not tensor.layout == torch.jagged:
3186+
raise RuntimeError(
3187+
"NTs that are not jagged are not supported by the bytes method. Please use the jagged layout instead "
3188+
"or raise and issue on https://github.com/pytorch/tensordict/issues instead."
3189+
)
3190+
attrs, ctx = tensor.__tensor_flatten__()
3191+
for attr in attrs:
3192+
t = getattr(tensor, attr)
3193+
count_bytes(t)
3194+
return
3195+
if isinstance(tensor, torch.Tensor):
3196+
if isinstance(tensor, MemoryMappedTensor):
3197+
add(tensor)
3198+
return
3199+
if type(tensor) is not torch.Tensor:
3200+
try:
3201+
attrs, ctx = tensor.__tensor_flatten__()
3202+
for attr in attrs:
3203+
t = getattr(tensor, attr)
3204+
count_bytes(t)
3205+
return
3206+
except AttributeError:
3207+
warnings.warn(
3208+
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
3209+
"impossible to count the bytes it contains. Falling back on regular count.",
3210+
category=UserWarning,
3211+
)
3212+
count_bytes(torch.as_tensor(tensor))
3213+
return
3214+
3215+
grad = getattr(tensor, "grad", None)
3216+
if grad is not None:
3217+
count_bytes(grad)
3218+
count_bytes(tensor.data)
3219+
else:
3220+
add(tensor)
3221+
return
3222+
3223+
vals = self._values_list(True, True)
3224+
for v in vals:
3225+
count_bytes(v)
3226+
total = 0
3227+
for tensor in set_of_tensors:
3228+
total += tensor.numel() * tensor.dtype.itemsize
3229+
return total
3230+
31453231
def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T:
31463232
"""Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors.
31473233

test/test_tensordict.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@
132132
mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"
133133

134134

135+
@pytest.fixture
136+
def device_fixture():
137+
device = torch.get_default_device()
138+
if torch.cuda.is_available():
139+
torch.set_default_device(torch.device("cuda:0"))
140+
elif torch.backends.mps.is_available():
141+
torch.set_default_device(torch.device("mps:0"))
142+
yield
143+
torch.set_default_device(device)
144+
145+
135146
def _compare_tensors_identity(td0, td1):
136147
if isinstance(td0, LazyStackedTensorDict):
137148
if not isinstance(td1, LazyStackedTensorDict):
@@ -242,7 +253,32 @@ def test_batchsize_reset(self):
242253
td_u.batch_size = [1]
243254
td_u.to_tensordict().batch_size = [1]
244255

245-
def test_depth(ggself):
256+
@pytest.mark.parametrize("count_duplicates", [False, True])
257+
def test_bytes(self, count_duplicates, device_fixture):
258+
tensor = torch.zeros(3)
259+
tensor_with_grad = torch.ones(3, requires_grad=True)
260+
(tensor_with_grad + 1).sum().backward()
261+
v = torch.ones(3) * 2 # 12 bytes
262+
offsets = torch.tensor([0, 1, 3]) # 24 bytes
263+
lengths = torch.tensor([1, 2]) # 16 bytes
264+
njt = torch.nested.nested_tensor_from_jagged(
265+
v, offsets, lengths=lengths
266+
) # 52 bytes
267+
tricky = torch.nested.nested_tensor_from_jagged(
268+
tensor, offsets, lengths=lengths
269+
) # 52 bytes or 0
270+
td = TensorDict(
271+
tensor=tensor, # 3 * 4 = 12 bytes
272+
tensor_with_grad=tensor_with_grad, # 3 * 4 * 2 = 24 bytes
273+
njt=njt, # 32
274+
tricky=tricky, # 32 or 0
275+
)
276+
if count_duplicates:
277+
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 52
278+
else:
279+
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 0
280+
281+
def test_depth(self):
246282
td = TensorDict({"a": {"b": {"c": {"d": 0}, "e": 0}, "f": 0}, "g": 0}).lock_()
247283
assert td.depth == 3
248284
with td.unlock_():
@@ -1903,6 +1939,16 @@ def test_pad_sequence_pad_dim1(self, make_mask):
19031939
else:
19041940
assert "masks" not in padded_td.keys()
19051941

1942+
@pytest.mark.parametrize("count_duplicates", [False, True])
1943+
def test_param_count(self, count_duplicates):
1944+
td = TensorDict(a=torch.randn(3), b=torch.randn(6))
1945+
td["c"] = td["a"]
1946+
assert len(td._values_list(True, True)) == 3
1947+
if count_duplicates:
1948+
assert td.param_count(count_duplicates=count_duplicates) == 12
1949+
else:
1950+
assert td.param_count(count_duplicates=count_duplicates) == 9
1951+
19061952
@pytest.mark.parametrize("device", get_available_devices())
19071953
def test_permute(self, device):
19081954
torch.manual_seed(1)

0 commit comments

Comments
 (0)