@@ -1098,15 +1098,13 @@ def _maybe_compute_stride_kjt(
10981098 stride : Optional [int ],
10991099 lengths : Optional [torch .Tensor ],
11001100 offsets : Optional [torch .Tensor ],
1101- stride_per_key_per_rank : Optional [torch . IntTensor ],
1101+ stride_per_key_per_rank : Optional [List [ List [ int ]] ],
11021102) -> int :
11031103 if stride is None :
11041104 if len (keys ) == 0 :
11051105 stride = 0
1106- elif (
1107- stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1108- ):
1109- stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1106+ elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1107+ stride = max ([sum (s ) for s in stride_per_key_per_rank ])
11101108 elif offsets is not None and offsets .numel () > 0 :
11111109 stride = (offsets .numel () - 1 ) // len (keys )
11121110 elif lengths is not None :
@@ -1485,8 +1483,8 @@ def _strides_from_kjt(
14851483def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
14861484 # empty like function fx wrapped, also avoids device hardcoding
14871485 stride , stride_per_key_per_rank = (
1488- (None , kjt ._stride_per_key_per_rank )
1489- if kjt ._stride_per_key_per_rank is not None and kjt . variable_stride_per_key ()
1486+ (None , kjt .stride_per_key_per_rank () )
1487+ if kjt .variable_stride_per_key ()
14901488 else (kjt .stride (), None )
14911489 )
14921490
@@ -1672,20 +1670,14 @@ def _maybe_compute_lengths_offset_per_key(
16721670
16731671def _maybe_compute_stride_per_key (
16741672 stride_per_key : Optional [List [int ]],
1675- stride_per_key_per_rank : Optional [torch . IntTensor ],
1673+ stride_per_key_per_rank : Optional [List [ List [ int ]] ],
16761674 stride : Optional [int ],
16771675 keys : List [str ],
16781676) -> Optional [List [int ]]:
16791677 if stride_per_key is not None :
16801678 return stride_per_key
16811679 elif stride_per_key_per_rank is not None :
1682- if stride_per_key_per_rank .dim () != 2 :
1683- # after permute the kjt could be empty
1684- return []
1685- rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1686- if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1687- pt2_checks_all_is_size (rt )
1688- return rt
1680+ return [sum (s ) for s in stride_per_key_per_rank ]
16891681 elif stride is not None :
16901682 return [stride ] * len (keys )
16911683 else :
@@ -1776,9 +1768,7 @@ def __init__(
17761768 lengths : Optional [torch .Tensor ] = None ,
17771769 offsets : Optional [torch .Tensor ] = None ,
17781770 stride : Optional [int ] = None ,
1779- stride_per_key_per_rank : Optional [
1780- Union [torch .IntTensor , List [List [int ]]]
1781- ] = None ,
1771+ stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
17821772 # Below exposed to ensure torch.script-able
17831773 stride_per_key : Optional [List [int ]] = None ,
17841774 length_per_key : Optional [List [int ]] = None ,
@@ -1800,14 +1790,8 @@ def __init__(
18001790 self ._lengths : Optional [torch .Tensor ] = lengths
18011791 self ._offsets : Optional [torch .Tensor ] = offsets
18021792 self ._stride : Optional [int ] = stride
1803- if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1804- # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1805- # does not take List[List[int]]
1806- assert not isinstance (stride_per_key_per_rank , list )
1807- self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1808- torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1809- if isinstance (stride_per_key_per_rank , list )
1810- else stride_per_key_per_rank
1793+ self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1794+ stride_per_key_per_rank
18111795 )
18121796 self ._stride_per_key : Optional [List [int ]] = stride_per_key
18131797 self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1818,8 +1802,6 @@ def __init__(
18181802 self ._inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = (
18191803 inverse_indices
18201804 )
1821- # this is only needed for torch.compile case
1822- self ._pt2_stride_per_key_per_rank : Optional [List [List [int ]]] = None
18231805
18241806 # legacy attribute, for backward compatabilibity
18251807 self ._variable_stride_per_key : Optional [bool ] = None
@@ -1835,6 +1817,10 @@ def _init_pt2_checks(self) -> None:
18351817 return
18361818 if self ._stride_per_key is not None :
18371819 pt2_checks_all_is_size (self ._stride_per_key )
1820+ if self ._stride_per_key_per_rank is not None :
1821+ # pyre-ignore [16]
1822+ for s in self ._stride_per_key_per_rank :
1823+ pt2_checks_all_is_size (s )
18381824
18391825 @staticmethod
18401826 def from_offsets_sync (
@@ -2044,7 +2030,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
20442030 kjt_stride , kjt_stride_per_key_per_rank = (
20452031 (stride_per_key [0 ], None )
20462032 if all (s == stride_per_key [0 ] for s in stride_per_key )
2047- else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2033+ else (None , [[ stride ] for stride in stride_per_key ] )
20482034 )
20492035 kjt = KeyedJaggedTensor (
20502036 keys = kjt_keys ,
@@ -2209,32 +2195,12 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
22092195 Returns:
22102196 List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
22112197 """
2212- # making a local reference to the class variable to make jit.script behave
2213- _stride_per_key_per_rank = self ._stride_per_key_per_rank
2214- if (
2215- not torch .jit .is_scripting ()
2216- and is_torchdynamo_compiling ()
2217- and _stride_per_key_per_rank is not None
2218- ):
2219- if self ._pt2_stride_per_key_per_rank is not None :
2220- return self ._pt2_stride_per_key_per_rank
2221- stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
2222- for stride_per_rank in stride_per_key_per_rank :
2223- pt2_checks_all_is_size (stride_per_rank )
2224- self ._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2225- return stride_per_key_per_rank
2226- return (
2227- []
2228- if _stride_per_key_per_rank is None
2229- else _stride_per_key_per_rank .tolist ()
2230- )
2198+ stride_per_key_per_rank = self ._stride_per_key_per_rank
2199+ return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
22312200
22322201 def variable_stride_per_key (self ) -> bool :
22332202 """
22342203 Returns whether the KeyedJaggedTensor has variable stride per key.
2235- NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2236- is not `None`. It might be assigned to False externally/intentionally, usually the
2237- `self._stride_per_key_per_rank` is trivial.
22382204
22392205 Returns:
22402206 bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2379,16 +2345,13 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23792345 start_offset = 0
23802346 _length_per_key = self .length_per_key ()
23812347 _offset_per_key = self .offset_per_key ()
2382- # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2383- _stride_per_key_per_rank = self ._stride_per_key_per_rank
23842348 for segment in segments :
23852349 end = start + segment
23862350 end_offset = _offset_per_key [end ]
23872351 keys : List [str ] = self ._keys [start :end ]
23882352 stride_per_key_per_rank = (
2389- _stride_per_key_per_rank [start :end , : ]
2353+ self . stride_per_key_per_rank () [start :end ]
23902354 if self .variable_stride_per_key ()
2391- and _stride_per_key_per_rank is not None
23922355 else None
23932356 )
23942357 if segment == len (self ._keys ):
@@ -2536,24 +2499,17 @@ def permute(
25362499
25372500 length_per_key = self .length_per_key ()
25382501 permuted_keys : List [str ] = []
2502+ permuted_stride_per_key_per_rank : List [List [int ]] = []
25392503 permuted_length_per_key : List [int ] = []
25402504 permuted_length_per_key_sum = 0
25412505 for index in indices :
25422506 key = self .keys ()[index ]
25432507 permuted_keys .append (key )
25442508 permuted_length_per_key .append (length_per_key [index ])
2545-
2546- stride_per_key = self ._stride_per_key
2547- permuted_stride_per_key = (
2548- [stride_per_key [i ] for i in indices ] if stride_per_key is not None else None
2549- )
2550-
2551- _stride_per_key_per_rank = self ._stride_per_key_per_rank
2552- permuted_stride_per_key_per_rank = (
2553- _stride_per_key_per_rank [indices , :]
2554- if self .variable_stride_per_key () and _stride_per_key_per_rank is not None
2555- else None
2556- )
2509+ if self .variable_stride_per_key ():
2510+ permuted_stride_per_key_per_rank .append (
2511+ self .stride_per_key_per_rank ()[index ]
2512+ )
25572513
25582514 permuted_length_per_key_sum = sum (permuted_length_per_key )
25592515 if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2605,16 +2561,18 @@ def permute(
26052561 self .weights_or_none (),
26062562 permuted_length_per_key_sum ,
26072563 )
2608-
2564+ stride_per_key_per_rank = (
2565+ permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2566+ )
26092567 kjt = KeyedJaggedTensor (
26102568 keys = permuted_keys ,
26112569 values = permuted_values ,
26122570 weights = permuted_weights ,
26132571 lengths = permuted_lengths .view (- 1 ),
26142572 offsets = None ,
26152573 stride = self ._stride ,
2616- stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2617- stride_per_key = permuted_stride_per_key ,
2574+ stride_per_key_per_rank = stride_per_key_per_rank ,
2575+ stride_per_key = None ,
26182576 length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
26192577 lengths_offset_per_key = None ,
26202578 offset_per_key = None ,
@@ -2933,7 +2891,7 @@ def dist_init(
29332891
29342892 if variable_stride_per_key :
29352893 assert stride_per_rank_per_key is not None
2936- stride_per_key_per_rank : torch .Tensor = stride_per_rank_per_key .view (
2894+ stride_per_key_per_rank_tensor : torch .Tensor = stride_per_rank_per_key .view (
29372895 num_workers , len (keys )
29382896 ).T .cpu ()
29392897
@@ -2970,18 +2928,23 @@ def dist_init(
29702928 weights ,
29712929 )
29722930
2973- if stride_per_key_per_rank .numel () == 0 :
2974- stride_per_key_per_rank = torch .zeros (
2975- (len (keys ), 1 ), device = "cpu" , dtype = torch .int64
2976- )
2931+ stride_per_key_per_rank = torch .jit .annotate (
2932+ List [List [int ]], stride_per_key_per_rank_tensor .tolist ()
2933+ )
29772934
2935+ if not stride_per_key_per_rank :
2936+ stride_per_key_per_rank = [[0 ]] * len (keys )
29782937 if stagger > 1 :
2938+ stride_per_key_per_rank_stagger : List [List [int ]] = []
29792939 local_world_size = num_workers // stagger
2980- indices = [
2981- list (range (i , num_workers , local_world_size ))
2982- for i in range (local_world_size )
2983- ]
2984- stride_per_key_per_rank = stride_per_key_per_rank [:, indices ]
2940+ for i in range (len (keys )):
2941+ stride_per_rank_stagger : List [int ] = []
2942+ for j in range (local_world_size ):
2943+ stride_per_rank_stagger .extend (
2944+ stride_per_key_per_rank [i ][j ::local_world_size ]
2945+ )
2946+ stride_per_key_per_rank_stagger .append (stride_per_rank_stagger )
2947+ stride_per_key_per_rank = stride_per_key_per_rank_stagger
29852948
29862949 kjt = KeyedJaggedTensor (
29872950 keys = keys ,
0 commit comments