|
14 | 14 | from torch.testing import FileCheck
|
15 | 15 | from torchrec.fx import symbolic_trace
|
16 | 16 | from torchrec.sparse.jagged_tensor import (
|
| 17 | + ComputeJTDictToKJT, |
17 | 18 | ComputeKJTToJTDict,
|
18 | 19 | JaggedTensor,
|
19 | 20 | jt_is_equal,
|
@@ -707,6 +708,37 @@ def test_pytree(self) -> None:
|
707 | 708 | self.assertTrue(torch.equal(j0.weights(), j1.weights()))
|
708 | 709 | self.assertTrue(torch.equal(j0.values(), j1.values()))
|
709 | 710 |
|
| 711 | + def test_compute_jt_dict_to_kjt_module(self) -> None: |
| 712 | + compute_jt_dict_to_kjt = ComputeJTDictToKJT() |
| 713 | + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) |
| 714 | + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) |
| 715 | + keys = ["index_0", "index_1"] |
| 716 | + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) |
| 717 | + |
| 718 | + jag_tensor = KeyedJaggedTensor( |
| 719 | + values=values, |
| 720 | + keys=keys, |
| 721 | + offsets=offsets, |
| 722 | + weights=weights, |
| 723 | + ) |
| 724 | + jag_tensor_dict = jag_tensor.to_dict() |
| 725 | + kjt = compute_jt_dict_to_kjt(jag_tensor_dict) |
| 726 | + j0 = kjt["index_0"] |
| 727 | + j1 = kjt["index_1"] |
| 728 | + |
| 729 | + self.assertTrue(isinstance(j0, JaggedTensor)) |
| 730 | + self.assertTrue(isinstance(j0, JaggedTensor)) |
| 731 | + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) |
| 732 | + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) |
| 733 | + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) |
| 734 | + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) |
| 735 | + self.assertTrue( |
| 736 | + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) |
| 737 | + ) |
| 738 | + self.assertTrue( |
| 739 | + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) |
| 740 | + ) |
| 741 | + |
710 | 742 | def test_from_jt_dict(self) -> None:
|
711 | 743 | values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
|
712 | 744 | weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])
|
|
0 commit comments