-
Notifications
You must be signed in to change notification settings - Fork 516
convert stride_per_key_per_rank to tensor inside KJT #2959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pull request was exported from Phabricator. Differential Revision: D74366343 |
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Differential Revision: D74366343
d9ba5cf
to
1ce1f10
Compare
This pull request was exported from Phabricator. Differential Revision: D74366343 |
1ce1f10
to
c195982
Compare
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Reviewed By: jd7-tr Differential Revision: D74366343
This pull request was exported from Phabricator. Differential Revision: D74366343 |
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Reviewed By: jd7-tr Differential Revision: D74366343
c195982
to
7b44f11
Compare
This pull request was exported from Phabricator. Differential Revision: D74366343 |
7b44f11
to
bae3f97
Compare
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Differential Revision: D74366343
This pull request was exported from Phabricator. Differential Revision: D74366343 |
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Differential Revision: D74366343
bae3f97
to
808901b
Compare
This pull request was exported from Phabricator. Differential Revision: D74366343 |
808901b
to
5a71b86
Compare
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Differential Revision: D74366343
This pull request was exported from Phabricator. Differential Revision: D74366343 |
Summary: Pull Request resolved: pytorch#2959 # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable. Differential Revision: D74366343 Reviewed By: jd7-tr
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: currently this `self._stride_per_key_per_rank` tensor is always on CPU since it's effective the meta data of a KJT. However, ideally it should be on GPU side since it's after input_dist and we'll should avoid move it to cpu unless really need it. Reviewed By: jd7-tr Differential Revision: D74366343
5a71b86
to
bd7bf5e
Compare
This pull request was exported from Phabricator. Differential Revision: D74366343 |
Summary: # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: currently this `self._stride_per_key_per_rank` tensor is always on CPU since it's effective the meta data of a KJT. However, ideally it should be on GPU side since it's after input_dist and we'll should avoid move it to cpu unless really need it. Reviewed By: jd7-tr Differential Revision: D74366343
bd7bf5e
to
f2bdf5e
Compare
This pull request was exported from Phabricator. Differential Revision: D74366343 |
Summary:
context
stride_per_key_per_rank
variable isList[List[int]] | None
which can't be handled correctly in PT2 IR (torch.export)_stride_per_key_per_rank
astorch.IntTensor | None
so it would be compatible with PT2 IR.equivalency
self._stride_per_key_per_rank
isNone
this logic is used to differentiate variable_batch case, and should have the same behavior after this diff
self._stride_per_key_per_rank
asList[List[int]]
most of the callsite use the function to get the list:
def stride_per_key_per_rank(self) -> List[List[int]]:
, and this function is modified to covert thetorch.IntTensor
to list as_stride_per_key_per_rank.tolist()
, the results should be the sameNOTE: this
self. _stride_per_key_per_rank.tolist()
tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like.to(...)
,record_stream()
, etc. should in general avoid altering this variable.Differential Revision: D74366343