Skip to content

convert stride_per_key_per_rank to tensor inside KJT #2959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,31 +703,31 @@ def disable_cuda_tf32(self) -> bool:
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"inductor",
"aot_eager",
_TestConfig(),
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"inductor",
"aot_eager",
_TestConfig(),
),
(
_ModelType.EBC,
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.FALSE,
"inductor",
"aot_eager",
_TestConfig(),
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.FALSE,
"inductor",
"aot_eager",
_TestConfig(),
),
]
Expand Down
4 changes: 2 additions & 2 deletions torchrec/pt2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def kjt_for_pt2_tracing(
values=values,
lengths=lengths,
weights=kjt.weights_or_none(),
stride_per_key_per_rank=[[stride]] * n,
stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"),
inverse_indices=(kjt.keys(), inverse_indices_tensor),
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def kjt_for_pt2_tracing(
lengths=lengths,
weights=weights,
stride=stride if not is_vb else None,
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None,
inverse_indices=inverse_indices,
)

Expand Down
6 changes: 4 additions & 2 deletions torchrec/schema/api_tests/test_jagged_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import inspect
import unittest
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from torchrec.schema.utils import is_signature_compatible
Expand Down Expand Up @@ -112,7 +112,9 @@ def __init__(
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Optional[
Union[List[List[int]], torch.IntTensor]
] = None,
# Below exposed to ensure torch.script-able
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
Expand Down
109 changes: 67 additions & 42 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
stride: Optional[int],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Optional[torch.IntTensor],
) -> int:
if stride is None:
if len(keys) == 0:
stride = 0
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
stride = max([sum(s) for s in stride_per_key_per_rank])
elif (
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
):
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
elif offsets is not None and offsets.numel() > 0:
stride = (offsets.numel() - 1) // len(keys)
elif lengths is not None:
Expand Down Expand Up @@ -1481,8 +1483,8 @@ def _strides_from_kjt(
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
# empty like function fx wrapped, also avoids device hardcoding
stride, stride_per_key_per_rank = (
(None, kjt.stride_per_key_per_rank())
if kjt.variable_stride_per_key()
(None, kjt._stride_per_key_per_rank)
if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key()
else (kjt.stride(), None)
)

Expand Down Expand Up @@ -1668,14 +1670,20 @@ def _maybe_compute_lengths_offset_per_key(

def _maybe_compute_stride_per_key(
stride_per_key: Optional[List[int]],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Optional[torch.IntTensor],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
if stride_per_key is not None:
return stride_per_key
elif stride_per_key_per_rank is not None:
return [sum(s) for s in stride_per_key_per_rank]
if stride_per_key_per_rank.dim() != 2:
# after permute the kjt could be empty
return []
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
pt2_checks_all_is_size(rt)
return rt
elif stride is not None:
return [stride] * len(keys)
else:
Expand Down Expand Up @@ -1766,7 +1774,9 @@ def __init__(
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Optional[
Union[torch.IntTensor, List[List[int]]]
] = None,
# Below exposed to ensure torch.script-able
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
Expand All @@ -1788,8 +1798,14 @@ def __init__(
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets
self._stride: Optional[int] = stride
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
stride_per_key_per_rank
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
# in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
# does not take List[List[int]]
assert not isinstance(stride_per_key_per_rank, list)
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
torch.IntTensor(stride_per_key_per_rank, device="cpu")
if isinstance(stride_per_key_per_rank, list)
else stride_per_key_per_rank
)
self._stride_per_key: Optional[List[int]] = stride_per_key
self._length_per_key: Optional[List[int]] = length_per_key
Expand All @@ -1815,10 +1831,6 @@ def _init_pt2_checks(self) -> None:
return
if self._stride_per_key is not None:
pt2_checks_all_is_size(self._stride_per_key)
if self._stride_per_key_per_rank is not None:
# pyre-ignore [16]
for s in self._stride_per_key_per_rank:
pt2_checks_all_is_size(s)

@staticmethod
def from_offsets_sync(
Expand Down Expand Up @@ -2028,7 +2040,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
kjt_stride, kjt_stride_per_key_per_rank = (
(stride_per_key[0], None)
if all(s == stride_per_key[0] for s in stride_per_key)
else (None, [[stride] for stride in stride_per_key])
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
)
kjt = KeyedJaggedTensor(
keys=kjt_keys,
Expand Down Expand Up @@ -2193,12 +2205,29 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
Returns:
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
"""
stride_per_key_per_rank = self._stride_per_key_per_rank
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
# making a local reference to the class variable to make jit.script behave
_stride_per_key_per_rank = self._stride_per_key_per_rank
if (
not torch.jit.is_scripting()
and is_torchdynamo_compiling()
and _stride_per_key_per_rank is not None
):
stride_per_key_per_rank = _stride_per_key_per_rank.tolist()
for stride_per_rank in stride_per_key_per_rank:
pt2_checks_all_is_size(stride_per_rank)
return stride_per_key_per_rank
return (
[]
if _stride_per_key_per_rank is None
else _stride_per_key_per_rank.tolist()
)

def variable_stride_per_key(self) -> bool:
"""
Returns whether the KeyedJaggedTensor has variable stride per key.
NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
is not `None`. It might be assigned to False externally/intentionally, usually the
`self._stride_per_key_per_rank` is trivial.

Returns:
bool: whether the KeyedJaggedTensor has variable stride per key.
Expand Down Expand Up @@ -2343,13 +2372,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
start_offset = 0
_length_per_key = self.length_per_key()
_offset_per_key = self.offset_per_key()
# use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
_stride_per_key_per_rank = self._stride_per_key_per_rank
for segment in segments:
end = start + segment
end_offset = _offset_per_key[end]
keys: List[str] = self._keys[start:end]
stride_per_key_per_rank = (
self.stride_per_key_per_rank()[start:end]
_stride_per_key_per_rank[start:end, :]
if self.variable_stride_per_key()
and _stride_per_key_per_rank is not None
else None
)
if segment == len(self._keys):
Expand Down Expand Up @@ -2514,17 +2546,17 @@ def permute(

length_per_key = self.length_per_key()
permuted_keys: List[str] = []
permuted_stride_per_key_per_rank: List[List[int]] = []
permuted_length_per_key: List[int] = []
permuted_length_per_key_sum = 0
for index in indices:
key = self.keys()[index]
permuted_keys.append(key)
permuted_length_per_key.append(length_per_key[index])
if self.variable_stride_per_key():
permuted_stride_per_key_per_rank.append(
self.stride_per_key_per_rank()[index]
)
_stride_per_key_per_rank = self._stride_per_key_per_rank
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None:
permuted_stride_per_key_per_rank = _stride_per_key_per_rank[indices, :]
else:
permuted_stride_per_key_per_rank = None

permuted_length_per_key_sum = sum(permuted_length_per_key)
if not torch.jit.is_scripting() and is_non_strict_exporting():
Expand Down Expand Up @@ -2576,17 +2608,15 @@ def permute(
self.weights_or_none(),
permuted_length_per_key_sum,
)
stride_per_key_per_rank = (
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
)

kjt = KeyedJaggedTensor(
keys=permuted_keys,
values=permuted_values,
weights=permuted_weights,
lengths=permuted_lengths.view(-1),
offsets=None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
stride_per_key=None,
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
lengths_offset_per_key=None,
Expand Down Expand Up @@ -2904,7 +2934,7 @@ def dist_init(

if variable_stride_per_key:
assert stride_per_rank_per_key is not None
stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view(
stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view(
num_workers, len(keys)
).T.cpu()

Expand Down Expand Up @@ -2941,23 +2971,18 @@ def dist_init(
weights,
)

stride_per_key_per_rank = torch.jit.annotate(
List[List[int]], stride_per_key_per_rank_tensor.tolist()
)
if stride_per_key_per_rank.numel() == 0:
stride_per_key_per_rank = torch.zeros(
(len(keys), 1), device="cpu", dtype=torch.int64
)

if not stride_per_key_per_rank:
stride_per_key_per_rank = [[0]] * len(keys)
if stagger > 1:
stride_per_key_per_rank_stagger: List[List[int]] = []
local_world_size = num_workers // stagger
for i in range(len(keys)):
stride_per_rank_stagger: List[int] = []
for j in range(local_world_size):
stride_per_rank_stagger.extend(
stride_per_key_per_rank[i][j::local_world_size]
)
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
stride_per_key_per_rank = stride_per_key_per_rank_stagger
indices = [
list(range(i, num_workers, local_world_size))
for i in range(local_world_size)
]
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]

kjt = KeyedJaggedTensor(
keys=keys,
Expand Down
Loading