[Fix] expand axes for dimension with integer indices in mlx_slice_update #1035
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
Fix #1015
For dimensions with integer indices,
src
should squeeze these dimensions.For example, an array
a = mx.zeros(5, 4, 3)
indexed bya[:, 0]
should end up with shape (5, 3), and the second dimension is squeezed, as the implementation inmlx_get_item
.However it's a bit complex for
mlx_set_item
since the array returned bymlx_slice_update
should have the original, unsqueezed shape. To achieve the same effect and make it easier, we could expand new axes in the update array on these dimensions that are supposed to be squeezed.For example, to update
a[:, 0] = mx.ones((5, 3))
, instead of squeezing a to shape (5, 3), we expand a new axis for the update array, so its shape becomes (5, 1, 3), which is broadcast compatible with a's unsqueezed shape (5, 4, 3). In other words,a[:, 0] = mx.ones((5, 3))
now achieve the same effect asa[:, 0:1] = mx.ones((5, 3))[:, None]
I'm a bit unsure if this is the correct way to implement it, and I don't really understand the old dimension expansion logic
maybe @jagrit06 can explain it a bit? I guess this tries to do something similar but can't see how it's done.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes