You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
convert stride_per_key_per_rank to tensor inside KJT (#2959)
Summary:
# context
* this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai))
* previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export)
* this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR.
# equivalency
* to check if `self._stride_per_key_per_rank` is `None`
this logic is used to differentiate variable_batch case, and should have the same behavior after this diff
* to use `self._stride_per_key_per_rank` as `List[List[int]]`
most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same
NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable.
Differential Revision: D74366343
0 commit comments