Skip to content

Commit 5a71b86

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
convert stride_per_key_per_rank to tensor inside KJT (#2959)
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Differential Revision: D74366343
1 parent d6031f9 commit 5a71b86

File tree

4 files changed

+80
-51
lines changed

4 files changed

+80
-51
lines changed

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,31 +703,31 @@ def disable_cuda_tf32(self) -> bool:
703703
ShardingType.TABLE_WISE.value,
704704
_InputType.SINGLE_BATCH,
705705
_ConvertToVariableBatch.TRUE,
706-
"inductor",
706+
"aot_eager",
707707
_TestConfig(),
708708
),
709709
(
710710
_ModelType.EBC,
711711
ShardingType.COLUMN_WISE.value,
712712
_InputType.SINGLE_BATCH,
713713
_ConvertToVariableBatch.TRUE,
714-
"inductor",
714+
"aot_eager",
715715
_TestConfig(),
716716
),
717717
(
718718
_ModelType.EBC,
719719
ShardingType.TABLE_WISE.value,
720720
_InputType.SINGLE_BATCH,
721721
_ConvertToVariableBatch.FALSE,
722-
"inductor",
722+
"aot_eager",
723723
_TestConfig(),
724724
),
725725
(
726726
_ModelType.EBC,
727727
ShardingType.COLUMN_WISE.value,
728728
_InputType.SINGLE_BATCH,
729729
_ConvertToVariableBatch.FALSE,
730-
"inductor",
730+
"aot_eager",
731731
_TestConfig(),
732732
),
733733
]

torchrec/pt2/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def kjt_for_pt2_tracing(
5454
values=values,
5555
lengths=lengths,
5656
weights=kjt.weights_or_none(),
57-
stride_per_key_per_rank=[[stride]] * n,
57+
stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"),
5858
inverse_indices=(kjt.keys(), inverse_indices_tensor),
5959
)
6060

@@ -85,7 +85,7 @@ def kjt_for_pt2_tracing(
8585
lengths=lengths,
8686
weights=weights,
8787
stride=stride if not is_vb else None,
88-
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
88+
stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None,
8989
inverse_indices=inverse_indices,
9090
)
9191

torchrec/schema/api_tests/test_jagged_tensor_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
import unittest
12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
1515
from torchrec.schema.utils import is_signature_compatible
@@ -112,7 +112,9 @@ def __init__(
112112
lengths: Optional[torch.Tensor] = None,
113113
offsets: Optional[torch.Tensor] = None,
114114
stride: Optional[int] = None,
115-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
115+
stride_per_key_per_rank: Optional[
116+
Union[List[List[int]], torch.IntTensor]
117+
] = None,
116118
# Below exposed to ensure torch.script-able
117119
stride_per_key: Optional[List[int]] = None,
118120
length_per_key: Optional[List[int]] = None,

torchrec/sparse/jagged_tensor.py

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
10961096
stride: Optional[int],
10971097
lengths: Optional[torch.Tensor],
10981098
offsets: Optional[torch.Tensor],
1099-
stride_per_key_per_rank: Optional[List[List[int]]],
1099+
stride_per_key_per_rank: Optional[torch.IntTensor],
11001100
) -> int:
11011101
if stride is None:
11021102
if len(keys) == 0:
11031103
stride = 0
1104-
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
1105-
stride = max([sum(s) for s in stride_per_key_per_rank])
1104+
elif (
1105+
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1106+
):
1107+
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
11061108
elif offsets is not None and offsets.numel() > 0:
11071109
stride = (offsets.numel() - 1) // len(keys)
11081110
elif lengths is not None:
@@ -1481,8 +1483,8 @@ def _strides_from_kjt(
14811483
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14821484
# empty like function fx wrapped, also avoids device hardcoding
14831485
stride, stride_per_key_per_rank = (
1484-
(None, kjt.stride_per_key_per_rank())
1485-
if kjt.variable_stride_per_key()
1486+
(None, kjt._stride_per_key_per_rank)
1487+
if kjt._stride_per_key_per_rank is not None
14861488
else (kjt.stride(), None)
14871489
)
14881490

@@ -1668,14 +1670,20 @@ def _maybe_compute_lengths_offset_per_key(
16681670

16691671
def _maybe_compute_stride_per_key(
16701672
stride_per_key: Optional[List[int]],
1671-
stride_per_key_per_rank: Optional[List[List[int]]],
1673+
stride_per_key_per_rank: Optional[torch.IntTensor],
16721674
stride: Optional[int],
16731675
keys: List[str],
16741676
) -> Optional[List[int]]:
16751677
if stride_per_key is not None:
16761678
return stride_per_key
16771679
elif stride_per_key_per_rank is not None:
1678-
return [sum(s) for s in stride_per_key_per_rank]
1680+
if stride_per_key_per_rank.dim() != 2:
1681+
# after permute the kjt could be empty
1682+
return []
1683+
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
1684+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1685+
pt2_checks_all_is_size(rt)
1686+
return rt
16791687
elif stride is not None:
16801688
return [stride] * len(keys)
16811689
else:
@@ -1766,7 +1774,9 @@ def __init__(
17661774
lengths: Optional[torch.Tensor] = None,
17671775
offsets: Optional[torch.Tensor] = None,
17681776
stride: Optional[int] = None,
1769-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
1777+
stride_per_key_per_rank: Optional[
1778+
Union[torch.IntTensor, List[List[int]]]
1779+
] = None,
17701780
# Below exposed to ensure torch.script-able
17711781
stride_per_key: Optional[List[int]] = None,
17721782
length_per_key: Optional[List[int]] = None,
@@ -1788,8 +1798,14 @@ def __init__(
17881798
self._lengths: Optional[torch.Tensor] = lengths
17891799
self._offsets: Optional[torch.Tensor] = offsets
17901800
self._stride: Optional[int] = stride
1791-
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
1792-
stride_per_key_per_rank
1801+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1802+
# in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1803+
# does not take List[List[int]]
1804+
assert not isinstance(stride_per_key_per_rank, list)
1805+
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
1806+
torch.IntTensor(stride_per_key_per_rank, device="cpu")
1807+
if isinstance(stride_per_key_per_rank, list)
1808+
else stride_per_key_per_rank
17931809
)
17941810
self._stride_per_key: Optional[List[int]] = stride_per_key
17951811
self._length_per_key: Optional[List[int]] = length_per_key
@@ -1815,10 +1831,11 @@ def _init_pt2_checks(self) -> None:
18151831
return
18161832
if self._stride_per_key is not None:
18171833
pt2_checks_all_is_size(self._stride_per_key)
1818-
if self._stride_per_key_per_rank is not None:
1819-
# pyre-ignore [16]
1820-
for s in self._stride_per_key_per_rank:
1821-
pt2_checks_all_is_size(s)
1834+
_stride_per_key_per_rank = self._stride_per_key_per_rank
1835+
if _stride_per_key_per_rank is not None:
1836+
stride_per_key_per_rank = _stride_per_key_per_rank.tolist()
1837+
for stride_per_rank in stride_per_key_per_rank:
1838+
pt2_checks_all_is_size(stride_per_rank)
18221839

18231840
@staticmethod
18241841
def from_offsets_sync(
@@ -2028,7 +2045,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
20282045
kjt_stride, kjt_stride_per_key_per_rank = (
20292046
(stride_per_key[0], None)
20302047
if all(s == stride_per_key[0] for s in stride_per_key)
2031-
else (None, [[stride] for stride in stride_per_key])
2048+
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
20322049
)
20332050
kjt = KeyedJaggedTensor(
20342051
keys=kjt_keys,
@@ -2193,8 +2210,22 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21932210
Returns:
21942211
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
21952212
"""
2196-
stride_per_key_per_rank = self._stride_per_key_per_rank
2197-
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2213+
# making a local reference to the class variable to make jit.script behave
2214+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2215+
if (
2216+
not torch.jit.is_scripting()
2217+
and is_torchdynamo_compiling()
2218+
and _stride_per_key_per_rank is not None
2219+
):
2220+
stride_per_key_per_rank = _stride_per_key_per_rank.tolist()
2221+
for stride_per_rank in stride_per_key_per_rank:
2222+
pt2_checks_all_is_size(stride_per_rank)
2223+
return stride_per_key_per_rank
2224+
return (
2225+
[]
2226+
if _stride_per_key_per_rank is None
2227+
else _stride_per_key_per_rank.tolist()
2228+
)
21982229

21992230
def variable_stride_per_key(self) -> bool:
22002231
"""
@@ -2343,13 +2374,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23432374
start_offset = 0
23442375
_length_per_key = self.length_per_key()
23452376
_offset_per_key = self.offset_per_key()
2377+
# use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2378+
_stride_per_key_per_rank = self._stride_per_key_per_rank
23462379
for segment in segments:
23472380
end = start + segment
23482381
end_offset = _offset_per_key[end]
23492382
keys: List[str] = self._keys[start:end]
23502383
stride_per_key_per_rank = (
2351-
self.stride_per_key_per_rank()[start:end]
2352-
if self.variable_stride_per_key()
2384+
_stride_per_key_per_rank[start:end, :]
2385+
if _stride_per_key_per_rank is not None
2386+
and self.variable_stride_per_key()
23532387
else None
23542388
)
23552389
if segment == len(self._keys):
@@ -2514,17 +2548,17 @@ def permute(
25142548

25152549
length_per_key = self.length_per_key()
25162550
permuted_keys: List[str] = []
2517-
permuted_stride_per_key_per_rank: List[List[int]] = []
25182551
permuted_length_per_key: List[int] = []
25192552
permuted_length_per_key_sum = 0
25202553
for index in indices:
25212554
key = self.keys()[index]
25222555
permuted_keys.append(key)
25232556
permuted_length_per_key.append(length_per_key[index])
2524-
if self.variable_stride_per_key():
2525-
permuted_stride_per_key_per_rank.append(
2526-
self.stride_per_key_per_rank()[index]
2527-
)
2557+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2558+
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None:
2559+
permuted_stride_per_key_per_rank = _stride_per_key_per_rank[indices, :]
2560+
else:
2561+
permuted_stride_per_key_per_rank = None
25282562

25292563
permuted_length_per_key_sum = sum(permuted_length_per_key)
25302564
if not torch.jit.is_scripting() and is_non_strict_exporting():
@@ -2576,17 +2610,15 @@ def permute(
25762610
self.weights_or_none(),
25772611
permuted_length_per_key_sum,
25782612
)
2579-
stride_per_key_per_rank = (
2580-
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
2581-
)
2613+
25822614
kjt = KeyedJaggedTensor(
25832615
keys=permuted_keys,
25842616
values=permuted_values,
25852617
weights=permuted_weights,
25862618
lengths=permuted_lengths.view(-1),
25872619
offsets=None,
25882620
stride=self._stride,
2589-
stride_per_key_per_rank=stride_per_key_per_rank,
2621+
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
25902622
stride_per_key=None,
25912623
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
25922624
lengths_offset_per_key=None,
@@ -2904,7 +2936,7 @@ def dist_init(
29042936

29052937
if variable_stride_per_key:
29062938
assert stride_per_rank_per_key is not None
2907-
stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view(
2939+
stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view(
29082940
num_workers, len(keys)
29092941
).T.cpu()
29102942

@@ -2941,23 +2973,18 @@ def dist_init(
29412973
weights,
29422974
)
29432975

2944-
stride_per_key_per_rank = torch.jit.annotate(
2945-
List[List[int]], stride_per_key_per_rank_tensor.tolist()
2946-
)
2976+
if stride_per_key_per_rank.numel() == 0:
2977+
stride_per_key_per_rank = torch.zeros(
2978+
(len(keys), 1), device="cpu", dtype=torch.int64
2979+
)
29472980

2948-
if not stride_per_key_per_rank:
2949-
stride_per_key_per_rank = [[0]] * len(keys)
29502981
if stagger > 1:
2951-
stride_per_key_per_rank_stagger: List[List[int]] = []
29522982
local_world_size = num_workers // stagger
2953-
for i in range(len(keys)):
2954-
stride_per_rank_stagger: List[int] = []
2955-
for j in range(local_world_size):
2956-
stride_per_rank_stagger.extend(
2957-
stride_per_key_per_rank[i][j::local_world_size]
2958-
)
2959-
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
2960-
stride_per_key_per_rank = stride_per_key_per_rank_stagger
2983+
indices = [
2984+
list(range(i, num_workers, local_world_size))
2985+
for i in range(local_world_size)
2986+
]
2987+
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
29612988

29622989
kjt = KeyedJaggedTensor(
29632990
keys=keys,

0 commit comments

Comments
 (0)