Skip to content

Commit

Permalink
fix slice update indexing (#1053)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Apr 29, 2024
1 parent 490c0c4 commit 09f1777
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 23 deletions.
6 changes: 2 additions & 4 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
msg << "[metal::Device] Unable to build metal library from source" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
Expand All @@ -285,8 +284,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
Expand Down
5 changes: 2 additions & 3 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
Expand All @@ -452,7 +452,7 @@ array expand_dims(
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << ax << " for output array with "
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
Expand Down Expand Up @@ -591,7 +591,6 @@ array slice_update(
if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s);
}

return array(
src.shape(),
src.dtype(),
Expand Down
42 changes: 26 additions & 16 deletions python/src/indexing.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.

#include <numeric>
#include <sstream>

Expand Down Expand Up @@ -767,6 +766,14 @@ auto mlx_slice_update(
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
return std::make_pair(false, src);
}
if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx)) {
return std::make_pair(false, src);
}
}
}

// Should be able to route to slice update

Expand Down Expand Up @@ -804,14 +811,6 @@ auto mlx_slice_update(
// It must be a tuple
auto entries = nb::cast<nb::tuple>(obj);

// Can't route to slice update if any arrays are present
for (int i = 0; i < entries.size(); i++) {
auto idx = entries[i];
if (nb::isinstance<array>(idx)) {
return std::make_pair(false, src);
}
}

// Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);

Expand All @@ -828,9 +827,19 @@ auto mlx_slice_update(
}

// Process entries
std::vector<int> upd_expand_dims;
int ax = 0;
for (int i = 0; i < indices.size(); ++i) {
std::vector<int> up_reshape(src.ndim());
int ax = src.ndim() - 1;
int up_ax = up.ndim() - 1;
for (; ax >= non_none_indices; ax--) {
if (up_ax >= 0) {
up_reshape[ax] = up.shape(up_ax);
up_ax--;
} else {
up_reshape[ax] = 1;
}
}

for (int i = indices.size() - 1; i >= 0; --i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
get_slice_params(
Expand All @@ -839,18 +848,19 @@ auto mlx_slice_update(
strides[ax],
nb::cast<nb::slice>(pyidx),
src.shape(ax));
ax++;
up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
ax--;
} else if (nb::isinstance<nb::int_>(pyidx)) {
int st = nb::cast<int>(pyidx);
st = (st < 0) ? st + src.shape(ax) : st;
starts[ax] = st;
stops[ax] = st + 1;
upd_expand_dims.push_back(ax);
ax++;
up_reshape[ax] = 1;
ax--;
}
}

up = expand_dims(up, upd_expand_dims);
up = reshape(up, std::move(up_reshape));
auto out = slice_update(src, up, starts, stops, strides);
return std::make_pair(true, out);
}
Expand Down
8 changes: 8 additions & 0 deletions python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,14 @@ def check_slices(arr_np, update_np, *idx_np):
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
)

x = mx.zeros((2, 3, 4, 5, 3))
x[..., 0] = 1.0
self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))

x = mx.zeros((2, 3, 4, 5, 3))
x[:, 0] = 1.0
self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))

def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)
Expand Down

0 comments on commit 09f1777

Please sign in to comment.