Skip to content

Commit 328a839

Browse files
committed
Add Tests for ExpressionSimplifier (Sonnet 4.5)
1 parent 0374a67 commit 328a839

File tree

1 file changed

+338
-0
lines changed

1 file changed

+338
-0
lines changed
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
package liquidjava.rj_language.opt;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import liquidjava.rj_language.ast.BinaryExpression;
6+
import liquidjava.rj_language.ast.Expression;
7+
import liquidjava.rj_language.ast.LiteralBoolean;
8+
import liquidjava.rj_language.ast.LiteralInt;
9+
import liquidjava.rj_language.ast.UnaryExpression;
10+
import liquidjava.rj_language.ast.Var;
11+
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
12+
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
13+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
14+
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
15+
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
16+
import org.junit.jupiter.api.Test;
17+
18+
/**
19+
* Test suite for expression simplification using constant propagation and folding
20+
*/
21+
class ExpressionSimplifierTest {
22+
23+
@Test
24+
void testNegation() {
25+
// Given: -a && a == 7
26+
// Expected: -7
27+
28+
Expression varA = new Var("a");
29+
Expression negA = new UnaryExpression("-", varA);
30+
Expression seven = new LiteralInt(7);
31+
Expression aEquals7 = new BinaryExpression(varA, "==", seven);
32+
Expression fullExpression = new BinaryExpression(negA, "&&", aEquals7);
33+
34+
// When
35+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
36+
37+
// Then
38+
assertNotNull(result, "Result should not be null");
39+
assertEquals("-7", result.getValue().toString(), "Expected result to be -7");
40+
41+
// 7 from variable a
42+
ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("a"));
43+
44+
// -7
45+
UnaryDerivationNode negation = new UnaryDerivationNode(val7, "-");
46+
ValDerivationNode expected = new ValDerivationNode(new LiteralInt(-7), negation);
47+
48+
// Compare the derivation trees
49+
assertDerivationEquals(expected, result, "");
50+
}
51+
52+
@Test
53+
void testSimpleAddition() {
54+
// Given: a + b && a == 3 && b == 5
55+
// Expected: 8 (3 + 5)
56+
57+
Expression varA = new Var("a");
58+
Expression varB = new Var("b");
59+
Expression addition = new BinaryExpression(varA, "+", varB);
60+
61+
Expression three = new LiteralInt(3);
62+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
63+
64+
Expression five = new LiteralInt(5);
65+
Expression bEquals5 = new BinaryExpression(varB, "==", five);
66+
67+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5);
68+
Expression fullExpression = new BinaryExpression(addition, "&&", conditions);
69+
70+
// When
71+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
72+
73+
// Then
74+
assertNotNull(result, "Result should not be null");
75+
assertEquals("8", result.getValue().toString(), "Expected result to be 8");
76+
77+
// 3 from variable a
78+
ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("a"));
79+
80+
// 5 from variable b
81+
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("b"));
82+
83+
// 3 + 5
84+
BinaryDerivationNode add3Plus5 = new BinaryDerivationNode(val3, val5, "+");
85+
ValDerivationNode expected = new ValDerivationNode(new LiteralInt(8), add3Plus5);
86+
87+
// Compare the derivation trees
88+
assertDerivationEquals(expected, result, "");
89+
}
90+
91+
@Test
92+
void testSimpleComparison() {
93+
// Given: (y || true) && !true && y == false
94+
// Expected: false (true && false)
95+
96+
Expression varY = new Var("y");
97+
Expression trueExp = new LiteralBoolean(true);
98+
Expression yOrTrue = new BinaryExpression(varY, "||", trueExp);
99+
100+
Expression notTrue = new UnaryExpression("!", trueExp);
101+
102+
Expression falseExp = new LiteralBoolean(false);
103+
Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp);
104+
105+
Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue);
106+
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse);
107+
108+
// When
109+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
110+
111+
// Then
112+
assertNotNull(result, "Result should not be null");
113+
assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean");
114+
assertFalse(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to befalse");
115+
116+
// (y || true) && y == false => false || true = true
117+
ValDerivationNode valFalseForY = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y"));
118+
ValDerivationNode valTrue1 = new ValDerivationNode(new LiteralBoolean(true), null);
119+
BinaryDerivationNode orFalseTrue = new BinaryDerivationNode(valFalseForY, valTrue1, "||");
120+
ValDerivationNode trueFromOr = new ValDerivationNode(new LiteralBoolean(true), orFalseTrue);
121+
122+
// !true = false
123+
ValDerivationNode valTrue2 = new ValDerivationNode(new LiteralBoolean(true), null);
124+
UnaryDerivationNode notOp = new UnaryDerivationNode(valTrue2, "!");
125+
ValDerivationNode falseFromNot = new ValDerivationNode(new LiteralBoolean(false), notOp);
126+
127+
// true && false = false
128+
BinaryDerivationNode andTrueFalse = new BinaryDerivationNode(trueFromOr, falseFromNot, "&&");
129+
ValDerivationNode falseFromFirstAnd = new ValDerivationNode(new LiteralBoolean(false), andTrueFalse);
130+
131+
// y == false
132+
ValDerivationNode valFalseForY2 = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y"));
133+
ValDerivationNode valFalse2 = new ValDerivationNode(new LiteralBoolean(false), null);
134+
BinaryDerivationNode compareFalseFalse = new BinaryDerivationNode(valFalseForY2, valFalse2, "==");
135+
ValDerivationNode trueFromCompare = new ValDerivationNode(new LiteralBoolean(true), compareFalseFalse);
136+
137+
// false && true = false
138+
BinaryDerivationNode finalAnd = new BinaryDerivationNode(falseFromFirstAnd, trueFromCompare, "&&");
139+
ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(false), finalAnd);
140+
141+
// Compare the derivation trees
142+
assertDerivationEquals(expected, result, "");
143+
}
144+
145+
@Test
146+
void testArithmeticWithConstants() {
147+
// Given: (a / b + (-5)) + x && a == 6 && b == 2
148+
// Expected: -2 + x (6 / 2 = 3, 3 + (-5) = -2)
149+
150+
Expression varA = new Var("a");
151+
Expression varB = new Var("b");
152+
Expression division = new BinaryExpression(varA, "/", varB);
153+
154+
Expression five = new LiteralInt(5);
155+
Expression negFive = new UnaryExpression("-", five);
156+
157+
Expression firstSum = new BinaryExpression(division, "+", negFive);
158+
Expression varX = new Var("x");
159+
Expression fullArithmetic = new BinaryExpression(firstSum, "+", varX);
160+
161+
Expression six = new LiteralInt(6);
162+
Expression aEquals6 = new BinaryExpression(varA, "==", six);
163+
164+
Expression two = new LiteralInt(2);
165+
Expression bEquals2 = new BinaryExpression(varB, "==", two);
166+
167+
Expression allConditions = new BinaryExpression(aEquals6, "&&", bEquals2);
168+
Expression fullExpression = new BinaryExpression(fullArithmetic, "&&", allConditions);
169+
170+
// When
171+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
172+
173+
// Then
174+
assertNotNull(result, "Result should not be null");
175+
assertNotNull(result.getValue(), "Result value should not be null");
176+
177+
String resultStr = result.getValue().toString();
178+
assertEquals("-2 + x", resultStr, "Expected result to be -2 + x");
179+
180+
// 6 from variable a
181+
ValDerivationNode val6 = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a"));
182+
183+
// 2 from variable b
184+
ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), new VarDerivationNode("b"));
185+
186+
// 6 / 2 = 3
187+
BinaryDerivationNode div6By2 = new BinaryDerivationNode(val6, val2, "/");
188+
ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), div6By2);
189+
190+
// -5 from unary negation of 5
191+
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), null);
192+
UnaryDerivationNode unaryNeg5 = new UnaryDerivationNode(val5, "-");
193+
ValDerivationNode valNeg5 = new ValDerivationNode(new LiteralInt(-5), unaryNeg5);
194+
195+
// 3 + (-5) = -2
196+
BinaryDerivationNode add3AndNeg5 = new BinaryDerivationNode(val3, valNeg5, "+");
197+
ValDerivationNode valNeg2 = new ValDerivationNode(new LiteralInt(-2), add3AndNeg5);
198+
199+
// x (variable with null origin)
200+
ValDerivationNode valX = new ValDerivationNode(new Var("x"), null);
201+
202+
// -2 + x
203+
BinaryDerivationNode addNeg2AndX = new BinaryDerivationNode(valNeg2, valX, "+");
204+
Expression expectedResultExpr = new BinaryExpression(new LiteralInt(-2), "+", new Var("x"));
205+
ValDerivationNode expected = new ValDerivationNode(expectedResultExpr, addNeg2AndX);
206+
207+
// Compare the derivation trees
208+
assertDerivationEquals(expected, result, "");
209+
}
210+
211+
@Test
212+
void testComplexArithmeticWithMultipleOperations() {
213+
// Given: (a * 2 + b - 3) == c && a == 5 && b == 7 && c == 14
214+
// Expected: (5 * 2 + 7 - 3) == 14 => 14 == 14 => true
215+
216+
Expression varA = new Var("a");
217+
Expression varB = new Var("b");
218+
Expression varC = new Var("c");
219+
220+
Expression two = new LiteralInt(2);
221+
Expression aTimes2 = new BinaryExpression(varA, "*", two);
222+
223+
Expression sum = new BinaryExpression(aTimes2, "+", varB);
224+
225+
Expression three = new LiteralInt(3);
226+
Expression arithmetic = new BinaryExpression(sum, "-", three);
227+
228+
Expression comparison = new BinaryExpression(arithmetic, "==", varC);
229+
230+
Expression five = new LiteralInt(5);
231+
Expression aEquals5 = new BinaryExpression(varA, "==", five);
232+
233+
Expression seven = new LiteralInt(7);
234+
Expression bEquals7 = new BinaryExpression(varB, "==", seven);
235+
236+
Expression fourteen = new LiteralInt(14);
237+
Expression cEquals14 = new BinaryExpression(varC, "==", fourteen);
238+
239+
Expression conj1 = new BinaryExpression(aEquals5, "&&", bEquals7);
240+
Expression allConditions = new BinaryExpression(conj1, "&&", cEquals14);
241+
Expression fullExpression = new BinaryExpression(comparison, "&&", allConditions);
242+
243+
// When
244+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
245+
246+
// Then
247+
assertNotNull(result, "Result should not be null");
248+
assertNotNull(result.getValue(), "Result value should not be null");
249+
assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean literal");
250+
assertTrue(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to be true");
251+
252+
// 5 * 2 + 7 - 3
253+
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
254+
ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null);
255+
BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*");
256+
ValDerivationNode val10 = new ValDerivationNode(new LiteralInt(10), mult5Times2);
257+
258+
ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
259+
BinaryDerivationNode add10Plus7 = new BinaryDerivationNode(val10, val7, "+");
260+
ValDerivationNode val17 = new ValDerivationNode(new LiteralInt(17), add10Plus7);
261+
262+
ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), null);
263+
BinaryDerivationNode sub17Minus3 = new BinaryDerivationNode(val17, val3, "-");
264+
ValDerivationNode val14Left = new ValDerivationNode(new LiteralInt(14), sub17Minus3);
265+
266+
// 14 from variable c
267+
ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
268+
269+
// 14 == 14
270+
BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "==");
271+
ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14);
272+
273+
// a == 5 => true
274+
ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
275+
ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null);
276+
BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "==");
277+
ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5);
278+
279+
// b == 7 => true
280+
ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
281+
ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null);
282+
BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "==");
283+
ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7);
284+
285+
// (a == 5) && (b == 7) => true
286+
BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&");
287+
ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB);
288+
289+
// c == 14 => true
290+
ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
291+
ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null);
292+
BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "==");
293+
ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14);
294+
295+
// ((a == 5) && (b == 7)) && (c == 14) => true
296+
BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&");
297+
ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC);
298+
299+
// 14 == 14 => true
300+
BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&");
301+
ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd);
302+
303+
// Compare the derivation trees
304+
assertDerivationEquals(expected, result, "");
305+
}
306+
307+
/**
308+
* Helper method to compare two derivation nodes recursively
309+
*/
310+
private void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) {
311+
if (expected == null && actual == null)
312+
return;
313+
314+
assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match");
315+
if (expected instanceof ValDerivationNode) {
316+
ValDerivationNode expectedVal = (ValDerivationNode) expected;
317+
ValDerivationNode actualVal = (ValDerivationNode) actual;
318+
assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(),
319+
message + ": values should match");
320+
assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin");
321+
} else if (expected instanceof BinaryDerivationNode) {
322+
BinaryDerivationNode expectedBin = (BinaryDerivationNode) expected;
323+
BinaryDerivationNode actualBin = (BinaryDerivationNode) actual;
324+
assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match");
325+
assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left");
326+
assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right");
327+
} else if (expected instanceof VarDerivationNode) {
328+
VarDerivationNode expectedVar = (VarDerivationNode) expected;
329+
VarDerivationNode actualVar = (VarDerivationNode) actual;
330+
assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match");
331+
} else if (expected instanceof UnaryDerivationNode) {
332+
UnaryDerivationNode expectedUnary = (UnaryDerivationNode) expected;
333+
UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
334+
assertEquals(expectedUnary.getOperator(), actualUnary.getOperator(), message + ": operators should match");
335+
assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
336+
}
337+
}
338+
}

0 commit comments

Comments
 (0)