Skip to content

Commit 1ffff9b

Browse files
Pengchao Wangfacebook-github-bot
Pengchao Wang
authored andcommitted
Add reverse module for ComputeKJTToJTDict to combine jt_dict to kjt (#1399)
Summary: Pull Request resolved: #1399 so that we don't need fx wrap KeyedJaggedTensor.from_jt_dict(jt_dict) manually everywhere. Also base on this we can do graph patten matching cancel (ComputeKJTToJTDict, ComputeKJTToJTDict) pairs during publish to save compute cycles. (see next diff in the stack) Reviewed By: houseroad, YazhiGao Differential Revision: D49423522 Privacy Context Container: 314155190942957 fbshipit-source-id: bbf5633d13b7003c3c81a0f676faf430f0a885a3
1 parent c1602e4 commit 1ffff9b

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,36 @@ def forward(
768768
)
769769

770770

771+
class ComputeJTDictToKJT(torch.nn.Module):
772+
"""Converts a dict of JaggedTensors to KeyedJaggedTensor.
773+
Args:
774+
775+
Example:
776+
passing in jt_dict
777+
{
778+
"Feature0": JaggedTensor([[V0,V1],None,V2]),
779+
"Feature1": JaggedTensor([V3,V4,[V5,V6,V7]]),
780+
}
781+
Returns::
782+
kjt with content:
783+
# 0 1 2 <-- dim_1
784+
# "Feature0" [V0,V1] None [V2]
785+
# "Feature1" [V3] [V4] [V5,V6,V7]
786+
# ^
787+
# dim_0
788+
789+
"""
790+
791+
def forward(self, jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
792+
"""
793+
Args:
794+
jt_dict: a dict of JaggedTensor
795+
Returns:
796+
KeyedJaggedTensor
797+
"""
798+
return KeyedJaggedTensor.from_jt_dict(jt_dict)
799+
800+
771801
@torch.fx.wrap
772802
def _maybe_compute_kjt_to_jt_dict(
773803
stride: int,

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.testing import FileCheck
1515
from torchrec.fx import symbolic_trace
1616
from torchrec.sparse.jagged_tensor import (
17+
ComputeJTDictToKJT,
1718
ComputeKJTToJTDict,
1819
JaggedTensor,
1920
jt_is_equal,
@@ -707,6 +708,37 @@ def test_pytree(self) -> None:
707708
self.assertTrue(torch.equal(j0.weights(), j1.weights()))
708709
self.assertTrue(torch.equal(j0.values(), j1.values()))
709710

711+
def test_compute_jt_dict_to_kjt_module(self) -> None:
712+
compute_jt_dict_to_kjt = ComputeJTDictToKJT()
713+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
714+
weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])
715+
keys = ["index_0", "index_1"]
716+
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
717+
718+
jag_tensor = KeyedJaggedTensor(
719+
values=values,
720+
keys=keys,
721+
offsets=offsets,
722+
weights=weights,
723+
)
724+
jag_tensor_dict = jag_tensor.to_dict()
725+
kjt = compute_jt_dict_to_kjt(jag_tensor_dict)
726+
j0 = kjt["index_0"]
727+
j1 = kjt["index_1"]
728+
729+
self.assertTrue(isinstance(j0, JaggedTensor))
730+
self.assertTrue(isinstance(j0, JaggedTensor))
731+
self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1])))
732+
self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5])))
733+
self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0])))
734+
self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3])))
735+
self.assertTrue(
736+
torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5]))
737+
)
738+
self.assertTrue(
739+
torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0]))
740+
)
741+
710742
def test_from_jt_dict(self) -> None:
711743
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
712744
weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])

0 commit comments

Comments
 (0)