Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs. #2925
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
See D73051959 for context.
Update the
_maybe_compute_stride_kjt
logic to calculate stride based off ofinverse_indices
for VBE KJTs.Currently, stride of VBE KJT with
stride_per_key_per_rank
is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off ofinverse_indices
. This causes issues in IR module serialization: [debug doc].(https://docs.google.com/document/d/1yQhI484cgVloSqIBPAeTQhzfb3ltjvMRiLaQDceHGOU/edit?tab=t.0#heading=h.c66chahhl8df).Differential Revision: D73824764