Skip to content

Commit 254becf

Browse files
Leon Gaofacebook-github-bot
Leon Gao
authored andcommitted
simple fx rule for get length tensor (#1767)
Summary: ATT Reviewed By: jingsh, jiayisuse Differential Revision: D54603545
1 parent 1d6ce32 commit 254becf

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchrec/quant/embedding_modules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor:
8181
return output
8282

8383

84+
@torch.fx.wrap
85+
def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
86+
return feature.lengths()
87+
88+
8489
def for_each_module_of_type_do(
8590
module: nn.Module,
8691
module_types: List[Type[torch.nn.Module]],
@@ -863,9 +868,8 @@ def forward(
863868
):
864869
f = kjts_per_key[i]
865870
indices = f.values()
866-
lengths = f.lengths()
871+
lengths = _get_feature_length(f)
867872
offsets = f.offsets()
868-
lengths = f.lengths()
869873
lookup = (
870874
emb_module(indices=indices, offsets=offsets)
871875
if self.register_tbes

0 commit comments

Comments
 (0)