Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,33 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
public:
using OpConversionPattern<cir::UnaryOp>::OpConversionPattern;

template <typename OpFloat, typename OpInt, bool rev>
mlir::Operation *
replaceImmediateOp(cir::UnaryOp op, mlir::Type type, mlir::Value input,
int64_t n,
mlir::ConversionPatternRewriter &rewriter) const {
if (type.isFloat()) {
auto imm = mlir::arith::ConstantOp::create(
rewriter, op.getLoc(),
mlir::FloatAttr::get(type, static_cast<double>(n)));
if constexpr (rev)
return rewriter.replaceOpWithNewOp<OpFloat>(op, type, imm, input);
else
return rewriter.replaceOpWithNewOp<OpFloat>(op, type, input, imm);
}
if (type.isInteger()) {
auto imm = mlir::arith::ConstantOp::create(
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, n));
if constexpr (rev)
return rewriter.replaceOpWithNewOp<OpInt>(op, type, imm, input);
else
return rewriter.replaceOpWithNewOp<OpInt>(op, type, input, imm);
}
op->emitError("Unsupported type: ") << type << " at " << op->getLoc();
llvm_unreachable("CIRUnaryOpLowering met unsupported type");
return nullptr;
}

mlir::LogicalResult
matchAndRewrite(cir::UnaryOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand All @@ -810,36 +837,31 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {

switch (op.getKind()) {
case cir::UnaryOpKind::Inc: {
auto One = mlir::arith::ConstantOp::create(
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1));
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, type, input, One);
replaceImmediateOp<mlir::arith::AddFOp, mlir::arith::AddIOp, false>(
op, type, input, 1, rewriter);
break;
}
case cir::UnaryOpKind::Dec: {
auto One = mlir::arith::ConstantOp::create(
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1));
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, input, One);
replaceImmediateOp<mlir::arith::AddFOp, mlir::arith::AddIOp, false>(
op, type, input, -1, rewriter);
break;
}
case cir::UnaryOpKind::Plus: {
rewriter.replaceOp(op, op.getInput());
break;
}
case cir::UnaryOpKind::Minus: {
auto Zero = mlir::arith::ConstantOp::create(
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 0));
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, Zero, input);
replaceImmediateOp<mlir::arith::SubFOp, mlir::arith::SubIOp, true>(
op, type, input, 0, rewriter);
break;
}
case cir::UnaryOpKind::Not: {
auto MinusOne = mlir::arith::ConstantOp::create(
auto o = mlir::arith::ConstantOp::create(
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, -1));
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, MinusOne,
input);
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, o, input);
break;
}
}

return mlir::LogicalResult::success();
}
};
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CIR/Lowering/ThroughMLIR/if.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ void foo() {
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
//CHECK: } else {
//CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref<i32>
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
//CHECK: %[[SEVEN:.+]] = arith.subi %[[SIX]], %[[C1_I32]] : i32
//CHECK: %[[C1_I32:.+]] = arith.constant -1 : i32
//CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
//CHECK: }
//CHECK: }
Expand Down Expand Up @@ -106,8 +106,8 @@ void foo3() {
//CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref<i32>
//CHECK: } else {
//CHECK: %[[TWELVE:.+]] = memref.load %[[alloca_0]][] : memref<i32>
//CHECK: %[[C1_I32_5:.+]] = arith.constant 1 : i32
//CHECK: %[[THIRTEEN:.+]] = arith.subi %[[TWELVE]], %[[C1_I32_5]] : i32
//CHECK: %[[C1_I32_5:.+]] = arith.constant -1 : i32
//CHECK: %[[THIRTEEN:.+]] = arith.addi %[[TWELVE]], %[[C1_I32_5]] : i32
//CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref<i32>
//CHECK: }
//CHECK: }
Expand Down
30 changes: 24 additions & 6 deletions clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
// RUN: cir-opt %s --cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s --cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM

!s32i = !cir.int<s, 32>
module {
Expand All @@ -17,14 +17,32 @@ module {
%5 = cir.load %1 : !cir.ptr<!s32i>, !s32i
%6 = cir.unary(dec, %5) : !s32i, !s32i
cir.store %6, %1 : !s32i, !cir.ptr<!s32i>

// test float
%7 = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
cir.return
}
}

// MLIR: = arith.constant 1
// MLIR: = arith.addi
// MLIR: = arith.constant 1
// MLIR: = arith.subi
// MLIR: = arith.constant -1
// MLIR: = arith.addi

// LLVM: = add i32 %[[#]], 1
// LLVM: = sub i32 %[[#]], 1
// LLVM: = add i32 %[[#]], -1


cir.func @floatingPoints(%arg0: !cir.double) {
%0 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["X", init] {alignment = 8 : i64}
cir.store %arg0, %0 : !cir.double, !cir.ptr<!cir.double>
%1 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
%2 = cir.unary(inc, %1) : !cir.double, !cir.double
%3 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
%4 = cir.unary(dec, %3) : !cir.double, !cir.double
cir.return
}
// MLIR: = arith.constant 1.0
// MLIR: = arith.addf
// MLIR: = arith.constant -1.0
// MLIR: = arith.addf
}
10 changes: 10 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ module {
cir.store %6, %1 : !s32i, !cir.ptr<!s32i>
cir.return
}

cir.func @floatingPoints(%arg0: !cir.double) {
%0 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["X", init] {alignment = 8 : i64}
cir.store %arg0, %0 : !cir.double, !cir.ptr<!cir.double>
%1 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
%2 = cir.unary(plus, %1) : !cir.double, !cir.double
%3 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
%4 = cir.unary(minus, %3) : !cir.double, !cir.double
cir.return
}
}

// MLIR: %[[#INPUT_PLUS:]] = memref.load
Expand Down