Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Commit 8b46a7c

Browse files
committed
Updated test_attributes
1 parent 6b4d80d commit 8b46a7c

File tree

4 files changed

+38
-31
lines changed

4 files changed

+38
-31
lines changed

basalt/autograd/attributes.mojo

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ struct Attribute(Stringable, CollectionElement):
124124
for i in range(self.size):
125125
self.data_shape[i] = value[i]
126126

127+
fn __init__(inout self, name: String, value: StaticTuple[Int, _]):
128+
self.data_shape = IndexList[MAX_RANK]()
129+
self.name = Bytes[MAX_NAME_CHARS](name)
130+
self.data = Bytes[MAX_DATA_BYTES]()
131+
self.type = AttributeType.INTS
132+
self.size = len(value)
133+
134+
for i in range(self.size):
135+
self.data_shape[i] = value[i]
136+
127137
fn __init__[dtype: DType](inout self, name: String, value: Scalar[dtype]):
128138
constrained[dtype.is_numeric(), "Attribute value must be numeric."]()
129139

tests/mojo/test_attributes.mojo

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from testing import assert_equal, assert_true
2+
from utils.index import IndexList
23

34
from basalt.nn import TensorShape
45
from basalt.autograd.attributes import Attribute
@@ -32,7 +33,7 @@ fn test_attribute_tensor_shape() raises:
3233

3334

3435
fn test_attribute_static_int_tuple() raises:
35-
alias value: StaticIntTuple[7] = StaticIntTuple[7](1, 2, 3, 4, 5, 6, 7)
36+
alias value: IndexList[7] = IndexList[7](1, 2, 3, 4, 5, 6, 7)
3637
alias a = Attribute(name="test", value=value)
3738

3839
assert_true(a.to_static[7]() == value)

tests/mojo/test_tensorutils.mojo

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ from random import rand
22
from testing import assert_equal, assert_almost_equal
33
from math import sqrt, exp
44

5-
from utils.static_tuple import StaticTuple
6-
7-
alias StaticIntTuple = StaticTuple[Int, _]
5+
from utils.index import IndexList
86

97
from basalt import dtype, nelts
108
from basalt.autograd.ops.matmul import dot
@@ -332,27 +330,27 @@ fn test_max() raises:
332330
@parameter
333331
fn fill_tensor[
334332
size: Int
335-
](inout tensor: Tensor[dtype], values: StaticIntTuple[size]):
333+
](inout tensor: Tensor[dtype], values: IndexList[size]):
336334
for i in range(tensor.num_elements()):
337335
tensor[i] = values[i]
338336

339337
var tensor_max_axis_0 = Tensor[dtype](get_reduce_shape(t.shape(), axis=0))
340338
tmax(tensor_max_axis_0, t, axis=0)
341-
var expected_max_axis_0_temp = StaticIntTuple[6](7, 8, 9, 10, 11, 12)
339+
var expected_max_axis_0_temp = IndexList[6](7, 8, 9, 10, 11, 12)
342340
var expected_max_axis_0 = Tensor[dtype](1, 3, 2)
343341
fill_tensor(expected_max_axis_0, expected_max_axis_0_temp)
344342
assert_tensors_equal(tensor_max_axis_0, expected_max_axis_0)
345343

346344
var tensor_max_axis_1 = Tensor[dtype](get_reduce_shape(t.shape(), axis=1))
347345
tmax(tensor_max_axis_1, t, axis=1)
348-
var expected_max_axis_1_temp = StaticIntTuple[4](5, 6, 11, 12)
346+
var expected_max_axis_1_temp = IndexList[4](5, 6, 11, 12)
349347
var expected_max_axis_1 = Tensor[dtype](2, 1, 2)
350348
fill_tensor(expected_max_axis_1, expected_max_axis_1_temp)
351349
assert_tensors_equal(tensor_max_axis_1, expected_max_axis_1)
352350

353351
var tensor_max_axis_2 = Tensor[dtype](get_reduce_shape(t.shape(), axis=2))
354352
tmax(tensor_max_axis_2, t, axis=2)
355-
var expected_max_axis_2_temp = StaticIntTuple[6](2, 4, 6, 8, 10, 12)
353+
var expected_max_axis_2_temp = IndexList[6](2, 4, 6, 8, 10, 12)
356354
var expected_max_axis_2 = Tensor[dtype](2, 3, 1)
357355
fill_tensor(expected_max_axis_2, expected_max_axis_2_temp)
358356
assert_tensors_equal(tensor_max_axis_2, expected_max_axis_2)

tests/mojo/test_tensorutils_data.mojo

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ from basalt.nn import Tensor, TensorShape
33
from basalt.utils.tensorutils import fill, elwise_op, accumulate_op
44
from basalt.utils.math_util import add
55

6-
from utils.static_tuple import StaticTuple
7-
8-
alias StaticIntTuple = StaticTuple[Int, _]
6+
from utils.index import IndexList
97

108

119
fn generate_tensor(*shape: Int) -> Tensor[dtype]:
@@ -18,7 +16,7 @@ fn generate_tensor(*shape: Int) -> Tensor[dtype]:
1816

1917
fn generate_expected_tensor[
2018
size: Int
21-
](data: StaticIntTuple[size], *shape: Int) -> Tensor[dtype]:
19+
](data: IndexList[size], *shape: Int) -> Tensor[dtype]:
2220
var A = Tensor[dtype](shape)
2321
for i in range(size):
2422
A[i] = data[i]
@@ -43,7 +41,7 @@ struct TransposeData:
4341
@staticmethod
4442
fn generate_1_2dim_test_case() -> TransposeData:
4543
var A = generate_tensor(2, 3)
46-
var expected = StaticIntTuple[6](1, 4, 2, 5, 3, 6)
44+
var expected = IndexList[6](1, 4, 2, 5, 3, 6)
4745
var tranpose_dims = VariadicList[Int](1, 0)
4846
var B = generate_expected_tensor(expected, 3, 2)
4947

@@ -52,7 +50,7 @@ struct TransposeData:
5250
@staticmethod
5351
fn generate_2_2dim_test_case() -> TransposeData:
5452
var A = generate_tensor(2, 3, 2)
55-
var expected = StaticIntTuple[12](1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12)
53+
var expected = IndexList[12](1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12)
5654
var tranpose_dims = VariadicList[Int](2, 1, 0)
5755
var B = generate_expected_tensor(expected, 2, 3, 2)
5856

@@ -61,7 +59,7 @@ struct TransposeData:
6159
@staticmethod
6260
fn generate_3_2dim_test_case() -> TransposeData:
6361
var A = generate_tensor(2, 3, 2, 3)
64-
var expected = StaticIntTuple[36](
62+
var expected = IndexList[36](
6563
1,
6664
2,
6765
3,
@@ -107,7 +105,7 @@ struct TransposeData:
107105
@staticmethod
108106
fn generate_4_2dim_test_case() -> TransposeData:
109107
var A = generate_tensor(3, 2, 3, 2, 3)
110-
var expected = StaticIntTuple[108](
108+
var expected = IndexList[108](
111109
1,
112110
2,
113111
3,
@@ -225,7 +223,7 @@ struct TransposeData:
225223
@staticmethod
226224
fn generate_1_alldim_test_case() -> TransposeData:
227225
var A = generate_tensor(2, 3, 2, 3)
228-
var expected = StaticIntTuple[36](
226+
var expected = IndexList[36](
229227
1,
230228
4,
231229
2,
@@ -271,7 +269,7 @@ struct TransposeData:
271269
@staticmethod
272270
fn generate_2_alldim_test_case() -> TransposeData:
273271
var A = generate_tensor(2, 3, 4)
274-
var expected = StaticIntTuple[24](
272+
var expected = IndexList[24](
275273
1,
276274
13,
277275
2,
@@ -306,7 +304,7 @@ struct TransposeData:
306304
@staticmethod
307305
fn generate_1_transpose_test_case() -> TransposeData:
308306
var A = generate_tensor(2, 3, 2, 3)
309-
var expected = StaticIntTuple[36](
307+
var expected = IndexList[36](
310308
1,
311309
19,
312310
7,
@@ -369,7 +367,7 @@ struct PaddingData:
369367
fn generate_1d_test_case_after() -> PaddingData:
370368
var A = generate_tensor(2)
371369

372-
var expected = StaticIntTuple[4](1, 2, 0, 0)
370+
var expected = IndexList[4](1, 2, 0, 0)
373371
var pad_with = List[Int]()
374372
pad_with.append(0) # before
375373
pad_with.append(2) # after
@@ -382,7 +380,7 @@ struct PaddingData:
382380
fn generate_1d_test_case_before_after() -> PaddingData:
383381
var A = generate_tensor(3)
384382

385-
var expected = StaticIntTuple[6](0, 0, 1, 2, 3, 0)
383+
var expected = IndexList[6](0, 0, 1, 2, 3, 0)
386384
var pad_with = List[Int]()
387385
pad_with.append(2) # before
388386
pad_with.append(1) # after
@@ -395,7 +393,7 @@ struct PaddingData:
395393
fn generate_2d_test_case() -> PaddingData:
396394
var A = generate_tensor(2, 2)
397395

398-
var expected = StaticIntTuple[45](
396+
var expected = IndexList[45](
399397
0,
400398
0,
401399
0,
@@ -456,7 +454,7 @@ struct PaddingData:
456454
fn generate_3d_test_case_simple() -> PaddingData:
457455
var A = generate_tensor(2, 2, 2)
458456

459-
var expected = StaticIntTuple[16](
457+
var expected = IndexList[16](
460458
0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8, 0, 0
461459
)
462460
var pad_with = List[Int]()
@@ -475,7 +473,7 @@ struct PaddingData:
475473
fn generate_3d_test_case() -> PaddingData:
476474
var A = generate_tensor(1, 2, 3)
477475

478-
var expected = StaticIntTuple[45](
476+
var expected = IndexList[45](
479477
0,
480478
0,
481479
0,
@@ -538,7 +536,7 @@ struct PaddingData:
538536
fn generate_4d_test_case() -> PaddingData:
539537
var A = generate_tensor(2, 2, 2, 2)
540538

541-
var expected = StaticIntTuple[81](
539+
var expected = IndexList[81](
542540
1,
543541
2,
544542
0,
@@ -663,7 +661,7 @@ struct SumMeanStdData:
663661
var A = generate_tensor(3, 4, 5)
664662
var axis = 0
665663

666-
var expected_sum = StaticIntTuple[20](
664+
var expected_sum = IndexList[20](
667665
63,
668666
66,
669667
69,
@@ -686,7 +684,7 @@ struct SumMeanStdData:
686684
120,
687685
)
688686

689-
var expected_mean = StaticIntTuple[20](
687+
var expected_mean = IndexList[20](
690688
21,
691689
22,
692690
23,
@@ -722,7 +720,7 @@ struct SumMeanStdData:
722720
var A = generate_tensor(3, 4, 5)
723721
var axis = 1
724722

725-
var expected_sum = StaticIntTuple[15](
723+
var expected_sum = IndexList[15](
726724
34,
727725
38,
728726
42,
@@ -740,7 +738,7 @@ struct SumMeanStdData:
740738
210,
741739
)
742740

743-
var expected_mean = StaticIntTuple[15](
741+
var expected_mean = IndexList[15](
744742
8,
745743
9,
746744
10,
@@ -772,7 +770,7 @@ struct SumMeanStdData:
772770
var A = generate_tensor(3, 4, 5)
773771
var axis = 2
774772

775-
var expected_sum = StaticIntTuple[12](
773+
var expected_sum = IndexList[12](
776774
15,
777775
40,
778776
65,
@@ -787,7 +785,7 @@ struct SumMeanStdData:
787785
290,
788786
)
789787

790-
var expected_mean = StaticIntTuple[12](
788+
var expected_mean = IndexList[12](
791789
3,
792790
8,
793791
13,

0 commit comments

Comments
 (0)