Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] expand axes for dimension with integer indices in mlx_slice_update #1035

Merged
merged 6 commits into from
Apr 29, 2024

Conversation

PRESIDENT810
Copy link
Contributor

Proposed changes

Fix #1015

For dimensions with integer indices, src should squeeze these dimensions.
For example, an array a = mx.zeros(5, 4, 3) indexed by a[:, 0] should end up with shape (5, 3), and the second dimension is squeezed, as the implementation in mlx_get_item.

However it's a bit complex for mlx_set_item since the array returned by mlx_slice_update should have the original, unsqueezed shape. To achieve the same effect and make it easier, we could expand new axes in the update array on these dimensions that are supposed to be squeezed.
For example, to update a[:, 0] = mx.ones((5, 3)), instead of squeezing a to shape (5, 3), we expand a new axis for the update array, so its shape becomes (5, 1, 3), which is broadcast compatible with a's unsqueezed shape (5, 4, 3). In other words, a[:, 0] = mx.ones((5, 3)) now achieve the same effect as a[:, 0:1] = mx.ones((5, 3))[:, None]

I'm a bit unsure if this is the correct way to implement it, and I don't really understand the old dimension expansion logic

if (src.ndim() - ax < up.ndim()) {
    upd_expand_dims.push_back(ax - src.ndim());
}

maybe @jagrit06 can explain it a bit? I guess this tries to do something similar but can't see how it's done.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@PRESIDENT810 PRESIDENT810 changed the title Fix/1015 [Fix] expand axes for dimension with integer indices in mlx_slice_update Apr 26, 2024
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!!

@awni
Copy link
Member

awni commented Apr 29, 2024

Closes #1050

@awni awni merged commit 490c0c4 into ml-explore:main Apr 29, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] ValueError with mx.array index
2 participants