This repository was archived by the owner on Mar 2, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathtest_attributes.mojo
130 lines (101 loc) · 3.64 KB
/
test_attributes.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from testing import assert_equal, assert_true
from utils.index import IndexList
from basalt.nn import TensorShape
from basalt.autograd.attributes import Attribute
fn test_attribute_key() raises:
alias a = Attribute(name="test", value=-1)
assert_true(str(a.name) == "test")
fn test_attribute_int() raises:
alias value: Int = 1
alias a = Attribute(name="test", value=value)
assert_true(a.to_int() == 1)
fn test_attribute_string() raises:
alias value: String = "hello"
alias a = Attribute(name="test", value=value)
assert_true(a.to_string() == value)
fn test_attribute_tensor_shape() raises:
alias value: TensorShape = TensorShape(1, 2, 3)
alias a = Attribute(name="test", value=value)
assert_true(a.to_shape() == value)
fn test_attribute_static_int_tuple() raises:
alias value: IndexList[7] = IndexList[7](1, 2, 3, 4, 5, 6, 7)
alias a = Attribute(name="test", value=value)
assert_true(a.to_static[7]() == value)
fn test_attribute_scalar() raises:
fn test_float32() raises:
alias value_a: Float32 = 1.23456
alias a1 = Attribute(name="test", value=value_a)
assert_true(
a1.to_scalar[DType.float32]() == value_a,
"Float32 scalar attribute failed",
)
alias value_b: Float32 = 65151
alias a2 = Attribute(name="test", value=value_b)
assert_true(
a2.to_scalar[DType.float32]() == value_b,
"Float32 scalar attribute failed",
)
fn test_float_literal() raises:
alias value_c: FloatLiteral = -1.1
alias a3 = Attribute(name="test", value=value_c)
assert_true(
a3.to_scalar[DType.float32]() == value_c,
"FloatLiteral scalar attribute failed",
)
fn test_float64() raises:
alias value_a: Float64 = -1.23456
alias a1 = Attribute(name="test", value=value_a)
assert_true(
a1.to_scalar[DType.float64]() == value_a,
"Float64 scalar attribute failed",
)
alias value_b: Float64 = 123456
alias a2 = Attribute(name="test", value=value_b)
assert_true(
a2.to_scalar[DType.float64]() == value_b,
"Float64 scalar attribute failed",
)
fn test_int32() raises:
alias value_a: Int32 = 666
alias a1 = Attribute(name="test", value=value_a)
assert_true(
a1.to_scalar[DType.int32]() == value_a,
"Int32 scalar attribute failed",
)
alias value_b: Int32 = -666
alias a2 = Attribute(name="test", value=value_b)
assert_true(
a2.to_scalar[DType.int32]() == value_b,
"Int32 scalar attribute failed",
)
fn test_attribute_small_scalar() raises:
alias value_a: Float32 = 1e-18
alias a = Attribute(name="test", value=value_a)
assert_true(
a.to_scalar[DType.float32]() == value_a,
"SMALL scalar attribute failed",
)
fn test_attribute_big_scalar() raises:
alias value_a: Float32 = 1e40
alias a = Attribute(name="test", value=value_a)
assert_true(
a.to_scalar[DType.float32]() == value_a,
"BIG scalar attribute failed",
)
test_float32()
test_float_literal()
test_float64()
test_int32()
test_attribute_small_scalar()
test_attribute_big_scalar()
fn main():
try:
test_attribute_key()
test_attribute_int()
test_attribute_string()
test_attribute_tensor_shape()
test_attribute_static_int_tuple()
test_attribute_scalar()
except e:
print("[ERROR] Error in attributes")
print(e)