Skip to content

Commit 63d2fad

Browse files
committed
finish TODO and simplify test
1 parent 256f2e6 commit 63d2fad

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
1414
include "mlir/IR/OpBase.td"
1515
include "mlir/IR/PatternBase.td"
1616

17-
// TODO: get the proper scalar type from the operand polynomial ring attribute
17+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
18+
// ring coefficient type.
19+
def getMinusOne
20+
: NativeCodeCall<
21+
"$_builder.getIntegerAttr("
22+
"cast<PolynomialType>($0.getType()).getRing().getCoefficientType(), -1)">;
23+
1824
def SubAsAdd : Pat<
1925
(Polynomial_SubOp $f, $g),
2026
(Polynomial_AddOp $f,
2127
(Polynomial_MulScalarOp $g,
22-
(Arith_ConstantOp
23-
ConstantAttr<I32Attr, "-1">)))>;
28+
(Arith_ConstantOp (getMinusOne $g))))>;
2429

2530
def INTTAfterNTT : Pat<
2631
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),

mlir/test/Dialect/Polynomial/canonicalization.mlir

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,14 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
3232

3333
#cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
3434
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
35-
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
36-
#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
3735
!sub_ty = !polynomial.polynomial<ring=#ring>
3836

39-
// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
40-
func.func @test_canonicalize_sub_power_of_two_cmod() -> !sub_ty {
41-
%poly0 = polynomial.constant {value=#one_plus_x_squared} : !sub_ty
42-
%poly1 = polynomial.constant {value=#one_minus_x_squared} : !sub_ty
37+
// CHECK-LABEL: test_canonicalize_sub
38+
// CHECK-SAME: (%[[p0:.*]]: [[T:.*]], %[[p1:.*]]: [[T]]) -> [[T]] {
39+
func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty {
4340
%0 = polynomial.sub %poly0, %poly1 : !sub_ty
4441
// CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
45-
// CHECK: %[[p1:.+]] = polynomial.constant
46-
// CHECK: %[[p2:.+]] = polynomial.constant
47-
// CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
48-
// CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
42+
// CHECK: %[[p1neg:.+]] = polynomial.mul_scalar %[[p1]], %[[minus_one]]
43+
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
4944
return %0 : !sub_ty
5045
}

0 commit comments

Comments
 (0)