Bringing Tensor::gather()
behavior closer to torch.gather()
#2567
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Definition:
In our case we were effectively constraining the requirement to:
index.size(d) != input.size(d); where d != dim
This PR removes the
!=
constraint and allows the op to go through as long as:index.dims()[d] <= input.dims()[d]; where d != dim
Added tests with data generated and validated against torch's ScatterGather tests.
Tested on
arm64
,x86_64
,cuda
andmetal
. I don't have access tomkl
and others, so unsure if thekernels
hold true!