diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java index 94aa3ab47056321..a6092eea160ea42 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java @@ -65,89 +65,120 @@ public void testSimplify() { DecimalV3Type.createDecimalV3Type(9, 0)); Expression tinyIntLiteral = new TinyIntLiteral((byte) 12); + // cast tinyint as tinyint assertRewrite(new Cast(tinyIntLiteral, TinyIntType.INSTANCE), tinyIntLiteral); + // cast tinyint as decimalv2(3,0) assertRewrite(new Cast(tinyIntLiteral, DecimalV2Type.forType(TinyIntType.INSTANCE)), new DecimalLiteral(new BigDecimal(12))); // TODO case failed, cast(12 as DecimalV2(5, 1)) -> 12 decimalv2(2, 0) + // cast tinyint as decimalv2(5,1) // assertRewrite(new Cast(tinyIntLiteral, DecimalV2Type.createDecimalV2Type(5,1)), new DecimalLiteral(DecimalV2Type.createDecimalV2Type(5,1), new BigDecimal("12.0"))); + // cast tinyint as decimalv3(3,0) assertRewrite(new Cast(tinyIntLiteral, DecimalV3Type.forType(TinyIntType.INSTANCE)), new DecimalV3Literal(new BigDecimal(12))); + // cast tinyint as decimalv3(5,1) assertRewrite(new Cast(tinyIntLiteral, DecimalV3Type.createDecimalV3Type(5, 1)), new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 1), new BigDecimal("12.0"))); // TODO cast(12 as decimalv3(2,1)) -> 12.0 decimalv3(2,1), and 12.0 decimalv3(2,1) == 12.0 decimalv3(3,1) ?? + // cast tinyint as decimalv3(2,1) assertRewrite(new Cast(tinyIntLiteral, DecimalV3Type.createDecimalV3Type(2, 1)), new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1), new BigDecimal("12.0"))); // TODO cast(12 as decimalv2(2,1)) -> 12 decimalv2(2,0) ?? + // cast tinyint as decimalv2(2,1) assertRewrite(new Cast(tinyIntLiteral, DecimalV2Type.createDecimalV2Type(2, 1)), new DecimalLiteral(DecimalV2Type.createDecimalV2Type(2, 0), new BigDecimal("12"))); Expression smallIntLiteral = new SmallIntLiteral((short) 30000); + // cast smallint as smallint assertRewrite(new Cast(smallIntLiteral, SmallIntType.INSTANCE), smallIntLiteral); + // cast smallint as decimalv2 assertRewrite(new Cast(smallIntLiteral, DecimalV2Type.forType(SmallIntType.INSTANCE)), new DecimalLiteral(new BigDecimal(30000))); + // cast smallint as decimalv3 assertRewrite(new Cast(smallIntLiteral, DecimalV3Type.forType(SmallIntType.INSTANCE)), new DecimalV3Literal(new BigDecimal(30000))); Expression intLiteral = new IntegerLiteral(30000000); + // cast int as int assertRewrite(new Cast(intLiteral, IntegerType.INSTANCE), intLiteral); + // cast int as decimalv2 assertRewrite(new Cast(intLiteral, DecimalV2Type.forType(IntegerType.INSTANCE)), new DecimalLiteral(new BigDecimal(30000000))); + // cast int as decimalv3 assertRewrite(new Cast(intLiteral, DecimalV3Type.forType(IntegerType.INSTANCE)), new DecimalV3Literal(new BigDecimal(30000000))); Expression bigIntLiteral = new BigIntLiteral(30000000000L); + // cast bigint as bigint assertRewrite(new Cast(bigIntLiteral, BigIntType.INSTANCE), bigIntLiteral); + // cast bigint as decimalv2 assertRewrite(new Cast(bigIntLiteral, DecimalV2Type.forType(BigIntType.INSTANCE)), new DecimalLiteral(new BigDecimal(30000000000L))); + // cast bigint as decimalv3 assertRewrite(new Cast(bigIntLiteral, DecimalV3Type.forType(BigIntType.INSTANCE)), new DecimalV3Literal(new BigDecimal(30000000000L))); Expression varcharLiteral = new VarcharLiteral("12345"); + // cast varchar(5) as varchar(3) assertRewrite(new Cast(varcharLiteral, VarcharType.createVarcharType(3)), new VarcharLiteral("123")); + // cast varchar(5) as varchar(10) assertRewrite(new Cast(varcharLiteral, VarcharType.createVarcharType(10)), new VarcharLiteral("12345")); + // cast varchar(5) as string assertRewrite(new Cast(varcharLiteral, StringType.INSTANCE), new StringLiteral("12345")); Expression charLiteral = new CharLiteral("12345", 5); + // cast char(5) as varchar(3) assertRewrite(new Cast(charLiteral, VarcharType.createVarcharType(3)), new VarcharLiteral("123")); + // cast char(5) as varchar(10) assertRewrite(new Cast(charLiteral, VarcharType.createVarcharType(10)), new VarcharLiteral("12345")); + // cast char(5) as string assertRewrite(new Cast(charLiteral, StringType.INSTANCE), new StringLiteral("12345")); Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1), new BigDecimal("12.0")); + // cast decimalv3(3,1) as decimalv3(5,1) assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(5, 1)), new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 1), new BigDecimal("12.0"))); // TODO this is different from cast(12 as decimalv3(2,1)) + // cast decimalv3(3,1) as decimalv3(2,1) assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(2, 1)), new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(2, 1))); // TODO unsupported but should? + // cast tinyint as smallint assertRewrite(new Cast(tinyIntLiteral, SmallIntType.INSTANCE), new Cast(tinyIntLiteral, SmallIntType.INSTANCE)); + // cast tinyint as int assertRewrite(new Cast(tinyIntLiteral, IntegerType.INSTANCE), new Cast(tinyIntLiteral, IntegerType.INSTANCE)); + // cast tinyint as bigint assertRewrite(new Cast(tinyIntLiteral, BigIntType.INSTANCE), new Cast(tinyIntLiteral, BigIntType.INSTANCE)); // unsupported + // cast bigint as int assertRewrite(new Cast(bigIntLiteral, IntegerType.INSTANCE), new Cast(bigIntLiteral, IntegerType.INSTANCE)); + // cast bigint as smallint assertRewrite(new Cast(bigIntLiteral, SmallIntType.INSTANCE), new Cast(bigIntLiteral, SmallIntType.INSTANCE)); - assertRewrite(new Cast(bigIntLiteral, IntegerType.INSTANCE), - new Cast(bigIntLiteral, IntegerType.INSTANCE)); + // cast bigint as tinyint + assertRewrite(new Cast(bigIntLiteral, TinyIntType.INSTANCE), + new Cast(bigIntLiteral, TinyIntType.INSTANCE)); + // nested cast assertRewrite( new Cast( new Cast(