Skip to content

feature-batched inference QEC #1491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@
DEFAULT_ROW_ALIGNMENT = 16


@torch.fx.wrap
def set_fake_stbe_offsets(values: torch.Tensor) -> torch.Tensor:
return torch.arange(
0,
values.numel() + 1,
device=values.device,
dtype=values.dtype,
)


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -659,7 +669,7 @@ def __init__( # noqa C901
else EmbeddingLocation.DEVICE,
)
],
pooling_mode=PoolingMode.NONE,
pooling_mode=PoolingMode.SUM,
weight_lists=weight_lists,
device=device,
output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)),
Expand Down Expand Up @@ -728,23 +738,48 @@ def forward(
self._embedding_names_by_table,
self._emb_modules,
):
for feature_name, embedding_name in zip(
config.feature_names, embedding_names
):
indices_list = []
length_list = []
if len(config.feature_names) == 1:
feature_name = config.feature_names[0]
f = jt_dict[feature_name]
values = f.values()
offsets = f.offsets()
# Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module
offsets = set_fake_stbe_offsets(values)
lookup = (
emb_module(indices=values, offsets=offsets)
if self.register_tbes
else emb_module.forward(indices=values, offsets=offsets)
)
feature_embeddings[embedding_name] = JaggedTensor(
feature_embeddings[embedding_names[0]] = JaggedTensor(
values=lookup,
lengths=f.lengths(),
weights=f.values() if self.need_indices else None,
)
else:
for feature_name in config.feature_names:
f = jt_dict[feature_name]
values = f.values()
length_list.append(f.lengths().view(1, -1))
indices_list.append(values)
indices = torch.cat(indices_list)
offsets = set_fake_stbe_offsets(indices)
length_all = torch.cat(length_list, dim=0)
# Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module
lookup = (
emb_module(indices=indices, offsets=offsets)
if self.register_tbes
else emb_module.forward(indices=indices, offsets=offsets)
).view(-1, self._embedding_dim)
length_split = torch.sum(length_all, dim=1).tolist()
splits = lookup.split(length_split, dim=0)
for embedding_name, length, indice, embedding in zip(
embedding_names, length_list, indices_list, splits
):
feature_embeddings[embedding_name] = JaggedTensor(
values=embedding,
lengths=length.view(-1),
weights=indice if self.need_indices else None,
)
return feature_embeddings

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,20 +542,20 @@ def test_ec(
name="t1",
embedding_dim=16,
num_embeddings=10,
feature_names=["f1"],
feature_names=["f1", "f2"],
data_type=data_type,
)
eb2_config = EmbeddingConfig(
name="t2",
embedding_dim=16,
num_embeddings=10,
feature_names=["f2"],
feature_names=["f3", "f4"],
data_type=data_type,
)
features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.as_tensor([0, 1]),
lengths=torch.as_tensor([1, 1]),
keys=["f1", "f2", "f3", "f4"],
values=torch.as_tensor([0, 1, 2, 3, 4, 5, 6, 7]),
lengths=torch.as_tensor([1, 2, 3, 2]),
)
self._test_ec(
tables=[eb1_config, eb2_config],
Expand Down