|
| 1 | +package liquidjava.rj_language.opt; |
| 2 | + |
| 3 | +import static org.junit.jupiter.api.Assertions.*; |
| 4 | + |
| 5 | +import liquidjava.rj_language.ast.BinaryExpression; |
| 6 | +import liquidjava.rj_language.ast.Expression; |
| 7 | +import liquidjava.rj_language.ast.LiteralBoolean; |
| 8 | +import liquidjava.rj_language.ast.LiteralInt; |
| 9 | +import liquidjava.rj_language.ast.UnaryExpression; |
| 10 | +import liquidjava.rj_language.ast.Var; |
| 11 | +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; |
| 12 | +import liquidjava.rj_language.opt.derivation_node.DerivationNode; |
| 13 | +import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; |
| 14 | +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; |
| 15 | +import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; |
| 16 | +import org.junit.jupiter.api.Test; |
| 17 | + |
| 18 | +/** |
| 19 | + * Test suite for expression simplification using constant propagation and folding |
| 20 | + */ |
| 21 | +class ExpressionSimplifierTest { |
| 22 | + |
| 23 | + @Test |
| 24 | + void testNegation() { |
| 25 | + // Given: -a && a == 7 |
| 26 | + // Expected: -7 |
| 27 | + |
| 28 | + Expression varA = new Var("a"); |
| 29 | + Expression negA = new UnaryExpression("-", varA); |
| 30 | + Expression seven = new LiteralInt(7); |
| 31 | + Expression aEquals7 = new BinaryExpression(varA, "==", seven); |
| 32 | + Expression fullExpression = new BinaryExpression(negA, "&&", aEquals7); |
| 33 | + |
| 34 | + // When |
| 35 | + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); |
| 36 | + |
| 37 | + // Then |
| 38 | + assertNotNull(result, "Result should not be null"); |
| 39 | + assertEquals("-7", result.getValue().toString(), "Expected result to be -7"); |
| 40 | + |
| 41 | + // 7 from variable a |
| 42 | + ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("a")); |
| 43 | + |
| 44 | + // -7 |
| 45 | + UnaryDerivationNode negation = new UnaryDerivationNode(val7, "-"); |
| 46 | + ValDerivationNode expected = new ValDerivationNode(new LiteralInt(-7), negation); |
| 47 | + |
| 48 | + // Compare the derivation trees |
| 49 | + assertDerivationEquals(expected, result, ""); |
| 50 | + } |
| 51 | + |
| 52 | + @Test |
| 53 | + void testSimpleAddition() { |
| 54 | + // Given: a + b && a == 3 && b == 5 |
| 55 | + // Expected: 8 (3 + 5) |
| 56 | + |
| 57 | + Expression varA = new Var("a"); |
| 58 | + Expression varB = new Var("b"); |
| 59 | + Expression addition = new BinaryExpression(varA, "+", varB); |
| 60 | + |
| 61 | + Expression three = new LiteralInt(3); |
| 62 | + Expression aEquals3 = new BinaryExpression(varA, "==", three); |
| 63 | + |
| 64 | + Expression five = new LiteralInt(5); |
| 65 | + Expression bEquals5 = new BinaryExpression(varB, "==", five); |
| 66 | + |
| 67 | + Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5); |
| 68 | + Expression fullExpression = new BinaryExpression(addition, "&&", conditions); |
| 69 | + |
| 70 | + // When |
| 71 | + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); |
| 72 | + |
| 73 | + // Then |
| 74 | + assertNotNull(result, "Result should not be null"); |
| 75 | + assertEquals("8", result.getValue().toString(), "Expected result to be 8"); |
| 76 | + |
| 77 | + // 3 from variable a |
| 78 | + ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("a")); |
| 79 | + |
| 80 | + // 5 from variable b |
| 81 | + ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("b")); |
| 82 | + |
| 83 | + // 3 + 5 |
| 84 | + BinaryDerivationNode add3Plus5 = new BinaryDerivationNode(val3, val5, "+"); |
| 85 | + ValDerivationNode expected = new ValDerivationNode(new LiteralInt(8), add3Plus5); |
| 86 | + |
| 87 | + // Compare the derivation trees |
| 88 | + assertDerivationEquals(expected, result, ""); |
| 89 | + } |
| 90 | + |
| 91 | + @Test |
| 92 | + void testSimpleComparison() { |
| 93 | + // Given: (y || true) && !true && y == false |
| 94 | + // Expected: false (true && false) |
| 95 | + |
| 96 | + Expression varY = new Var("y"); |
| 97 | + Expression trueExp = new LiteralBoolean(true); |
| 98 | + Expression yOrTrue = new BinaryExpression(varY, "||", trueExp); |
| 99 | + |
| 100 | + Expression notTrue = new UnaryExpression("!", trueExp); |
| 101 | + |
| 102 | + Expression falseExp = new LiteralBoolean(false); |
| 103 | + Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp); |
| 104 | + |
| 105 | + Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue); |
| 106 | + Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse); |
| 107 | + |
| 108 | + // When |
| 109 | + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); |
| 110 | + |
| 111 | + // Then |
| 112 | + assertNotNull(result, "Result should not be null"); |
| 113 | + assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean"); |
| 114 | + assertFalse(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to befalse"); |
| 115 | + |
| 116 | + // (y || true) && y == false => false || true = true |
| 117 | + ValDerivationNode valFalseForY = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y")); |
| 118 | + ValDerivationNode valTrue1 = new ValDerivationNode(new LiteralBoolean(true), null); |
| 119 | + BinaryDerivationNode orFalseTrue = new BinaryDerivationNode(valFalseForY, valTrue1, "||"); |
| 120 | + ValDerivationNode trueFromOr = new ValDerivationNode(new LiteralBoolean(true), orFalseTrue); |
| 121 | + |
| 122 | + // !true = false |
| 123 | + ValDerivationNode valTrue2 = new ValDerivationNode(new LiteralBoolean(true), null); |
| 124 | + UnaryDerivationNode notOp = new UnaryDerivationNode(valTrue2, "!"); |
| 125 | + ValDerivationNode falseFromNot = new ValDerivationNode(new LiteralBoolean(false), notOp); |
| 126 | + |
| 127 | + // true && false = false |
| 128 | + BinaryDerivationNode andTrueFalse = new BinaryDerivationNode(trueFromOr, falseFromNot, "&&"); |
| 129 | + ValDerivationNode falseFromFirstAnd = new ValDerivationNode(new LiteralBoolean(false), andTrueFalse); |
| 130 | + |
| 131 | + // y == false |
| 132 | + ValDerivationNode valFalseForY2 = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y")); |
| 133 | + ValDerivationNode valFalse2 = new ValDerivationNode(new LiteralBoolean(false), null); |
| 134 | + BinaryDerivationNode compareFalseFalse = new BinaryDerivationNode(valFalseForY2, valFalse2, "=="); |
| 135 | + ValDerivationNode trueFromCompare = new ValDerivationNode(new LiteralBoolean(true), compareFalseFalse); |
| 136 | + |
| 137 | + // false && true = false |
| 138 | + BinaryDerivationNode finalAnd = new BinaryDerivationNode(falseFromFirstAnd, trueFromCompare, "&&"); |
| 139 | + ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(false), finalAnd); |
| 140 | + |
| 141 | + // Compare the derivation trees |
| 142 | + assertDerivationEquals(expected, result, ""); |
| 143 | + } |
| 144 | + |
| 145 | + @Test |
| 146 | + void testArithmeticWithConstants() { |
| 147 | + // Given: (a / b + (-5)) + x && a == 6 && b == 2 |
| 148 | + // Expected: -2 + x (6 / 2 = 3, 3 + (-5) = -2) |
| 149 | + |
| 150 | + Expression varA = new Var("a"); |
| 151 | + Expression varB = new Var("b"); |
| 152 | + Expression division = new BinaryExpression(varA, "/", varB); |
| 153 | + |
| 154 | + Expression five = new LiteralInt(5); |
| 155 | + Expression negFive = new UnaryExpression("-", five); |
| 156 | + |
| 157 | + Expression firstSum = new BinaryExpression(division, "+", negFive); |
| 158 | + Expression varX = new Var("x"); |
| 159 | + Expression fullArithmetic = new BinaryExpression(firstSum, "+", varX); |
| 160 | + |
| 161 | + Expression six = new LiteralInt(6); |
| 162 | + Expression aEquals6 = new BinaryExpression(varA, "==", six); |
| 163 | + |
| 164 | + Expression two = new LiteralInt(2); |
| 165 | + Expression bEquals2 = new BinaryExpression(varB, "==", two); |
| 166 | + |
| 167 | + Expression allConditions = new BinaryExpression(aEquals6, "&&", bEquals2); |
| 168 | + Expression fullExpression = new BinaryExpression(fullArithmetic, "&&", allConditions); |
| 169 | + |
| 170 | + // When |
| 171 | + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); |
| 172 | + |
| 173 | + // Then |
| 174 | + assertNotNull(result, "Result should not be null"); |
| 175 | + assertNotNull(result.getValue(), "Result value should not be null"); |
| 176 | + |
| 177 | + String resultStr = result.getValue().toString(); |
| 178 | + assertEquals("-2 + x", resultStr, "Expected result to be -2 + x"); |
| 179 | + |
| 180 | + // 6 from variable a |
| 181 | + ValDerivationNode val6 = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a")); |
| 182 | + |
| 183 | + // 2 from variable b |
| 184 | + ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), new VarDerivationNode("b")); |
| 185 | + |
| 186 | + // 6 / 2 = 3 |
| 187 | + BinaryDerivationNode div6By2 = new BinaryDerivationNode(val6, val2, "/"); |
| 188 | + ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), div6By2); |
| 189 | + |
| 190 | + // -5 from unary negation of 5 |
| 191 | + ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), null); |
| 192 | + UnaryDerivationNode unaryNeg5 = new UnaryDerivationNode(val5, "-"); |
| 193 | + ValDerivationNode valNeg5 = new ValDerivationNode(new LiteralInt(-5), unaryNeg5); |
| 194 | + |
| 195 | + // 3 + (-5) = -2 |
| 196 | + BinaryDerivationNode add3AndNeg5 = new BinaryDerivationNode(val3, valNeg5, "+"); |
| 197 | + ValDerivationNode valNeg2 = new ValDerivationNode(new LiteralInt(-2), add3AndNeg5); |
| 198 | + |
| 199 | + // x (variable with null origin) |
| 200 | + ValDerivationNode valX = new ValDerivationNode(new Var("x"), null); |
| 201 | + |
| 202 | + // -2 + x |
| 203 | + BinaryDerivationNode addNeg2AndX = new BinaryDerivationNode(valNeg2, valX, "+"); |
| 204 | + Expression expectedResultExpr = new BinaryExpression(new LiteralInt(-2), "+", new Var("x")); |
| 205 | + ValDerivationNode expected = new ValDerivationNode(expectedResultExpr, addNeg2AndX); |
| 206 | + |
| 207 | + // Compare the derivation trees |
| 208 | + assertDerivationEquals(expected, result, ""); |
| 209 | + } |
| 210 | + |
| 211 | + @Test |
| 212 | + void testComplexArithmeticWithMultipleOperations() { |
| 213 | + // Given: (a * 2 + b - 3) == c && a == 5 && b == 7 && c == 14 |
| 214 | + // Expected: (5 * 2 + 7 - 3) == 14 => 14 == 14 => true |
| 215 | + |
| 216 | + Expression varA = new Var("a"); |
| 217 | + Expression varB = new Var("b"); |
| 218 | + Expression varC = new Var("c"); |
| 219 | + |
| 220 | + Expression two = new LiteralInt(2); |
| 221 | + Expression aTimes2 = new BinaryExpression(varA, "*", two); |
| 222 | + |
| 223 | + Expression sum = new BinaryExpression(aTimes2, "+", varB); |
| 224 | + |
| 225 | + Expression three = new LiteralInt(3); |
| 226 | + Expression arithmetic = new BinaryExpression(sum, "-", three); |
| 227 | + |
| 228 | + Expression comparison = new BinaryExpression(arithmetic, "==", varC); |
| 229 | + |
| 230 | + Expression five = new LiteralInt(5); |
| 231 | + Expression aEquals5 = new BinaryExpression(varA, "==", five); |
| 232 | + |
| 233 | + Expression seven = new LiteralInt(7); |
| 234 | + Expression bEquals7 = new BinaryExpression(varB, "==", seven); |
| 235 | + |
| 236 | + Expression fourteen = new LiteralInt(14); |
| 237 | + Expression cEquals14 = new BinaryExpression(varC, "==", fourteen); |
| 238 | + |
| 239 | + Expression conj1 = new BinaryExpression(aEquals5, "&&", bEquals7); |
| 240 | + Expression allConditions = new BinaryExpression(conj1, "&&", cEquals14); |
| 241 | + Expression fullExpression = new BinaryExpression(comparison, "&&", allConditions); |
| 242 | + |
| 243 | + // When |
| 244 | + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); |
| 245 | + |
| 246 | + // Then |
| 247 | + assertNotNull(result, "Result should not be null"); |
| 248 | + assertNotNull(result.getValue(), "Result value should not be null"); |
| 249 | + assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean literal"); |
| 250 | + assertTrue(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to be true"); |
| 251 | + |
| 252 | + // 5 * 2 + 7 - 3 |
| 253 | + ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); |
| 254 | + ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null); |
| 255 | + BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*"); |
| 256 | + ValDerivationNode val10 = new ValDerivationNode(new LiteralInt(10), mult5Times2); |
| 257 | + |
| 258 | + ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b")); |
| 259 | + BinaryDerivationNode add10Plus7 = new BinaryDerivationNode(val10, val7, "+"); |
| 260 | + ValDerivationNode val17 = new ValDerivationNode(new LiteralInt(17), add10Plus7); |
| 261 | + |
| 262 | + ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), null); |
| 263 | + BinaryDerivationNode sub17Minus3 = new BinaryDerivationNode(val17, val3, "-"); |
| 264 | + ValDerivationNode val14Left = new ValDerivationNode(new LiteralInt(14), sub17Minus3); |
| 265 | + |
| 266 | + // 14 from variable c |
| 267 | + ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c")); |
| 268 | + |
| 269 | + // 14 == 14 |
| 270 | + BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "=="); |
| 271 | + ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14); |
| 272 | + |
| 273 | + // a == 5 => true |
| 274 | + ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); |
| 275 | + ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null); |
| 276 | + BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "=="); |
| 277 | + ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5); |
| 278 | + |
| 279 | + // b == 7 => true |
| 280 | + ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b")); |
| 281 | + ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null); |
| 282 | + BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "=="); |
| 283 | + ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7); |
| 284 | + |
| 285 | + // (a == 5) && (b == 7) => true |
| 286 | + BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&"); |
| 287 | + ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB); |
| 288 | + |
| 289 | + // c == 14 => true |
| 290 | + ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c")); |
| 291 | + ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null); |
| 292 | + BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "=="); |
| 293 | + ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14); |
| 294 | + |
| 295 | + // ((a == 5) && (b == 7)) && (c == 14) => true |
| 296 | + BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&"); |
| 297 | + ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC); |
| 298 | + |
| 299 | + // 14 == 14 => true |
| 300 | + BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&"); |
| 301 | + ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd); |
| 302 | + |
| 303 | + // Compare the derivation trees |
| 304 | + assertDerivationEquals(expected, result, ""); |
| 305 | + } |
| 306 | + |
| 307 | + /** |
| 308 | + * Helper method to compare two derivation nodes recursively |
| 309 | + */ |
| 310 | + private void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) { |
| 311 | + if (expected == null && actual == null) |
| 312 | + return; |
| 313 | + |
| 314 | + assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match"); |
| 315 | + if (expected instanceof ValDerivationNode) { |
| 316 | + ValDerivationNode expectedVal = (ValDerivationNode) expected; |
| 317 | + ValDerivationNode actualVal = (ValDerivationNode) actual; |
| 318 | + assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(), |
| 319 | + message + ": values should match"); |
| 320 | + assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin"); |
| 321 | + } else if (expected instanceof BinaryDerivationNode) { |
| 322 | + BinaryDerivationNode expectedBin = (BinaryDerivationNode) expected; |
| 323 | + BinaryDerivationNode actualBin = (BinaryDerivationNode) actual; |
| 324 | + assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match"); |
| 325 | + assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left"); |
| 326 | + assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right"); |
| 327 | + } else if (expected instanceof VarDerivationNode) { |
| 328 | + VarDerivationNode expectedVar = (VarDerivationNode) expected; |
| 329 | + VarDerivationNode actualVar = (VarDerivationNode) actual; |
| 330 | + assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match"); |
| 331 | + } else if (expected instanceof UnaryDerivationNode) { |
| 332 | + UnaryDerivationNode expectedUnary = (UnaryDerivationNode) expected; |
| 333 | + UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual; |
| 334 | + assertEquals(expectedUnary.getOperator(), actualUnary.getOperator(), message + ": operators should match"); |
| 335 | + assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand"); |
| 336 | + } |
| 337 | + } |
| 338 | +} |
0 commit comments