@@ -3,9 +3,7 @@ from basalt.nn import Tensor, TensorShape
3
3
from basalt.utils.tensorutils import fill, elwise_op, accumulate_op
4
4
from basalt.utils.math_util import add
5
5
6
- from utils.static_tuple import StaticTuple
7
-
8
- alias StaticIntTuple = StaticTuple[Int, _]
6
+ from utils.index import IndexList
9
7
10
8
11
9
fn generate_tensor (* shape : Int) -> Tensor[dtype]:
@@ -18,7 +16,7 @@ fn generate_tensor(*shape: Int) -> Tensor[dtype]:
18
16
19
17
fn generate_expected_tensor [
20
18
size : Int
21
- ](data : StaticIntTuple [size], * shape : Int) -> Tensor[dtype]:
19
+ ](data : IndexList [size], * shape : Int) -> Tensor[dtype]:
22
20
var A = Tensor[dtype](shape)
23
21
for i in range (size):
24
22
A[i] = data[i]
@@ -43,7 +41,7 @@ struct TransposeData:
43
41
@ staticmethod
44
42
fn generate_1_2dim_test_case () -> TransposeData:
45
43
var A = generate_tensor(2 , 3 )
46
- var expected = StaticIntTuple [6 ](1 , 4 , 2 , 5 , 3 , 6 )
44
+ var expected = IndexList [6 ](1 , 4 , 2 , 5 , 3 , 6 )
47
45
var tranpose_dims = VariadicList[Int](1 , 0 )
48
46
var B = generate_expected_tensor(expected, 3 , 2 )
49
47
@@ -52,7 +50,7 @@ struct TransposeData:
52
50
@ staticmethod
53
51
fn generate_2_2dim_test_case () -> TransposeData:
54
52
var A = generate_tensor(2 , 3 , 2 )
55
- var expected = StaticIntTuple [12 ](1 , 7 , 3 , 9 , 5 , 11 , 2 , 8 , 4 , 10 , 6 , 12 )
53
+ var expected = IndexList [12 ](1 , 7 , 3 , 9 , 5 , 11 , 2 , 8 , 4 , 10 , 6 , 12 )
56
54
var tranpose_dims = VariadicList[Int](2 , 1 , 0 )
57
55
var B = generate_expected_tensor(expected, 2 , 3 , 2 )
58
56
@@ -61,7 +59,7 @@ struct TransposeData:
61
59
@ staticmethod
62
60
fn generate_3_2dim_test_case () -> TransposeData:
63
61
var A = generate_tensor(2 , 3 , 2 , 3 )
64
- var expected = StaticIntTuple [36 ](
62
+ var expected = IndexList [36 ](
65
63
1 ,
66
64
2 ,
67
65
3 ,
@@ -107,7 +105,7 @@ struct TransposeData:
107
105
@ staticmethod
108
106
fn generate_4_2dim_test_case () -> TransposeData:
109
107
var A = generate_tensor(3 , 2 , 3 , 2 , 3 )
110
- var expected = StaticIntTuple [108 ](
108
+ var expected = IndexList [108 ](
111
109
1 ,
112
110
2 ,
113
111
3 ,
@@ -225,7 +223,7 @@ struct TransposeData:
225
223
@ staticmethod
226
224
fn generate_1_alldim_test_case () -> TransposeData:
227
225
var A = generate_tensor(2 , 3 , 2 , 3 )
228
- var expected = StaticIntTuple [36 ](
226
+ var expected = IndexList [36 ](
229
227
1 ,
230
228
4 ,
231
229
2 ,
@@ -271,7 +269,7 @@ struct TransposeData:
271
269
@ staticmethod
272
270
fn generate_2_alldim_test_case () -> TransposeData:
273
271
var A = generate_tensor(2 , 3 , 4 )
274
- var expected = StaticIntTuple [24 ](
272
+ var expected = IndexList [24 ](
275
273
1 ,
276
274
13 ,
277
275
2 ,
@@ -306,7 +304,7 @@ struct TransposeData:
306
304
@ staticmethod
307
305
fn generate_1_transpose_test_case () -> TransposeData:
308
306
var A = generate_tensor(2 , 3 , 2 , 3 )
309
- var expected = StaticIntTuple [36 ](
307
+ var expected = IndexList [36 ](
310
308
1 ,
311
309
19 ,
312
310
7 ,
@@ -369,7 +367,7 @@ struct PaddingData:
369
367
fn generate_1d_test_case_after () -> PaddingData:
370
368
var A = generate_tensor(2 )
371
369
372
- var expected = StaticIntTuple [4 ](1 , 2 , 0 , 0 )
370
+ var expected = IndexList [4 ](1 , 2 , 0 , 0 )
373
371
var pad_with = List[Int]()
374
372
pad_with.append(0 ) # before
375
373
pad_with.append(2 ) # after
@@ -382,7 +380,7 @@ struct PaddingData:
382
380
fn generate_1d_test_case_before_after () -> PaddingData:
383
381
var A = generate_tensor(3 )
384
382
385
- var expected = StaticIntTuple [6 ](0 , 0 , 1 , 2 , 3 , 0 )
383
+ var expected = IndexList [6 ](0 , 0 , 1 , 2 , 3 , 0 )
386
384
var pad_with = List[Int]()
387
385
pad_with.append(2 ) # before
388
386
pad_with.append(1 ) # after
@@ -395,7 +393,7 @@ struct PaddingData:
395
393
fn generate_2d_test_case () -> PaddingData:
396
394
var A = generate_tensor(2 , 2 )
397
395
398
- var expected = StaticIntTuple [45 ](
396
+ var expected = IndexList [45 ](
399
397
0 ,
400
398
0 ,
401
399
0 ,
@@ -456,7 +454,7 @@ struct PaddingData:
456
454
fn generate_3d_test_case_simple () -> PaddingData:
457
455
var A = generate_tensor(2 , 2 , 2 )
458
456
459
- var expected = StaticIntTuple [16 ](
457
+ var expected = IndexList [16 ](
460
458
0 , 0 , 1 , 2 , 3 , 4 , 0 , 0 , 0 , 0 , 5 , 6 , 7 , 8 , 0 , 0
461
459
)
462
460
var pad_with = List[Int]()
@@ -475,7 +473,7 @@ struct PaddingData:
475
473
fn generate_3d_test_case () -> PaddingData:
476
474
var A = generate_tensor(1 , 2 , 3 )
477
475
478
- var expected = StaticIntTuple [45 ](
476
+ var expected = IndexList [45 ](
479
477
0 ,
480
478
0 ,
481
479
0 ,
@@ -538,7 +536,7 @@ struct PaddingData:
538
536
fn generate_4d_test_case () -> PaddingData:
539
537
var A = generate_tensor(2 , 2 , 2 , 2 )
540
538
541
- var expected = StaticIntTuple [81 ](
539
+ var expected = IndexList [81 ](
542
540
1 ,
543
541
2 ,
544
542
0 ,
@@ -663,7 +661,7 @@ struct SumMeanStdData:
663
661
var A = generate_tensor(3 , 4 , 5 )
664
662
var axis = 0
665
663
666
- var expected_sum = StaticIntTuple [20 ](
664
+ var expected_sum = IndexList [20 ](
667
665
63 ,
668
666
66 ,
669
667
69 ,
@@ -686,7 +684,7 @@ struct SumMeanStdData:
686
684
120 ,
687
685
)
688
686
689
- var expected_mean = StaticIntTuple [20 ](
687
+ var expected_mean = IndexList [20 ](
690
688
21 ,
691
689
22 ,
692
690
23 ,
@@ -722,7 +720,7 @@ struct SumMeanStdData:
722
720
var A = generate_tensor(3 , 4 , 5 )
723
721
var axis = 1
724
722
725
- var expected_sum = StaticIntTuple [15 ](
723
+ var expected_sum = IndexList [15 ](
726
724
34 ,
727
725
38 ,
728
726
42 ,
@@ -740,7 +738,7 @@ struct SumMeanStdData:
740
738
210 ,
741
739
)
742
740
743
- var expected_mean = StaticIntTuple [15 ](
741
+ var expected_mean = IndexList [15 ](
744
742
8 ,
745
743
9 ,
746
744
10 ,
@@ -772,7 +770,7 @@ struct SumMeanStdData:
772
770
var A = generate_tensor(3 , 4 , 5 )
773
771
var axis = 2
774
772
775
- var expected_sum = StaticIntTuple [12 ](
773
+ var expected_sum = IndexList [12 ](
776
774
15 ,
777
775
40 ,
778
776
65 ,
@@ -787,7 +785,7 @@ struct SumMeanStdData:
787
785
290 ,
788
786
)
789
787
790
- var expected_mean = StaticIntTuple [12 ](
788
+ var expected_mean = IndexList [12 ](
791
789
3 ,
792
790
8 ,
793
791
13 ,
0 commit comments