@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
1096
1096
stride : Optional [int ],
1097
1097
lengths : Optional [torch .Tensor ],
1098
1098
offsets : Optional [torch .Tensor ],
1099
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1099
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1100
1100
) -> int :
1101
1101
if stride is None :
1102
1102
if len (keys ) == 0 :
1103
1103
stride = 0
1104
- elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1105
- stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1104
+ elif (
1105
+ stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1106
+ ):
1107
+ stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1106
1108
elif offsets is not None and offsets .numel () > 0 :
1107
1109
stride = (offsets .numel () - 1 ) // len (keys )
1108
1110
elif lengths is not None :
@@ -1668,14 +1670,18 @@ def _maybe_compute_lengths_offset_per_key(
1668
1670
1669
1671
def _maybe_compute_stride_per_key (
1670
1672
stride_per_key : Optional [List [int ]],
1671
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1673
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1672
1674
stride : Optional [int ],
1673
1675
keys : List [str ],
1674
1676
) -> Optional [List [int ]]:
1675
1677
if stride_per_key is not None :
1676
1678
return stride_per_key
1677
1679
elif stride_per_key_per_rank is not None :
1678
- return [sum (s ) for s in stride_per_key_per_rank ]
1680
+ if stride_per_key_per_rank .dim () != 2 :
1681
+ # after permute the kjt could be empty
1682
+ return []
1683
+ rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1684
+ return rt
1679
1685
elif stride is not None :
1680
1686
return [stride ] * len (keys )
1681
1687
else :
@@ -1766,7 +1772,9 @@ def __init__(
1766
1772
lengths : Optional [torch .Tensor ] = None ,
1767
1773
offsets : Optional [torch .Tensor ] = None ,
1768
1774
stride : Optional [int ] = None ,
1769
- stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1775
+ stride_per_key_per_rank : Optional [
1776
+ Union [torch .IntTensor , List [List [int ]]]
1777
+ ] = None ,
1770
1778
# Below exposed to ensure torch.script-able
1771
1779
stride_per_key : Optional [List [int ]] = None ,
1772
1780
length_per_key : Optional [List [int ]] = None ,
@@ -1788,8 +1796,10 @@ def __init__(
1788
1796
self ._lengths : Optional [torch .Tensor ] = lengths
1789
1797
self ._offsets : Optional [torch .Tensor ] = offsets
1790
1798
self ._stride : Optional [int ] = stride
1791
- self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1792
- stride_per_key_per_rank
1799
+ self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1800
+ torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1801
+ if isinstance (stride_per_key_per_rank , list )
1802
+ else stride_per_key_per_rank
1793
1803
)
1794
1804
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1795
1805
self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1815,10 +1825,11 @@ def _init_pt2_checks(self) -> None:
1815
1825
return
1816
1826
if self ._stride_per_key is not None :
1817
1827
pt2_checks_all_is_size (self ._stride_per_key )
1818
- if self ._stride_per_key_per_rank is not None :
1819
- # pyre-ignore [16]
1820
- for s in self ._stride_per_key_per_rank :
1821
- pt2_checks_all_is_size (s )
1828
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
1829
+ if _stride_per_key_per_rank is not None :
1830
+ for stride_per_rank in _stride_per_key_per_rank :
1831
+ for s in stride_per_rank :
1832
+ torch ._check_is_size (s .item ())
1822
1833
1823
1834
@staticmethod
1824
1835
def from_offsets_sync (
@@ -2028,7 +2039,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
2028
2039
kjt_stride , kjt_stride_per_key_per_rank = (
2029
2040
(stride_per_key [0 ], None )
2030
2041
if all (s == stride_per_key [0 ] for s in stride_per_key )
2031
- else (None , [[ stride ] for stride in stride_per_key ] )
2042
+ else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2032
2043
)
2033
2044
kjt = KeyedJaggedTensor (
2034
2045
keys = kjt_keys ,
@@ -2193,8 +2204,13 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2193
2204
Returns:
2194
2205
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2195
2206
"""
2196
- stride_per_key_per_rank = self ._stride_per_key_per_rank
2197
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2207
+ # making a local reference to the class variable to make jit.script behave
2208
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2209
+ return (
2210
+ []
2211
+ if _stride_per_key_per_rank is None
2212
+ else _stride_per_key_per_rank .tolist ()
2213
+ )
2198
2214
2199
2215
def variable_stride_per_key (self ) -> bool :
2200
2216
"""
@@ -2514,17 +2530,17 @@ def permute(
2514
2530
2515
2531
length_per_key = self .length_per_key ()
2516
2532
permuted_keys : List [str ] = []
2517
- permuted_stride_per_key_per_rank : List [List [int ]] = []
2518
2533
permuted_length_per_key : List [int ] = []
2519
2534
permuted_length_per_key_sum = 0
2520
2535
for index in indices :
2521
2536
key = self .keys ()[index ]
2522
2537
permuted_keys .append (key )
2523
2538
permuted_length_per_key .append (length_per_key [index ])
2524
- if self .variable_stride_per_key ():
2525
- permuted_stride_per_key_per_rank .append (
2526
- self .stride_per_key_per_rank ()[index ]
2527
- )
2539
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2540
+ if self .variable_stride_per_key () and _stride_per_key_per_rank is not None :
2541
+ permuted_stride_per_key_per_rank = _stride_per_key_per_rank [indices , :]
2542
+ else :
2543
+ permuted_stride_per_key_per_rank = None
2528
2544
2529
2545
permuted_length_per_key_sum = sum (permuted_length_per_key )
2530
2546
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2576,17 +2592,15 @@ def permute(
2576
2592
self .weights_or_none (),
2577
2593
permuted_length_per_key_sum ,
2578
2594
)
2579
- stride_per_key_per_rank = (
2580
- permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2581
- )
2595
+
2582
2596
kjt = KeyedJaggedTensor (
2583
2597
keys = permuted_keys ,
2584
2598
values = permuted_values ,
2585
2599
weights = permuted_weights ,
2586
2600
lengths = permuted_lengths .view (- 1 ),
2587
2601
offsets = None ,
2588
2602
stride = self ._stride ,
2589
- stride_per_key_per_rank = stride_per_key_per_rank ,
2603
+ stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2590
2604
stride_per_key = None ,
2591
2605
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2592
2606
lengths_offset_per_key = None ,
0 commit comments