@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
1096
1096
stride : Optional [int ],
1097
1097
lengths : Optional [torch .Tensor ],
1098
1098
offsets : Optional [torch .Tensor ],
1099
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1099
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1100
1100
) -> int :
1101
1101
if stride is None :
1102
1102
if len (keys ) == 0 :
1103
1103
stride = 0
1104
- elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1105
- stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1104
+ elif (
1105
+ stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1106
+ ):
1107
+ stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1106
1108
elif offsets is not None and offsets .numel () > 0 :
1107
1109
stride = (offsets .numel () - 1 ) // len (keys )
1108
1110
elif lengths is not None :
@@ -1481,8 +1483,8 @@ def _strides_from_kjt(
1481
1483
def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
1482
1484
# empty like function fx wrapped, also avoids device hardcoding
1483
1485
stride , stride_per_key_per_rank = (
1484
- (None , kjt .stride_per_key_per_rank () )
1485
- if kjt .variable_stride_per_key ()
1486
+ (None , kjt ._stride_per_key_per_rank )
1487
+ if kjt ._stride_per_key_per_rank is not None
1486
1488
else (kjt .stride (), None )
1487
1489
)
1488
1490
@@ -1668,14 +1670,20 @@ def _maybe_compute_lengths_offset_per_key(
1668
1670
1669
1671
def _maybe_compute_stride_per_key (
1670
1672
stride_per_key : Optional [List [int ]],
1671
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1673
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1672
1674
stride : Optional [int ],
1673
1675
keys : List [str ],
1674
1676
) -> Optional [List [int ]]:
1675
1677
if stride_per_key is not None :
1676
1678
return stride_per_key
1677
1679
elif stride_per_key_per_rank is not None :
1678
- return [sum (s ) for s in stride_per_key_per_rank ]
1680
+ if stride_per_key_per_rank .dim () != 2 :
1681
+ # after permute the kjt could be empty
1682
+ return []
1683
+ rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1684
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1685
+ pt2_checks_all_is_size (rt )
1686
+ return rt
1679
1687
elif stride is not None :
1680
1688
return [stride ] * len (keys )
1681
1689
else :
@@ -1766,7 +1774,9 @@ def __init__(
1766
1774
lengths : Optional [torch .Tensor ] = None ,
1767
1775
offsets : Optional [torch .Tensor ] = None ,
1768
1776
stride : Optional [int ] = None ,
1769
- stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1777
+ stride_per_key_per_rank : Optional [
1778
+ Union [torch .IntTensor , List [List [int ]]]
1779
+ ] = None ,
1770
1780
# Below exposed to ensure torch.script-able
1771
1781
stride_per_key : Optional [List [int ]] = None ,
1772
1782
length_per_key : Optional [List [int ]] = None ,
@@ -1788,8 +1798,14 @@ def __init__(
1788
1798
self ._lengths : Optional [torch .Tensor ] = lengths
1789
1799
self ._offsets : Optional [torch .Tensor ] = offsets
1790
1800
self ._stride : Optional [int ] = stride
1791
- self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1792
- stride_per_key_per_rank
1801
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1802
+ # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1803
+ # does not take List[List[int]]
1804
+ assert not isinstance (stride_per_key_per_rank , list )
1805
+ self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1806
+ torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1807
+ if isinstance (stride_per_key_per_rank , list )
1808
+ else stride_per_key_per_rank
1793
1809
)
1794
1810
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1795
1811
self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1815,10 +1831,11 @@ def _init_pt2_checks(self) -> None:
1815
1831
return
1816
1832
if self ._stride_per_key is not None :
1817
1833
pt2_checks_all_is_size (self ._stride_per_key )
1818
- if self ._stride_per_key_per_rank is not None :
1819
- # pyre-ignore [16]
1820
- for s in self ._stride_per_key_per_rank :
1821
- pt2_checks_all_is_size (s )
1834
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
1835
+ if _stride_per_key_per_rank is not None :
1836
+ stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
1837
+ for stride_per_rank in stride_per_key_per_rank :
1838
+ pt2_checks_all_is_size (stride_per_rank )
1822
1839
1823
1840
@staticmethod
1824
1841
def from_offsets_sync (
@@ -2028,7 +2045,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
2028
2045
kjt_stride , kjt_stride_per_key_per_rank = (
2029
2046
(stride_per_key [0 ], None )
2030
2047
if all (s == stride_per_key [0 ] for s in stride_per_key )
2031
- else (None , [[ stride ] for stride in stride_per_key ] )
2048
+ else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2032
2049
)
2033
2050
kjt = KeyedJaggedTensor (
2034
2051
keys = kjt_keys ,
@@ -2193,8 +2210,22 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2193
2210
Returns:
2194
2211
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2195
2212
"""
2196
- stride_per_key_per_rank = self ._stride_per_key_per_rank
2197
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2213
+ # making a local reference to the class variable to make jit.script behave
2214
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2215
+ if (
2216
+ not torch .jit .is_scripting ()
2217
+ and is_torchdynamo_compiling ()
2218
+ and _stride_per_key_per_rank is not None
2219
+ ):
2220
+ stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
2221
+ for stride_per_rank in stride_per_key_per_rank :
2222
+ pt2_checks_all_is_size (stride_per_rank )
2223
+ return stride_per_key_per_rank
2224
+ return (
2225
+ []
2226
+ if _stride_per_key_per_rank is None
2227
+ else _stride_per_key_per_rank .tolist ()
2228
+ )
2198
2229
2199
2230
def variable_stride_per_key (self ) -> bool :
2200
2231
"""
@@ -2343,13 +2374,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2343
2374
start_offset = 0
2344
2375
_length_per_key = self .length_per_key ()
2345
2376
_offset_per_key = self .offset_per_key ()
2377
+ # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2378
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2346
2379
for segment in segments :
2347
2380
end = start + segment
2348
2381
end_offset = _offset_per_key [end ]
2349
2382
keys : List [str ] = self ._keys [start :end ]
2350
2383
stride_per_key_per_rank = (
2351
- self .stride_per_key_per_rank ()[start :end ]
2352
- if self .variable_stride_per_key ()
2384
+ _stride_per_key_per_rank [start :end , :]
2385
+ if _stride_per_key_per_rank is not None
2386
+ and self .variable_stride_per_key ()
2353
2387
else None
2354
2388
)
2355
2389
if segment == len (self ._keys ):
@@ -2514,17 +2548,17 @@ def permute(
2514
2548
2515
2549
length_per_key = self .length_per_key ()
2516
2550
permuted_keys : List [str ] = []
2517
- permuted_stride_per_key_per_rank : List [List [int ]] = []
2518
2551
permuted_length_per_key : List [int ] = []
2519
2552
permuted_length_per_key_sum = 0
2520
2553
for index in indices :
2521
2554
key = self .keys ()[index ]
2522
2555
permuted_keys .append (key )
2523
2556
permuted_length_per_key .append (length_per_key [index ])
2524
- if self .variable_stride_per_key ():
2525
- permuted_stride_per_key_per_rank .append (
2526
- self .stride_per_key_per_rank ()[index ]
2527
- )
2557
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2558
+ if self .variable_stride_per_key () and _stride_per_key_per_rank is not None :
2559
+ permuted_stride_per_key_per_rank = _stride_per_key_per_rank [indices , :]
2560
+ else :
2561
+ permuted_stride_per_key_per_rank = None
2528
2562
2529
2563
permuted_length_per_key_sum = sum (permuted_length_per_key )
2530
2564
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2576,17 +2610,15 @@ def permute(
2576
2610
self .weights_or_none (),
2577
2611
permuted_length_per_key_sum ,
2578
2612
)
2579
- stride_per_key_per_rank = (
2580
- permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2581
- )
2613
+
2582
2614
kjt = KeyedJaggedTensor (
2583
2615
keys = permuted_keys ,
2584
2616
values = permuted_values ,
2585
2617
weights = permuted_weights ,
2586
2618
lengths = permuted_lengths .view (- 1 ),
2587
2619
offsets = None ,
2588
2620
stride = self ._stride ,
2589
- stride_per_key_per_rank = stride_per_key_per_rank ,
2621
+ stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2590
2622
stride_per_key = None ,
2591
2623
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2592
2624
lengths_offset_per_key = None ,
@@ -2904,7 +2936,7 @@ def dist_init(
2904
2936
2905
2937
if variable_stride_per_key :
2906
2938
assert stride_per_rank_per_key is not None
2907
- stride_per_key_per_rank_tensor : torch .Tensor = stride_per_rank_per_key .view (
2939
+ stride_per_key_per_rank : torch .Tensor = stride_per_rank_per_key .view (
2908
2940
num_workers , len (keys )
2909
2941
).T .cpu ()
2910
2942
@@ -2941,23 +2973,18 @@ def dist_init(
2941
2973
weights ,
2942
2974
)
2943
2975
2944
- stride_per_key_per_rank = torch .jit .annotate (
2945
- List [List [int ]], stride_per_key_per_rank_tensor .tolist ()
2946
- )
2976
+ if stride_per_key_per_rank .numel () == 0 :
2977
+ stride_per_key_per_rank = torch .zeros (
2978
+ (len (keys ), 1 ), device = "cpu" , dtype = torch .int64
2979
+ )
2947
2980
2948
- if not stride_per_key_per_rank :
2949
- stride_per_key_per_rank = [[0 ]] * len (keys )
2950
2981
if stagger > 1 :
2951
- stride_per_key_per_rank_stagger : List [List [int ]] = []
2952
2982
local_world_size = num_workers // stagger
2953
- for i in range (len (keys )):
2954
- stride_per_rank_stagger : List [int ] = []
2955
- for j in range (local_world_size ):
2956
- stride_per_rank_stagger .extend (
2957
- stride_per_key_per_rank [i ][j ::local_world_size ]
2958
- )
2959
- stride_per_key_per_rank_stagger .append (stride_per_rank_stagger )
2960
- stride_per_key_per_rank = stride_per_key_per_rank_stagger
2983
+ indices = [
2984
+ list (range (i , num_workers , local_world_size ))
2985
+ for i in range (local_world_size )
2986
+ ]
2987
+ stride_per_key_per_rank = stride_per_key_per_rank [:, indices ]
2961
2988
2962
2989
kjt = KeyedJaggedTensor (
2963
2990
keys = keys ,
0 commit comments