@@ -80,10 +80,6 @@ def _to_offsets(lengths: torch.Tensor) -> torch.Tensor:
8080 return torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
8181
8282
83- def _to_lengths (offsets : torch .Tensor ) -> torch .Tensor :
84- return offsets [1 :] - offsets [:- 1 ]
85-
86-
8783@torch .jit .script_if_tracing
8884def _batched_lengths_to_offsets (lengths : torch .Tensor ) -> torch .Tensor :
8985 (f , b ) = lengths .shape
@@ -1452,33 +1448,6 @@ def _maybe_compute_kjt_to_jt_dict(
14521448 return _jt_dict
14531449
14541450
1455- @torch .fx .wrap
1456- def _merge_weights_or_none (
1457- a_weights : Optional [torch .Tensor ],
1458- b_weights : Optional [torch .Tensor ],
1459- ) -> Optional [torch .Tensor ]:
1460- assert not (
1461- (a_weights is None ) ^ (b_weights is None )
1462- ), "Can only merge weighted or unweighted KJTs."
1463- if a_weights is None :
1464- return None
1465- # pyre-ignore[6]
1466- return torch .cat ([a_weights , b_weights ], dim = 0 )
1467-
1468-
1469- @torch .fx .wrap
1470- def _strides_from_kjt (
1471- kjt : "KeyedJaggedTensor" ,
1472- ) -> Tuple [Optional [int ], Optional [List [List [int ]]]]:
1473- stride , stride_per_key_per_rank = (
1474- (None , kjt .stride_per_key_per_rank ())
1475- if kjt .variable_stride_per_key ()
1476- else (kjt .stride (), None )
1477- )
1478-
1479- return stride , stride_per_key_per_rank
1480-
1481-
14821451@torch .fx .wrap
14831452def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
14841453 # empty like function fx wrapped, also avoids device hardcoding
@@ -1684,18 +1653,6 @@ def _maybe_compute_stride_per_key(
16841653 return None
16851654
16861655
1687- def _maybe_compute_variable_stride_per_key (
1688- variable_stride_per_key : Optional [bool ],
1689- stride_per_key_per_rank : Optional [List [List [int ]]],
1690- ) -> bool :
1691- if variable_stride_per_key is not None :
1692- return variable_stride_per_key
1693- elif stride_per_key_per_rank is not None :
1694- return True
1695- else :
1696- return False
1697-
1698-
16991656class KeyedJaggedTensor (Pipelineable , metaclass = JaggedTensorMeta ):
17001657 """Represents an (optionally weighted) keyed jagged tensor.
17011658
0 commit comments