Skip to content

Commit c195982

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
convert stride_per_key_per_rank to tensor inside KJT (#2959)
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Reviewed By: jd7-tr Differential Revision: D74366343
1 parent 8f55f3a commit c195982

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

torchrec/schema/api_tests/test_jagged_tensor_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
import unittest
12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
1515
from torchrec.schema.utils import is_signature_compatible
@@ -112,7 +112,9 @@ def __init__(
112112
lengths: Optional[torch.Tensor] = None,
113113
offsets: Optional[torch.Tensor] = None,
114114
stride: Optional[int] = None,
115-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
115+
stride_per_key_per_rank: Optional[
116+
Union[List[List[int]], torch.IntTensor]
117+
] = None,
116118
# Below exposed to ensure torch.script-able
117119
stride_per_key: Optional[List[int]] = None,
118120
length_per_key: Optional[List[int]] = None,

torchrec/sparse/jagged_tensor.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
10961096
stride: Optional[int],
10971097
lengths: Optional[torch.Tensor],
10981098
offsets: Optional[torch.Tensor],
1099-
stride_per_key_per_rank: Optional[List[List[int]]],
1099+
stride_per_key_per_rank: Optional[torch.IntTensor],
11001100
) -> int:
11011101
if stride is None:
11021102
if len(keys) == 0:
11031103
stride = 0
1104-
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
1105-
stride = max([sum(s) for s in stride_per_key_per_rank])
1104+
elif (
1105+
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1106+
):
1107+
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
11061108
elif offsets is not None and offsets.numel() > 0:
11071109
stride = (offsets.numel() - 1) // len(keys)
11081110
elif lengths is not None:
@@ -1668,14 +1670,18 @@ def _maybe_compute_lengths_offset_per_key(
16681670

16691671
def _maybe_compute_stride_per_key(
16701672
stride_per_key: Optional[List[int]],
1671-
stride_per_key_per_rank: Optional[List[List[int]]],
1673+
stride_per_key_per_rank: Optional[torch.IntTensor],
16721674
stride: Optional[int],
16731675
keys: List[str],
16741676
) -> Optional[List[int]]:
16751677
if stride_per_key is not None:
16761678
return stride_per_key
16771679
elif stride_per_key_per_rank is not None:
1678-
return [sum(s) for s in stride_per_key_per_rank]
1680+
if stride_per_key_per_rank.dim() != 2:
1681+
# after permute the kjt could be empty
1682+
return []
1683+
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
1684+
return rt
16791685
elif stride is not None:
16801686
return [stride] * len(keys)
16811687
else:
@@ -1766,7 +1772,9 @@ def __init__(
17661772
lengths: Optional[torch.Tensor] = None,
17671773
offsets: Optional[torch.Tensor] = None,
17681774
stride: Optional[int] = None,
1769-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
1775+
stride_per_key_per_rank: Optional[
1776+
Union[torch.IntTensor, List[List[int]]]
1777+
] = None,
17701778
# Below exposed to ensure torch.script-able
17711779
stride_per_key: Optional[List[int]] = None,
17721780
length_per_key: Optional[List[int]] = None,
@@ -1788,8 +1796,10 @@ def __init__(
17881796
self._lengths: Optional[torch.Tensor] = lengths
17891797
self._offsets: Optional[torch.Tensor] = offsets
17901798
self._stride: Optional[int] = stride
1791-
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
1792-
stride_per_key_per_rank
1799+
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
1800+
torch.IntTensor(stride_per_key_per_rank, device="cpu")
1801+
if isinstance(stride_per_key_per_rank, list)
1802+
else stride_per_key_per_rank
17931803
)
17941804
self._stride_per_key: Optional[List[int]] = stride_per_key
17951805
self._length_per_key: Optional[List[int]] = length_per_key
@@ -1816,8 +1826,7 @@ def _init_pt2_checks(self) -> None:
18161826
if self._stride_per_key is not None:
18171827
pt2_checks_all_is_size(self._stride_per_key)
18181828
if self._stride_per_key_per_rank is not None:
1819-
# pyre-ignore [16]
1820-
for s in self._stride_per_key_per_rank:
1829+
for s in self.stride_per_key_per_rank():
18211830
pt2_checks_all_is_size(s)
18221831

18231832
@staticmethod
@@ -2028,7 +2037,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
20282037
kjt_stride, kjt_stride_per_key_per_rank = (
20292038
(stride_per_key[0], None)
20302039
if all(s == stride_per_key[0] for s in stride_per_key)
2031-
else (None, [[stride] for stride in stride_per_key])
2040+
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
20322041
)
20332042
kjt = KeyedJaggedTensor(
20342043
keys=kjt_keys,
@@ -2193,8 +2202,13 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21932202
Returns:
21942203
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
21952204
"""
2196-
stride_per_key_per_rank = self._stride_per_key_per_rank
2197-
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2205+
# making a local reference to the class variable to make jit.script behave
2206+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2207+
return (
2208+
[]
2209+
if _stride_per_key_per_rank is None
2210+
else _stride_per_key_per_rank.tolist()
2211+
)
21982212

21992213
def variable_stride_per_key(self) -> bool:
22002214
"""
@@ -2514,17 +2528,17 @@ def permute(
25142528

25152529
length_per_key = self.length_per_key()
25162530
permuted_keys: List[str] = []
2517-
permuted_stride_per_key_per_rank: List[List[int]] = []
25182531
permuted_length_per_key: List[int] = []
25192532
permuted_length_per_key_sum = 0
25202533
for index in indices:
25212534
key = self.keys()[index]
25222535
permuted_keys.append(key)
25232536
permuted_length_per_key.append(length_per_key[index])
2524-
if self.variable_stride_per_key():
2525-
permuted_stride_per_key_per_rank.append(
2526-
self.stride_per_key_per_rank()[index]
2527-
)
2537+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2538+
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None:
2539+
permuted_stride_per_key_per_rank = _stride_per_key_per_rank[indices, :]
2540+
else:
2541+
permuted_stride_per_key_per_rank = None
25282542

25292543
permuted_length_per_key_sum = sum(permuted_length_per_key)
25302544
if not torch.jit.is_scripting() and is_non_strict_exporting():
@@ -2576,17 +2590,15 @@ def permute(
25762590
self.weights_or_none(),
25772591
permuted_length_per_key_sum,
25782592
)
2579-
stride_per_key_per_rank = (
2580-
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
2581-
)
2593+
25822594
kjt = KeyedJaggedTensor(
25832595
keys=permuted_keys,
25842596
values=permuted_values,
25852597
weights=permuted_weights,
25862598
lengths=permuted_lengths.view(-1),
25872599
offsets=None,
25882600
stride=self._stride,
2589-
stride_per_key_per_rank=stride_per_key_per_rank,
2601+
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
25902602
stride_per_key=None,
25912603
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
25922604
lengths_offset_per_key=None,

torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,9 @@ def bench(
465465
keys=kjt.keys(),
466466
values=kjt._values,
467467
lengths=kjt._lengths,
468-
stride_per_key_per_rank=kjt._stride_per_key_per_rank,
468+
stride_per_key_per_rank=torch.IntTensor(kjt._stride_per_key).reshape(
469+
-1, 1
470+
),
469471
)
470472
vbe_fn_kwargs = fn_kwargs.copy()
471473
if "kjt" in fn_kwargs:

0 commit comments

Comments
 (0)