Skip to content

Test include stride_per_key to KJT's flatten and unflatten #2903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 20 additions & 30 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,14 @@ def forward(
num_embeddings=10,
feature_names=["f2"],
)
config3 = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
ebc = EmbeddingBagCollection(
tables=[config1, config2],
tables=[config1, config2, config3],
is_weighted=False,
)

Expand Down Expand Up @@ -292,15 +298,17 @@ def test_serialize_deserialize_ebc(self) -> None:
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
model = self.generate_model_for_vbe_kjt()
id_list_features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
lengths=torch.tensor([3, 3, 2]),
stride_per_key_per_rank=[[2], [1]],
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
keys=["f1", "f2", "f3"],
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
stride_per_key_per_rank=[[3], [2], [1]],
inverse_indices=(
["f1", "f2", "f3"],
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
),
)

eager_out = model(id_list_features)
Expand All @@ -319,15 +327,16 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
# Run forward on ExportedProgram
ep_output = ep.module()(id_list_features)

self.assertEqual(len(ep_output), len(id_list_features.keys()))
for i, tensor in enumerate(ep_output):
self.assertEqual(eager_out[i].shape, tensor.shape)
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])

# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)

# check EBC config
for i in range(5):
for i in range(1):
ebc_name = f"ebc{i + 1}"
self.assertIsInstance(
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
Expand All @@ -342,29 +351,9 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

# check FPEBC config
for i in range(2):
fpebc_name = f"fpebc{i + 1}"
assert isinstance(
getattr(deserialized_model, fpebc_name),
FeatureProcessedEmbeddingBagCollection,
)

for deserialized, orginal in zip(
getattr(
deserialized_model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
getattr(
model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
):
self.assertEqual(deserialized.name, orginal.name)
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

# Run forward on deserialized model and compare the output
deserialized_model.load_state_dict(model.state_dict())

deserialized_out = deserialized_model(id_list_features)

self.assertEqual(len(deserialized_out), len(eager_out))
Expand All @@ -385,6 +374,7 @@ def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
)

eager_out = model(feature2)

# Serialize EBC
Expand Down
2 changes: 1 addition & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def reorder_inverse_indices(
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
feature_names: List[str],
) -> torch.Tensor:
if inverse_indices is None:
if inverse_indices is None or inverse_indices[1].numel() == 0:
return torch.empty(0)
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
index = torch.tensor(
Expand Down
51 changes: 44 additions & 7 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# pyre-strict

import abc
import dataclasses
import logging

import operator
from dataclasses import dataclass

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -1756,6 +1758,7 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
"_weights",
"_lengths",
"_offsets",
"_inverse_indices_tensor",
]

def __init__(
Expand Down Expand Up @@ -1801,6 +1804,12 @@ def __init__(
inverse_indices
)

# Init _inverse_indices_tensor to an empty tensor so it will be exported as FakeTensor by torch.export.
# Otherwise it will be exported as ConstantArgument when it's None, which causes Unsupported data type error.
self._inverse_indices_tensor: Optional[torch.Tensor] = torch.empty(0)
if inverse_indices is not None:
self._inverse_indices_tensor = inverse_indices[1]

# legacy attribute, for backward compatabilibity
self._variable_stride_per_key: Optional[bool] = None

Expand Down Expand Up @@ -3030,15 +3039,32 @@ def dist_init(
return kjt.sync()


@dataclass
class KjtTreeSpecs:
keys: List[str]
stride_per_key_per_rank: Optional[List[List[int]]]

def to_dict(self) -> dict[str, Any]:
return {
field.name: getattr(self, field.name) for field in dataclasses.fields(self)
}


def _kjt_flatten(
t: KeyedJaggedTensor,
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], Optional[List[List[int]]]]]:
return [getattr(t, a) for a in KeyedJaggedTensor._fields], (
t._keys,
t._stride_per_key_per_rank,
)


def _kjt_flatten_with_keys(
t: KeyedJaggedTensor,
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
) -> Tuple[
List[Tuple[KeyEntry, Optional[torch.Tensor]]],
Tuple[List[str], Optional[List[List[int]]]],
]:
values, context = _kjt_flatten(t)
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
return [ # pyre-ignore[7]
Expand All @@ -3047,9 +3073,17 @@ def _kjt_flatten_with_keys(


def _kjt_unflatten(
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
values: List[Optional[torch.Tensor]],
context: Tuple[
List[str], Optional[List[List[int]]]
], # context is the (_keys, _stride_per_key_per_rank, _inverse_indices) tuple
) -> KeyedJaggedTensor:
return KeyedJaggedTensor(context, *values)
return KeyedJaggedTensor(
context[0],
*values[:-1],
stride_per_key_per_rank=context[1],
inverse_indices=(context[0], values[-1]),
)


def _kjt_flatten_spec(
Expand All @@ -3070,7 +3104,9 @@ def _kjt_flatten_spec(

def flatten_kjt_list(
kjt_arr: List[KeyedJaggedTensor],
) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]:
) -> Tuple[
List[Optional[torch.Tensor]], List[Tuple[List[str], Optional[List[List[int]]]]]
]:
_flattened_data = []
_flattened_context = []
for t in kjt_arr:
Expand All @@ -3081,7 +3117,8 @@ def flatten_kjt_list(


def unflatten_kjt_list(
values: List[Optional[torch.Tensor]], contexts: List[List[str]]
values: List[Optional[torch.Tensor]],
contexts: List[Tuple[List[str], Optional[List[List[int]]]]],
) -> List[KeyedJaggedTensor]:
num_kjt_fields = len(KeyedJaggedTensor._fields)
length = len(values)
Expand Down
17 changes: 17 additions & 0 deletions torchrec/sparse/tests/test_keyed_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,23 @@ def test_meta_device_compatibility(self) -> None:
lengths=torch.tensor([], device=torch.device("meta")),
)

def test_flatten_unflatten_with_vbe(self) -> None:
kjt = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
lengths=torch.tensor([3, 3, 2]),
stride_per_key_per_rank=[[2], [1]],
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
)

flat_kjt, spec = pytree.tree_flatten(kjt)
unflattened_kjt = pytree.tree_unflatten(flat_kjt, spec)

self.assertEqual(
kjt.stride_per_key_per_rank(), unflattened_kjt.stride_per_key_per_rank()
)
self.assertEqual(kjt.inverse_indices(), unflattened_kjt.inverse_indices())


class TestKeyedJaggedTensorScripting(unittest.TestCase):
def test_scriptable_forward(self) -> None:
Expand Down
Loading