Skip to content

[CIR] Bit manipulation builtin functions #146529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,146 @@ def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
}];
}

//===----------------------------------------------------------------------===//
// Bit Manipulation Operations
//===----------------------------------------------------------------------===//

class CIR_BitOpBase<string mnemonic, TypeConstraint operandTy>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins operandTy:$input);
let results = (outs operandTy:$result);

let assemblyFormat = [{
`(` $input `:` type($input) `)` `:` type($result) attr-dict
}];
}

class CIR_BitZeroCountOpBase<string mnemonic, TypeConstraint operandTy>
: CIR_BitOpBase<mnemonic, operandTy> {
let arguments = (ins operandTy:$input, UnitAttr:$poison_zero);

let assemblyFormat = [{
`(` $input `:` type($input) `)` (`poison_zero` $poison_zero^)?
`:` type($result) attr-dict
}];
}

def BitClrsbOp : CIR_BitOpBase<"bit.clrsb", CIR_SIntOfWidths<[32, 64]>> {
let summary = "Get the number of leading redundant sign bits in the input";
let description = [{
Compute the number of leading redundant sign bits in the input integer.

The input integer must be a signed integer. The most significant bit of the
input integer is the sign bit. The `cir.bit.clrsb` operation returns the
number of consecutive bits following the sign bit that are identical to the
sign bit.

The bit width of the input integer must be either 32 or 64.

Examples:

```mlir
// %0 = 0b1101_1110_1010_1101_1011_1110_1110_1111
%0 = cir.const #cir.int<3735928559> : !s32i
// %1 will be 1 because there is 1 bit following the most significant bit
// that is identical to it.
%1 = cir.bit.clrsb(%0 : !s32i) : !s32i

// %2 = 1, 0b0000_0000_0000_0000_0000_0000_0000_0001
%2 = cir.const #cir.int<1> : !s32i
// %3 will be 30 because there are 30 consecutive bits following the sign
// bit that are identical to the sign bit.
%3 = cir.bit.clrsb(%2 : !s32i) : !s32i
```
}];
}

def BitClzOp : CIR_BitZeroCountOpBase<"bit.clz",
CIR_UIntOfWidths<[16, 32, 64]>> {
let summary = "Get the number of leading 0-bits in the input";
let description = [{
Compute the number of leading 0-bits in the input.

The input integer must be an unsigned integer. The `cir.bit.clz` operation
returns the number of consecutive 0-bits at the most significant bit
position in the input.

If the `poison_zero` attribute is present, this operation will have
undefined behavior if the input value is 0.

Example:

```mlir
// %0 = 0b0000_0000_0000_0000_0000_0000_0000_1000
%0 = cir.const #cir.int<8> : !u32i
// %1 will be 28
%1 = cir.bit.clz(%0 : !u32i) poison_zero : !u32i
```
}];
}

def BitCtzOp : CIR_BitZeroCountOpBase<"bit.ctz",
CIR_UIntOfWidths<[16, 32, 64]>> {
let summary = "Get the number of trailing 0-bits in the input";
let description = [{
Compute the number of trailing 0-bits in the input.

The input integer must be an unsigned integer. The `cir.bit.ctz` operation
counts the number of consecutive 0-bits starting from the least significant
bit.

If the `poison_zero` attribute is present, this operation will have
undefined behavior if the input value is 0.

Example:

```mlir
// %0 = 0b1000
%0 = cir.const #cir.int<8> : !u32i
// %1 will be 3
%1 = cir.bit.ctz(%0 : !u32i) poison_zero : !u32i
```
}];
}

def BitParityOp : CIR_BitOpBase<"bit.parity", CIR_UIntOfWidths<[32, 64]>> {
let summary = "Get the parity of input";
let description = [{
Compute the parity of the input. The parity of an integer is the number of
1-bits in it modulo 2.

The input must be an unsigned integer.

Example:

```mlir
// %0 = 0x0110_1000
%0 = cir.const #cir.int<104> : !u32i
// %1 will be 1 since there are three 1-bits in %0
%1 = cir.bit.parity(%0 : !u32i) : !u32i
```
}];
}

def BitPopcountOp : CIR_BitOpBase<"bit.popcnt",
CIR_UIntOfWidths<[16, 32, 64]>> {
let summary = "Get the number of 1-bits in input";
let description = [{
Compute the number of 1-bits in the input.

The input must be an unsigned integer.

Example:

```mlir
// %0 = 0x0110_1000
%0 = cir.const #cir.int<104> : !u32i
// %1 will be 3 since there are 3 1-bits in %0
%1 = cir.bit.popcnt(%0 : !u32i) : !u32i
```
}];
}

//===----------------------------------------------------------------------===//
// Assume Operations
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ struct MissingFeatures {
static bool builtinCall() { return false; }
static bool builtinCallF128() { return false; }
static bool builtinCallMathErrno() { return false; }
static bool builtinCheckKind() { return false; }
static bool cgFPOptionsRAII() { return false; }
static bool cirgenABIInfo() { return false; }
static bool cleanupAfterErrorDiags() { return false; }
Expand Down
64 changes: 64 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@ static RValue emitLibraryCall(CIRGenFunction &cgf, const FunctionDecl *fd,
return cgf.emitCall(e->getCallee()->getType(), callee, e, ReturnValueSlot());
}

template <typename Op>
static RValue emitBuiltinBitOp(CIRGenFunction &cgf, const CallExpr *e,
bool poisonZero = false) {
assert(!cir::MissingFeatures::builtinCheckKind());

mlir::Value arg = cgf.emitScalarExpr(e->getArg(0));
CIRGenBuilderTy &builder = cgf.getBuilder();

Op op;
if constexpr (std::is_same_v<Op, cir::BitClzOp> ||
std::is_same_v<Op, cir::BitCtzOp>)
op = builder.create<Op>(cgf.getLoc(e->getSourceRange()), arg, poisonZero);
else
op = builder.create<Op>(cgf.getLoc(e->getSourceRange()), arg);

mlir::Value result = op.getResult();
mlir::Type exprTy = cgf.convertType(e->getType());
if (exprTy != result.getType())
result = builder.createIntCast(result, exprTy);

return RValue::get(result);
}

RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID,
const CallExpr *e,
ReturnValueSlot returnValue) {
Expand Down Expand Up @@ -101,6 +124,47 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID,
return RValue::get(complex);
}

case Builtin::BI__builtin_clrsb:
case Builtin::BI__builtin_clrsbl:
case Builtin::BI__builtin_clrsbll:
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, e);

case Builtin::BI__builtin_ctzs:
case Builtin::BI__builtin_ctz:
case Builtin::BI__builtin_ctzl:
case Builtin::BI__builtin_ctzll:
case Builtin::BI__builtin_ctzg:
assert(!cir::MissingFeatures::builtinCheckKind());
return emitBuiltinBitOp<cir::BitCtzOp>(*this, e, /*poisonZero=*/true);

case Builtin::BI__builtin_clzs:
case Builtin::BI__builtin_clz:
case Builtin::BI__builtin_clzl:
case Builtin::BI__builtin_clzll:
case Builtin::BI__builtin_clzg:
assert(!cir::MissingFeatures::builtinCheckKind());
return emitBuiltinBitOp<cir::BitClzOp>(*this, e, /*poisonZero=*/true);

case Builtin::BI__builtin_parity:
case Builtin::BI__builtin_parityl:
case Builtin::BI__builtin_parityll:
return emitBuiltinBitOp<cir::BitParityOp>(*this, e);

case Builtin::BI__lzcnt16:
case Builtin::BI__lzcnt:
case Builtin::BI__lzcnt64:
assert(!cir::MissingFeatures::builtinCheckKind());
return emitBuiltinBitOp<cir::BitClzOp>(*this, e, /*poisonZero=*/false);

case Builtin::BI__popcnt16:
case Builtin::BI__popcnt:
case Builtin::BI__popcnt64:
case Builtin::BI__builtin_popcount:
case Builtin::BI__builtin_popcountl:
case Builtin::BI__builtin_popcountll:
case Builtin::BI__builtin_popcountg:
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, e);

case Builtin::BI__builtin_expect:
case Builtin::BI__builtin_expect_with_probability: {
mlir::Value argValue = emitScalarExpr(e->getArg(0));
Expand Down
80 changes: 80 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,81 @@ mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
cir::BitClrsbOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), adaptor.getInput().getType(), 0);
auto isNeg = rewriter.create<mlir::LLVM::ICmpOp>(
op.getLoc(),
mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(),
mlir::LLVM::ICmpPredicate::slt),
adaptor.getInput(), zero);

auto negOne = rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), adaptor.getInput().getType(), -1);
auto flipped = rewriter.create<mlir::LLVM::XOrOp>(op.getLoc(),
adaptor.getInput(), negOne);

auto select = rewriter.create<mlir::LLVM::SelectOp>(
op.getLoc(), isNeg, flipped, adaptor.getInput());

auto resTy = getTypeConverter()->convertType(op.getType());
auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
op.getLoc(), resTy, select, /*is_zero_poison=*/false);

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one);
rewriter.replaceOp(op, res);

return mlir::LogicalResult::success();
}

mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite(
cir::BitClzOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}

mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite(
cir::BitCtzOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}

mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite(
cir::BitParityOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
adaptor.getInput());

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto popcntMod2 =
rewriter.create<mlir::LLVM::AndOp>(op.getLoc(), popcnt, one);
rewriter.replaceOp(op, popcntMod2);

return mlir::LogicalResult::success();
}

mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite(
cir::BitPopcountOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
adaptor.getInput());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}

mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite(
cir::BrCondOp brOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -1896,6 +1971,11 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMAssumeOpLowering,
CIRToLLVMBaseClassAddrOpLowering,
CIRToLLVMBinOpLowering,
CIRToLLVMBitClrsbOpLowering,
CIRToLLVMBitClzOpLowering,
CIRToLLVMBitCtzOpLowering,
CIRToLLVMBitParityOpLowering,
CIRToLLVMBitPopcountOpLowering,
CIRToLLVMBrCondOpLowering,
CIRToLLVMBrOpLowering,
CIRToLLVMCallOpLowering,
Expand Down
50 changes: 50 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,56 @@ class CIRToLLVMAssumeOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBitClrsbOpLowering
: public mlir::OpConversionPattern<cir::BitClrsbOp> {
public:
using mlir::OpConversionPattern<cir::BitClrsbOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::BitClrsbOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBitClzOpLowering
: public mlir::OpConversionPattern<cir::BitClzOp> {
public:
using mlir::OpConversionPattern<cir::BitClzOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::BitClzOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBitCtzOpLowering
: public mlir::OpConversionPattern<cir::BitCtzOp> {
public:
using mlir::OpConversionPattern<cir::BitCtzOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::BitCtzOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBitParityOpLowering
: public mlir::OpConversionPattern<cir::BitParityOp> {
public:
using mlir::OpConversionPattern<cir::BitParityOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::BitParityOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBitPopcountOpLowering
: public mlir::OpConversionPattern<cir::BitPopcountOp> {
public:
using mlir::OpConversionPattern<cir::BitPopcountOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::BitPopcountOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBrCondOpLowering
: public mlir::OpConversionPattern<cir::BrCondOp> {
public:
Expand Down
Loading
Loading