Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Index #80

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion basalt/autograd/ops/mlops.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from math.limit import min_finite, max_finite

from basalt import Tensor, TensorShape
from basalt.utils.tensorutils import elwise_transform
from basalt.utils.itertools import product
from basalt.autograd.attributes import Attribute, AttributeVector


Expand Down Expand Up @@ -491,4 +492,137 @@ struct SLICE:

Self.slice_kernel[ug_shape, t1_shape, steps, starts, ends, True](res_grad, ug)

return res_grad ^
return res_grad ^


struct INDEX:
@staticmethod
fn adjust_boundary(slice: Int, dim_size: Int) -> Int:
# Adjust negative indices & ensure they are within bounds.
var s = slice if slice >= 0 else dim_size + slice
return max(min(s, dim_size), 0)

@staticmethod
fn to_indeces(shape: TensorShape, attrs: AttributeVector) -> List[List[Int]]:
var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s")
var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i")

var rank = shape.rank()
var indeces = List[List[Int]]()
indeces.reserve(rank)

for dim in range(rank):
var temp = List[Int]()

# Option 1: Slice
if attrs[SLICE_LITERALS[dim]]:
var slice = attrs[SLICE_LITERALS[dim]].value().to_shape()
var step = slice[2] if slice.rank() == 3 else 1
for i in range(
start=Self.adjust_boundary(slice[0], shape[dim]),
end=Self.adjust_boundary(slice[1], shape[dim]),
step=step
):
temp.append(i)

# Option 2: Indeces
elif attrs[INDEX_LITERALS[dim]]:
var indeces = attrs[INDEX_LITERALS[dim]].value().to_shape()
for i in range(indeces.rank()):
temp.append(indeces[i])

# All indeces
else:
for i in range(shape[dim]):
temp.append(i)

indeces.append(temp)

return indeces ^

@staticmethod
fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape:
var indeces = Self.to_indeces(shape, attrs)
var rank = shape.rank()
var new_shape = List[Int]()
new_shape.reserve(rank)
for i in range(rank):
new_shape.append(len(indeces[i]))
return TensorShape(new_shape)

@staticmethod
fn map_indeces[
nelts: Int,
strides: TensorShape,
indeces: List[List[Int]],
](idx: Int) -> SIMD[DType.int64, nelts]:
alias indeces_product = product(indeces)

var temp = SIMD[DType.int64, nelts]()
for i in range(idx, idx + nelts):
var comb = indeces_product[i]
var flat_index = 0

for dim in range(len(comb)):
flat_index += comb[dim] * strides[dim]

temp[i % nelts] = flat_index

return temp

@staticmethod
fn forward[
t1_shape: TensorShape,
attributes: AttributeVector,
](inout res: Tensor[dtype], t1: Tensor[dtype]):
alias indeces = Self.to_indeces(t1_shape, attributes)
alias strides = t1_shape.strides()
alias total_length = len(product(indeces))

@parameter
fn vec_index[nelts: Int](i: Int):

res.store[nelts](i,
t1.data().gather(Self.map_indeces[nelts, strides, indeces](i))
)

vectorize[vec_index, nelts](total_length)


@staticmethod
fn backward[
ug_shape: TensorShape,
t1_shape: TensorShape,
attributes: AttributeVector = AttributeVector(),
](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]:
alias indeces = Self.to_indeces(t1_shape, attributes)
alias strides = t1_shape.strides()
alias total_length = len(product(indeces))

var res_grad = Tensor[dtype](t1_shape)

@parameter
fn vec_index[nelts: Int](i: Int):

var offset = Self.map_indeces[nelts, strides, indeces](i)

# res_grad.data().scatter(
# offset,
# res_grad.data().gather(offset) + ug.load[nelts](i),
# )

# NOTE: Scatter (reduce SUM) required
# When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1]
# The standard scatter will overwrite the values with overlapping indices.
# It doesn't accumulate index 0 twice as it should be: res_grad[0] += 1 + 1
# cfr. https://github.com/ml-explore/mlx/blob/main/mlx/backend/common/indexing.cpp#L256-L258
# cfr. https://github.com/modularml/mojo/blob/main/stdlib/src/sys/intrinsics.mojo#L903

# Workaround
var u = ug.load[nelts](i)
for j in range(nelts):
res_grad[int(offset[j])] += u[j]

vectorize[vec_index, nelts](total_length)

return res_grad^
9 changes: 8 additions & 1 deletion basalt/autograd/ops/ops.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from .basics import (
TRANSPOSE,
FMA,
)
from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE
from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE, INDEX
from .dynamics import CONCAT, SPLIT
from .conv import CONV2D
from .pool import MAXPOOL2D
Expand Down Expand Up @@ -61,6 +61,7 @@ struct OP(Stringable):
alias CONCAT = OP(23, "CONCAT", dynamic=True)
alias SPLIT = OP(24, "SPLIT", dynamic=True)
alias SLICE = OP(25, "SLICE")
alias INDEX = OP(26, "INDEX")

var id: UInt8
var name: Bytes[16]
Expand Down Expand Up @@ -135,6 +136,8 @@ fn static_result_shape(
return UNSQUEEZE.result_shape(t1_shape, attributes)
elif op == OP.SLICE:
return SLICE.result_shape(t1_shape, attributes)
elif op == OP.INDEX:
return INDEX.result_shape(t1_shape, attributes)
else:
print("[ERROR] Operator not found.")
return TensorShape(-1)
Expand Down Expand Up @@ -249,6 +252,8 @@ fn forward_op[
UNSQUEEZE.forward[t1_shape, attributes](res, t1)
elif op == OP.SLICE:
SLICE.forward[t1_shape, attributes](res, t1)
elif op == OP.INDEX:
INDEX.forward[t1_shape, attributes](res, t1)
else:
print("[ERROR] Operator not found.")

Expand Down Expand Up @@ -361,6 +366,8 @@ fn backward_op[
res_grad = UNSQUEEZE.backward[ug_shape, t1_shape](ug, t1)
elif op == OP.SLICE:
res_grad = SLICE.backward[ug_shape, t1_shape, attributes](ug, t1)
elif op == OP.INDEX:
res_grad = INDEX.backward[ug_shape, t1_shape, attributes](ug, t1)
else:
print("[ERROR] Operator not found.")
res_grad = Tensor[dtype](-1)
Expand Down
49 changes: 49 additions & 0 deletions basalt/utils/itertools.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

@value
struct _ProductIterator(Sized):
var lists: List[List[Int]]
var _current: Int
var _iters: Int

@always_inline("nodebug")
fn __init__(inout self, lists: List[List[Int]]):
self.lists = lists
self._current = 0

self._iters = 1
for lst in self.lists:
self._iters *= len(lst[])

@always_inline("nodebug")
fn __len__(self) -> Int:
return self._iters

@always_inline("nodebug")
fn __iter__(self) -> Self:
return self

@always_inline("nodebug")
fn __next__(inout self) -> List[Int]:
self._current += 1
self._iters -= 1
return self._get_combination(self._current - 1)

@always_inline("nodebug")
fn _get_combination(self, current: Int) -> List[Int]:
var combination = List[Int]()
var count = current
for i in reversed(range(len(self.lists))):
var index = count % len(self.lists[i])
combination.append(self.lists[i][index])
count //= len(self.lists[i])
combination._reverse()
return combination ^

@always_inline("nodebug")
fn __getitem__(self, index: Int) -> List[Int]:
return self._get_combination(index)


@always_inline("nodebug")
fn product(lists: List[List[Int]]) -> _ProductIterator:
return _ProductIterator(lists)
55 changes: 55 additions & 0 deletions tests/mojo/test_mlops.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,59 @@ fn test_backward_SLICE_multiple_axes() raises:
](t1, ug, expected_ug)


fn test_INDEX() raises:
alias t1_shape = TensorShape(2, 3, 5)
var t = Tensor[dtype](t1_shape)
for i in range(t.num_elements()):
t[i] = i

# t[:, [0, 0], 0:5:2]
# TODO: need for a list attribute as this only supports to specify indeces of MAX_RANK
alias attr_1 = Attribute("dim_1i", TensorShape(0, 0))
alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2))

var expected = Tensor[dtype](2, 2, 3)
for i in range(2):
for j in range(2):
for k in range(3):
expected[i*2*3 + j*3 + k] = i * 3 * 5 + k * 2

test_unary_op[
OP.INDEX, t1_shape, AttributeVector(
attr_1,
attr_2,
)
](t, expected)


fn test_INDEX_backward() raises:
alias t1_shape = TensorShape(2, 3, 5)
var t = Tensor[dtype](t1_shape)
for i in range(t.num_elements()):
t[i] = i

alias attr_1 = Attribute("dim_1i", TensorShape(0, 0))
alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2))

alias ug_shape = TensorShape(2, 2, 3)
var ug = Tensor[dtype](ug_shape)
fill(ug, 1.0)

var expected = Tensor[dtype](t1_shape)
for i in range(2):
for j in range(2):
for k in range(3):
# NOTE: `+=` because selected indeces [0, 0] can repeat
expected[i * 3 * 5 + k * 2] += 1.0

test_unary_op_backward[
OP.INDEX, t1_shape, ug_shape, AttributeVector(
attr_1,
attr_2,
)
](t, ug, expected)


fn main():
try:
test_SIGMOID()
Expand All @@ -632,6 +685,7 @@ fn main():
test_SLICE_step()
test_SLICE_neg()
test_SLICE_multiple_axes()
test_INDEX()
except e:
print("[ERROR] Error in forward mlops")
print(e)
Expand All @@ -646,6 +700,7 @@ fn main():
test_backward_UNSQUEEZE()
test_backward_SLICE()
test_backward_SLICE_multiple_axes()
test_INDEX_backward()
except e:
print("[ERROR] Error in backward mlops")
print(e)
Expand Down