@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
1096
1096
stride : Optional [int ],
1097
1097
lengths : Optional [torch .Tensor ],
1098
1098
offsets : Optional [torch .Tensor ],
1099
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1099
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1100
1100
) -> int :
1101
1101
if stride is None :
1102
1102
if len (keys ) == 0 :
1103
1103
stride = 0
1104
- elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1105
- stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1104
+ elif (
1105
+ stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1106
+ ):
1107
+ stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1106
1108
elif offsets is not None and offsets .numel () > 0 :
1107
1109
stride = (offsets .numel () - 1 ) // len (keys )
1108
1110
elif lengths is not None :
@@ -1668,14 +1670,18 @@ def _maybe_compute_lengths_offset_per_key(
1668
1670
1669
1671
def _maybe_compute_stride_per_key (
1670
1672
stride_per_key : Optional [List [int ]],
1671
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1673
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1672
1674
stride : Optional [int ],
1673
1675
keys : List [str ],
1674
1676
) -> Optional [List [int ]]:
1675
1677
if stride_per_key is not None :
1676
1678
return stride_per_key
1677
1679
elif stride_per_key_per_rank is not None :
1678
- return [sum (s ) for s in stride_per_key_per_rank ]
1680
+ if stride_per_key_per_rank .dim () != 2 :
1681
+ # after permute the kjt could be empty
1682
+ return []
1683
+ rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1684
+ return rt
1679
1685
elif stride is not None :
1680
1686
return [stride ] * len (keys )
1681
1687
else :
@@ -1766,7 +1772,9 @@ def __init__(
1766
1772
lengths : Optional [torch .Tensor ] = None ,
1767
1773
offsets : Optional [torch .Tensor ] = None ,
1768
1774
stride : Optional [int ] = None ,
1769
- stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1775
+ stride_per_key_per_rank : Optional [
1776
+ Union [torch .IntTensor , List [List [int ]]]
1777
+ ] = None ,
1770
1778
# Below exposed to ensure torch.script-able
1771
1779
stride_per_key : Optional [List [int ]] = None ,
1772
1780
length_per_key : Optional [List [int ]] = None ,
@@ -1788,8 +1796,10 @@ def __init__(
1788
1796
self ._lengths : Optional [torch .Tensor ] = lengths
1789
1797
self ._offsets : Optional [torch .Tensor ] = offsets
1790
1798
self ._stride : Optional [int ] = stride
1791
- self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1792
- stride_per_key_per_rank
1799
+ self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1800
+ torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1801
+ if isinstance (stride_per_key_per_rank , list )
1802
+ else stride_per_key_per_rank
1793
1803
)
1794
1804
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1795
1805
self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1816,8 +1826,7 @@ def _init_pt2_checks(self) -> None:
1816
1826
if self ._stride_per_key is not None :
1817
1827
pt2_checks_all_is_size (self ._stride_per_key )
1818
1828
if self ._stride_per_key_per_rank is not None :
1819
- # pyre-ignore [16]
1820
- for s in self ._stride_per_key_per_rank :
1829
+ for s in self .stride_per_key_per_rank ():
1821
1830
pt2_checks_all_is_size (s )
1822
1831
1823
1832
@staticmethod
@@ -2028,7 +2037,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
2028
2037
kjt_stride , kjt_stride_per_key_per_rank = (
2029
2038
(stride_per_key [0 ], None )
2030
2039
if all (s == stride_per_key [0 ] for s in stride_per_key )
2031
- else (None , [[ stride ] for stride in stride_per_key ] )
2040
+ else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2032
2041
)
2033
2042
kjt = KeyedJaggedTensor (
2034
2043
keys = kjt_keys ,
@@ -2193,8 +2202,13 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2193
2202
Returns:
2194
2203
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2195
2204
"""
2196
- stride_per_key_per_rank = self ._stride_per_key_per_rank
2197
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2205
+ # making a local reference to the class variable to make jit.script behave
2206
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2207
+ return (
2208
+ []
2209
+ if _stride_per_key_per_rank is None
2210
+ else _stride_per_key_per_rank .tolist ()
2211
+ )
2198
2212
2199
2213
def variable_stride_per_key (self ) -> bool :
2200
2214
"""
@@ -2514,17 +2528,17 @@ def permute(
2514
2528
2515
2529
length_per_key = self .length_per_key ()
2516
2530
permuted_keys : List [str ] = []
2517
- permuted_stride_per_key_per_rank : List [List [int ]] = []
2518
2531
permuted_length_per_key : List [int ] = []
2519
2532
permuted_length_per_key_sum = 0
2520
2533
for index in indices :
2521
2534
key = self .keys ()[index ]
2522
2535
permuted_keys .append (key )
2523
2536
permuted_length_per_key .append (length_per_key [index ])
2524
- if self .variable_stride_per_key ():
2525
- permuted_stride_per_key_per_rank .append (
2526
- self .stride_per_key_per_rank ()[index ]
2527
- )
2537
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2538
+ if self .variable_stride_per_key () and _stride_per_key_per_rank is not None :
2539
+ permuted_stride_per_key_per_rank = _stride_per_key_per_rank [indices , :]
2540
+ else :
2541
+ permuted_stride_per_key_per_rank = None
2528
2542
2529
2543
permuted_length_per_key_sum = sum (permuted_length_per_key )
2530
2544
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2576,17 +2590,15 @@ def permute(
2576
2590
self .weights_or_none (),
2577
2591
permuted_length_per_key_sum ,
2578
2592
)
2579
- stride_per_key_per_rank = (
2580
- permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2581
- )
2593
+
2582
2594
kjt = KeyedJaggedTensor (
2583
2595
keys = permuted_keys ,
2584
2596
values = permuted_values ,
2585
2597
weights = permuted_weights ,
2586
2598
lengths = permuted_lengths .view (- 1 ),
2587
2599
offsets = None ,
2588
2600
stride = self ._stride ,
2589
- stride_per_key_per_rank = stride_per_key_per_rank ,
2601
+ stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2590
2602
stride_per_key = None ,
2591
2603
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2592
2604
lengths_offset_per_key = None ,
0 commit comments