Skip to content

Commit

Permalink
fix for #76 (#78)
Browse files Browse the repository at this point in the history
- fix array indexing issue

* adopt changes from ml-explore/mlx#1053
  • Loading branch information
davidkoski authored May 2, 2024
1 parent b43bdff commit e0a2918
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
50 changes: 33 additions & 17 deletions Source/MLX/MLXArray+Indexing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,11 @@ func updateSlice(
return nil
}

// Can't route to slice update if any arrays are present
if operations.contains(where: { $0.isArray }) {
return nil
}

// Remove leading singletons dimensions from the update
var update = update.leadingSingletonDimensionsRemoved(stream: stream)

Expand All @@ -858,54 +863,65 @@ func updateSlice(
return result
}

// Can't route to slice update if any arrays are present
if operations.contains(where: { $0.isArray }) {
return nil
}

// Expand ellipses into a series of ':' (full slice) slices
let operations = expandEllipsisOperations(shape: src.shape.asInt32, operations: operations)

// If no non-None indices return the broadcasted update
if countNonNewAxisOperations(operations) == 0 {
let nonNewAxisOperationCount = countNonNewAxisOperations(operations)
if nonNewAxisOperationCount == 0 {
return broadcast(update, to: src.shape)
}

// Process entries
var updateExpandDimensions = [Int]()
var axis = 0
for item in operations {
var updateReshape = [Int](repeating: 0, count: src.ndim)
var axis = src.ndim - 1
var updateAxis = update.ndim - 1

while axis >= nonNewAxisOperationCount {
if updateAxis >= 0 {
updateReshape[axis] = update.dim(updateAxis)
updateAxis -= 1
} else {
updateReshape[axis] = 1
}
axis -= 1
}

for item in operations.reversed() {
switch item {
case .ellipsis, .array:
// these were replaced or rejected earlier
fatalError("unexpected item \(item) in updateSlice")

case .index(let index):
let size = src.dim(axis).int32
let index = index < 0 ? index + size : index
starts[axis] = index
ends[axis] = index + 1
if ndim - axis < update.ndim {
updateExpandDimensions.append(axis - ndim)
}

axis += 1
updateReshape[axis] = 1
axis -= 1

case .slice(let slice):
let size = src.dim(axis).int32
starts[axis] = slice.start(size)
ends[axis] = slice.end(size)
strides[axis] = slice.stride

axis += 1
if updateAxis >= 0 {
updateReshape[axis] = update.dim(updateAxis)
updateAxis -= 1
} else {
updateReshape[axis] = 1
}
axis -= 1

case .newAxis:
break
}
}

if !updateExpandDimensions.isEmpty {
update = update.expandedDimensions(axes: updateExpandDimensions, stream: stream)
}
update = reshaped(update, updateReshape)

let result = MLXArray(
mlx_slice_update(
Expand Down
22 changes: 22 additions & 0 deletions Tests/MLXTests/MLXArray+IndexingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -640,4 +640,26 @@ class MLXArrayIndexingTests: XCTestCase {
[.stride(by: -1), ..<2, i, 2..., .ellipsis, .newAxis, .stride(by: 2)], 128142)
}

public func testSliceWithBroadcast() {
// https://github.com/ml-explore/mlx-swift/issues/76

let a = MLXArray.ones([2, 6, 6, 6])
let b = MLXArray.zeros([3, 4, 4, 4])

b[0, 0 ..< 4, 3, 0 ..< 4] = a[0, 1 ..< 5, 5, 1 ..< 5]

let expected = MLXArray(
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
], [3, 4, 4, 4])

assertEqual(b, expected)
}

}

0 comments on commit e0a2918

Please sign in to comment.