-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathtest_pyrdl.py
321 lines (246 loc) · 9.95 KB
/
test_pyrdl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""Unit tests for IRDL."""
from dataclasses import dataclass
import pytest
from xdsl.ir import Attribute, Data, ParametrizedAttribute
from xdsl.irdl import (
AllOf,
AnyAttr,
AttrConstraint,
BaseAttr,
ConstraintContext,
EqAttrConstraint,
ParamAttrConstraint,
ParameterDef,
VarConstraint,
eq,
irdl_attr_definition,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException
@irdl_attr_definition
class BoolData(Data[bool]):
"""An attribute holding a boolean value."""
name = "bool"
@classmethod
def parse_parameter(cls, parser: AttrParser) -> bool:
raise NotImplementedError()
def print_parameter(self, printer: Printer):
printer.print_string(str(self.data))
@irdl_attr_definition
class IntData(Data[int]):
"""An attribute holding an integer value."""
name = "int"
@classmethod
def parse_parameter(cls, parser: AttrParser) -> int:
with parser.in_angle_brackets():
return parser.parse_integer()
def print_parameter(self, printer: Printer):
with printer.in_angle_brackets():
printer.print_string(str(self.data))
@irdl_attr_definition
class DoubleParamAttr(ParametrizedAttribute):
"""An attribute with two unbounded attribute parameters."""
name = "param"
param1: ParameterDef[Attribute]
param2: ParameterDef[Attribute]
def test_eq_attr_verify():
"""Check that an EqAttrConstraint verifies the expected attribute"""
bool_true = BoolData(True)
eq_true_constraint = EqAttrConstraint(bool_true)
eq_true_constraint.verify(bool_true, ConstraintContext())
def test_eq_attr_verify_wrong_parameters_fail():
"""
Check that an EqAttrConstraint fails to verify an attribute with different
parameters.
"""
bool_true = BoolData(True)
bool_false = BoolData(False)
eq_true_constraint = EqAttrConstraint(bool_true)
with pytest.raises(VerifyException) as e:
eq_true_constraint.verify(bool_false, ConstraintContext())
assert e.value.args[0] == (f"Expected attribute {bool_true} but got {bool_false}")
def test_eq_attr_verify_wrong_base_fail():
"""
Check that an EqAttrConstraint fails to verify an attribute with a
different base attribute.
"""
bool_true = BoolData(True)
int_zero = IntData(0)
eq_true_constraint = EqAttrConstraint(bool_true)
with pytest.raises(VerifyException) as e:
eq_true_constraint.verify(int_zero, ConstraintContext())
assert e.value.args[0] == (f"Expected attribute {bool_true} but got {int_zero}")
def test_base_attr_verify():
"""
Check that a BaseAttr constraint verifies an attribute with the expected
base attribute.
"""
eq_true_constraint = BaseAttr(BoolData)
eq_true_constraint.verify(BoolData(True), ConstraintContext())
eq_true_constraint.verify(BoolData(False), ConstraintContext())
def test_base_attr_verify_wrong_base_fail():
"""
Check that a BaseAttr constraint fails to verify an attribute with a
different base attribute.
"""
eq_true_constraint = BaseAttr(BoolData)
int_zero = IntData(0)
with pytest.raises(VerifyException) as e:
eq_true_constraint.verify(int_zero, ConstraintContext())
assert e.value.args[0] == (
f"{int_zero} should be of base attribute {BoolData.name}"
)
def test_any_attr_verify():
"""Check that an AnyAttr verifies any attribute."""
any_constraint = AnyAttr()
any_constraint.verify(BoolData(True), ConstraintContext())
any_constraint.verify(BoolData(False), ConstraintContext())
any_constraint.verify(IntData(0), ConstraintContext())
@dataclass(frozen=True)
class LessThan(AttrConstraint):
bound: int
def verify(
self,
attr: Attribute,
constraint_context: ConstraintContext,
) -> None:
if not isinstance(attr, IntData):
raise VerifyException(f"{attr} should be of base attribute {IntData.name}")
if attr.data >= self.bound:
raise VerifyException(f"{attr} should hold a value less than {self.bound}")
@dataclass(frozen=True)
class GreaterThan(AttrConstraint):
bound: int
def verify(
self,
attr: Attribute,
constraint_context: ConstraintContext | None = None,
) -> None:
if not isinstance(attr, IntData):
raise VerifyException(f"{attr} should be of base attribute {IntData.name}")
if attr.data <= self.bound:
raise VerifyException(
f"{attr} should hold a value greater than {self.bound}"
)
def test_anyof_verify():
"""
Check that an AnyOf constraint verifies if one of the constraints
verify.
"""
constraint = LessThan(0) | GreaterThan(10)
constraint.verify(IntData(-1), ConstraintContext())
constraint.verify(IntData(-10), ConstraintContext())
constraint.verify(IntData(11), ConstraintContext())
constraint.verify(IntData(100), ConstraintContext())
def test_anyof_verify_fail():
"""
Check that an AnyOf constraint fails to verify if none of the constraints
verify.
"""
constraint = LessThan(0) | GreaterThan(10)
zero = IntData(0)
ten = IntData(10)
with pytest.raises(VerifyException) as e:
constraint.verify(zero, ConstraintContext())
assert e.value.args[0] == f"Unexpected attribute {zero}"
with pytest.raises(VerifyException) as e:
constraint.verify(ten, ConstraintContext())
assert e.value.args[0] == f"Unexpected attribute {ten}"
def test_allof_verify():
"""
Check that an AllOf constraint verifies if all of the constraints
verify.
"""
constraint = AllOf((LessThan(10), GreaterThan(0)))
constraint.verify(IntData(1), ConstraintContext())
constraint.verify(IntData(9), ConstraintContext())
constraint.verify(IntData(5), ConstraintContext())
def test_allof_verify_fail():
"""
Check that an AllOf constraint fails to verify if one of the constraints
fails to verify.
"""
constraint = AllOf((LessThan(10), GreaterThan(0)))
with pytest.raises(VerifyException) as e:
constraint.verify(IntData(10), ConstraintContext())
assert e.value.args[0] == f"{IntData(10)} should hold a value less than 10"
with pytest.raises(VerifyException) as e:
constraint.verify(IntData(0), ConstraintContext())
assert e.value.args[0] == f"{IntData(0)} should hold a value greater than 0"
def test_allof_verify_multiple_failures():
"""
Check that an AllOf constraint provides verification info for all related constraints
even when one of them fails.
"""
constraint = AllOf((LessThan(5), GreaterThan(8)))
with pytest.raises(
VerifyException,
match=f"The following constraints were not satisfied:\n{IntData(7)} should "
f"hold a value less than 5\n{IntData(7)} should hold a value greater than 8",
):
constraint.verify(IntData(7), ConstraintContext())
def test_param_attr_verify():
bool_true = BoolData(True)
constraint = ParamAttrConstraint(
DoubleParamAttr, [EqAttrConstraint(bool_true), BaseAttr(IntData)]
)
constraint.verify(DoubleParamAttr([bool_true, IntData(0)]), ConstraintContext())
constraint.verify(DoubleParamAttr([bool_true, IntData(42)]), ConstraintContext())
def test_param_attr_verify_base_fail():
bool_true = BoolData(True)
constraint = ParamAttrConstraint(
DoubleParamAttr, [EqAttrConstraint(bool_true), BaseAttr(IntData)]
)
with pytest.raises(VerifyException) as e:
constraint.verify(bool_true, ConstraintContext())
assert e.value.args[0] == (
f"{bool_true} should be of base attribute {DoubleParamAttr.name}"
)
def test_param_attr_verify_params_num_params_fail():
bool_true = BoolData(True)
constraint = ParamAttrConstraint(DoubleParamAttr, [EqAttrConstraint(bool_true)])
attr = DoubleParamAttr([bool_true, IntData(0)])
with pytest.raises(VerifyException) as e:
constraint.verify(attr, ConstraintContext())
assert e.value.args[0] == ("1 parameters expected, but got 2")
def test_param_attr_verify_params_fail():
bool_true = BoolData(True)
bool_false = BoolData(False)
constraint = ParamAttrConstraint(
DoubleParamAttr, [EqAttrConstraint(bool_true), BaseAttr(IntData)]
)
with pytest.raises(VerifyException) as e:
constraint.verify(DoubleParamAttr([bool_true, bool_false]), ConstraintContext())
assert e.value.args[0] == (
f"{bool_false} should be of base attribute {IntData.name}"
)
with pytest.raises(VerifyException) as e:
constraint.verify(
DoubleParamAttr([bool_false, IntData(0)]), ConstraintContext()
)
assert e.value.args[0] == (f"Expected attribute {bool_true} but got {bool_false}")
def test_constraint_vars_success():
"""Test that VarConstraint verifier succeed when given the same attributes."""
constraint = VarConstraint("T", eq(BoolData(False)) | eq(IntData(0)))
constraint_context = ConstraintContext()
constraint.verify(BoolData(False), constraint_context)
constraint.verify(BoolData(False), constraint_context)
constraint_context = ConstraintContext()
constraint.verify(IntData(0), constraint_context)
constraint.verify(IntData(0), constraint_context)
def test_constraint_vars_fail_different():
"""Test that VarConstraint verifier fails when given different attributes."""
constraint = VarConstraint("T", eq(BoolData(False)) | eq(IntData(0)))
constraint_context = ConstraintContext()
constraint.verify(IntData(0), constraint_context)
with pytest.raises(VerifyException):
constraint.verify(BoolData(False), constraint_context)
def test_constraint_vars_fail_underlying_constraint():
"""
Test that VarConstraint verifier fails when given
attributes that fail the underlying constraint.
"""
constraint = VarConstraint("T", eq(BoolData(False)) | eq(IntData(0)))
with pytest.raises(VerifyException):
constraint.verify(IntData(1), ConstraintContext())