Skip to content

Commit

Permalink
adopt changes from ml-explore/mlx#1053
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkoski committed Apr 29, 2024
1 parent c6a171e commit c04c305
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions Source/MLX/MLXArray+Indexing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,11 @@ func updateSlice(
return nil
}

// Can't route to slice update if any arrays are present
if operations.contains(where: { $0.isArray }) {
return nil
}

// Remove leading singletons dimensions from the update
var update = update.leadingSingletonDimensionsRemoved(stream: stream)

Expand All @@ -858,55 +863,65 @@ func updateSlice(
return result
}

// Can't route to slice update if any arrays are present
if operations.contains(where: { $0.isArray }) {
return nil
}

// Expand ellipses into a series of ':' (full slice) slices
let operations = expandEllipsisOperations(shape: src.shape.asInt32, operations: operations)

// If no non-None indices return the broadcasted update
if countNonNewAxisOperations(operations) == 0 {
let nonNewAxisOperationCount = countNonNewAxisOperations(operations)
if nonNewAxisOperationCount == 0 {
return broadcast(update, to: src.shape)
}

// Process entries
var updateExpandDimensions = [Int]()
var axis = 0
for item in operations {
var updateReshape = [Int](repeating: 0, count: src.ndim)
var axis = src.ndim - 1
var updateAxis = update.ndim - 1

while axis >= nonNewAxisOperationCount {
if updateAxis >= 0 {
updateReshape[axis] = update.dim(updateAxis)
updateAxis -= 1
} else {
updateReshape[axis] = 1
}
axis -= 1
}

for item in operations.reversed() {
switch item {
case .ellipsis, .array:
// these were replaced or rejected earlier
fatalError("unexpected item \(item) in updateSlice")

case .index(let index):
let size = src.dim(axis).int32
let index = index < 0 ? index + size : index
starts[axis] = index
ends[axis] = index + 1
// if ndim - axis < update.ndim {
// updateExpandDimensions.append(axis - ndim)
// }
updateExpandDimensions.append(axis)

axis += 1
updateReshape[axis] = 1
axis -= 1

case .slice(let slice):
let size = src.dim(axis).int32
starts[axis] = slice.start(size)
ends[axis] = slice.end(size)
strides[axis] = slice.stride

axis += 1
if updateAxis >= 0 {
updateReshape[axis] = update.dim(updateAxis)
updateAxis -= 1
} else {
updateReshape[axis] = 1
}
axis -= 1

case .newAxis:
break
}
}

if !updateExpandDimensions.isEmpty {
update = update.expandedDimensions(axes: updateExpandDimensions, stream: stream)
}
update = reshaped(update, updateReshape)

let result = MLXArray(
mlx_slice_update(
Expand Down

0 comments on commit c04c305

Please sign in to comment.