Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Commit 7f78067

Browse files
committed
Updated test_tensorutils
1 parent 1af8d20 commit 7f78067

File tree

5 files changed

+22
-30
lines changed

5 files changed

+22
-30
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ flamegraph.svg
1515

1616
.magic
1717

18-
examples/data/yolov8n.onnx
18+
examples/data/yolov8n.onnx
19+
20+
*.DS_Store

basalt/nn/tensor.mojo

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ struct Tensor[dtype: DType](Stringable, Movable, CollectionElement):
113113
memset_zero(self._data, shape.num_elements())
114114
self._shape = shape
115115

116+
fn __init__(inout self, shapes: VariadicList[Int]):
117+
self._shape = TensorShape(shapes)
118+
self._data = UnsafePointer[Scalar[dtype]].alloc(self._shape.num_elements())
119+
memset_zero(self._data, self._shape.num_elements())
120+
116121
fn __init__(
117122
inout self, owned data: UnsafePointer[Scalar[dtype]], owned shape: TensorShape
118123
):

test_mult.mojo

Lines changed: 0 additions & 23 deletions
This file was deleted.

tests/mojo/test_tensorutils.mojo

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ from random import rand
22
from testing import assert_equal, assert_almost_equal
33
from math import sqrt, exp
44

5+
from utils.static_tuple import StaticTuple
6+
7+
alias StaticIntTuple = StaticTuple[Int, _]
8+
59
from basalt import dtype, nelts
610
from basalt.autograd.ops.matmul import dot
711
from basalt.utils.tensorutils import (
@@ -70,7 +74,7 @@ fn test_elwise_transform() raises:
7074
var D = Tensor[dtype](2, 10)
7175
fill(A, 4)
7276
fill(B, 2)
73-
fill(C, exp[dtype, 1](2))
77+
fill(C, exp(SIMD[dtype, 1](2.0)))
7478
fill(D, 7)
7579

7680
var A_res = Tensor[dtype](2, 10)
@@ -178,7 +182,7 @@ fn test_elwise_broadcast_tensor() raises:
178182
for i in range(40):
179183
for j in range(3):
180184
var index = (i % 4) + ((i // 4) * 12) + j * 4
181-
result1_expected[index] = 3.0 + (i + 1)
185+
result1_expected[index] = Float32(3.0) + (i + 1)
182186
assert_tensors_equal(result1, result1_expected)
183187

184188

@@ -197,7 +201,7 @@ fn test_sum_mean_std() raises:
197201
assert_equal(tensor_sum, s)
198202

199203
var tensor_mean = tmean(t)
200-
assert_equal(tensor_mean, s / 20)
204+
assert_equal(tensor_mean, Float32(s) / 20)
201205

202206
var tensor_std = tstd(t)
203207
var expected_std: Scalar[dtype] = 0
@@ -262,7 +266,7 @@ fn test_sum_mean_std_n() raises:
262266
assert_equal(tensor_sum, s)
263267

264268
var tensor_mean = tmean(t)
265-
assert_equal(tensor_mean, s / 60)
269+
assert_equal(tensor_mean, Float32(s) / 60)
266270

267271
var tensor_std = tstd(t)
268272
var expected_std: Scalar[dtype] = 0

tests/mojo/test_tensorutils_data.mojo

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from basalt import dtype, nelts
22
from basalt.nn import Tensor, TensorShape
3-
from basalt.utils.tensorutils import fill, elwise_op
3+
from basalt.utils.tensorutils import fill, elwise_op, accumulate_op
44
from basalt.utils.math_util import add
55

6+
from utils.static_tuple import StaticTuple
7+
8+
alias StaticIntTuple = StaticTuple[Int, _]
9+
610

711
fn generate_tensor(*shape: Int) -> Tensor[dtype]:
812
var A = Tensor[dtype](shape)
@@ -759,7 +763,7 @@ struct SumMeanStdData:
759763

760764
var B = generate_expected_tensor[15](expected_sum, 3, 1, 5)
761765
var C = generate_expected_tensor[15](expected_mean, 3, 1, 5)
762-
elwise_op[add](C, C, 0.5)
766+
accumulate_op[add](C, 0.5)
763767

764768
return SumMeanStdData(A, axis, B, C, expected_std)
765769

0 commit comments

Comments
 (0)