Skip to content

[CIR] Infer MLIR context in type builders when possible #1570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {

mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
mlir::Value imag) {
auto resultComplexTy = cir::ComplexType::get(getContext(), real.getType());
auto resultComplexTy = cir::ComplexType::get(real.getType());
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}

Expand Down
38 changes: 38 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def CIR_ComplexType : CIR_Type<"Complex", "complex",

let parameters = (ins "mlir::Type":$elementTy);

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementTy), [{
return $_get(elementTy.getContext(), elementTy);
}]>,
];

let assemblyFormat = [{
`<` $elementTy `>`
}];
Expand Down Expand Up @@ -301,6 +307,14 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
let parameters = (ins "mlir::Type":$memberTy,
"cir::RecordType":$clsTy);

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$memberTy, "cir::RecordType":$clsTy
), [{
return $_get(memberTy.getContext(), memberTy, clsTy);
}]>,
];

let assemblyFormat = [{
`<` $memberTy `in` $clsTy `>`
}];
Expand Down Expand Up @@ -338,6 +352,14 @@ def CIR_ArrayType : CIR_Type<"Array", "array",

let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size);

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$eltType, "uint64_t":$size
), [{
return $_get(eltType.getContext(), eltType, size);
}]>,
];

let assemblyFormat = [{
`<` $eltType `x` $size `>`
}];
Expand All @@ -358,6 +380,14 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",

let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size);

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$eltType, "uint64_t":$size
), [{
return $_get(eltType.getContext(), eltType, size);
}]>,
];

let assemblyFormat = [{
`<` $eltType `x` $size `>`
}];
Expand Down Expand Up @@ -452,6 +482,14 @@ def CIR_MethodType : CIR_Type<"Method", "method",
let parameters = (ins "cir::FuncType":$memberFuncTy,
"cir::RecordType":$clsTy);

let builders = [
TypeBuilderWithInferredContext<(ins
"cir::FuncType":$memberFuncTy, "cir::RecordType":$clsTy
), [{
return $_get(memberFuncTy.getContext(), memberFuncTy, clsTy);
}]>,
];

let assemblyFormat = [{
`<` qualified($memberFuncTy) `in` $clsTy `>`
}];
Expand Down
13 changes: 4 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,16 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
// If the string is full of null bytes, emit a #cir.zero rather than
// a #cir.const_array.
if (lastNonZeroPos == llvm::StringRef::npos) {
auto arrayTy = cir::ArrayType::get(getContext(), eltTy, finalSize);
auto arrayTy = cir::ArrayType::get(eltTy, finalSize);
return getZeroAttr(arrayTy);
}
// We will use trailing zeros only if there are more than one zero
// at the end
int trailingZerosNum =
finalSize > lastNonZeroPos + 2 ? finalSize - lastNonZeroPos - 1 : 0;
auto truncatedArrayTy =
cir::ArrayType::get(getContext(), eltTy, finalSize - trailingZerosNum);
auto fullArrayTy = cir::ArrayType::get(getContext(), eltTy, finalSize);
cir::ArrayType::get(eltTy, finalSize - trailingZerosNum);
auto fullArrayTy = cir::ArrayType::get(eltTy, finalSize);
return cir::ConstArrayAttr::get(
getContext(), fullArrayTy,
mlir::StringAttr::get(str.drop_back(trailingZerosNum),
Expand Down Expand Up @@ -407,8 +407,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
bool isSigned = false) {
auto elementTy = mlir::dyn_cast_or_null<cir::IntType>(vt.getEltType());
assert(elementTy && "expected int vector");
return cir::VectorType::get(getContext(),
isExtended
return cir::VectorType::get(isExtended
? getExtendedIntTy(elementTy, isSigned)
: getTruncatedIntTy(elementTy, isSigned),
vt.getSize());
Expand Down Expand Up @@ -530,10 +529,6 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
return getCompleteRecordTy(members, name, packed, padded, ast);
}

cir::ArrayType getArrayType(mlir::Type eltType, unsigned size) {
return cir::ArrayType::get(getContext(), eltType, size);
}

bool isSized(mlir::Type ty) {
if (mlir::isa<cir::PointerType, cir::RecordType, cir::ArrayType,
cir::BoolType, cir::IntType, cir::CIRFPTypeInterface>(ty))
Expand Down
91 changes: 31 additions & 60 deletions clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1863,14 +1863,12 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags,
switch (TypeFlags.getEltType()) {
case NeonTypeFlags::Int8:
case NeonTypeFlags::Poly8:
return cir::VectorType::get(CGF->getBuilder().getContext(),
TypeFlags.isUnsigned() ? CGF->UInt8Ty
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt8Ty
: CGF->SInt8Ty,
V1Ty ? 1 : (8 << IsQuad));
case NeonTypeFlags::Int16:
case NeonTypeFlags::Poly16:
return cir::VectorType::get(CGF->getBuilder().getContext(),
TypeFlags.isUnsigned() ? CGF->UInt16Ty
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt16Ty
: CGF->SInt16Ty,
V1Ty ? 1 : (4 << IsQuad));
case NeonTypeFlags::BFloat16:
Expand All @@ -1884,14 +1882,12 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags,
else
llvm_unreachable("NeonTypeFlags::Float16 NYI");
case NeonTypeFlags::Int32:
return cir::VectorType::get(CGF->getBuilder().getContext(),
TypeFlags.isUnsigned() ? CGF->UInt32Ty
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt32Ty
: CGF->SInt32Ty,
V1Ty ? 1 : (2 << IsQuad));
case NeonTypeFlags::Int64:
case NeonTypeFlags::Poly64:
return cir::VectorType::get(CGF->getBuilder().getContext(),
TypeFlags.isUnsigned() ? CGF->UInt64Ty
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt64Ty
: CGF->SInt64Ty,
V1Ty ? 1 : (1 << IsQuad));
case NeonTypeFlags::Poly128:
Expand All @@ -1900,12 +1896,10 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags,
// so we use v16i8 to represent poly128 and get pattern matched.
llvm_unreachable("NeonTypeFlags::Poly128 NYI");
case NeonTypeFlags::Float32:
return cir::VectorType::get(CGF->getBuilder().getContext(),
CGF->getCIRGenModule().FloatTy,
return cir::VectorType::get(CGF->getCIRGenModule().FloatTy,
V1Ty ? 1 : (2 << IsQuad));
case NeonTypeFlags::Float64:
return cir::VectorType::get(CGF->getBuilder().getContext(),
CGF->getCIRGenModule().DoubleTy,
return cir::VectorType::get(CGF->getCIRGenModule().DoubleTy,
V1Ty ? 1 : (1 << IsQuad));
}
llvm_unreachable("Unknown vector element type!");
Expand Down Expand Up @@ -2102,7 +2096,7 @@ static cir::VectorType getSignChangedVectorType(CIRGenBuilderTy &builder,
auto elemTy = mlir::cast<cir::IntType>(vecTy.getEltType());
elemTy = elemTy.isSigned() ? builder.getUIntNTy(elemTy.getWidth())
: builder.getSIntNTy(elemTy.getWidth());
return cir::VectorType::get(builder.getContext(), elemTy, vecTy.getSize());
return cir::VectorType::get(elemTy, vecTy.getSize());
}

static cir::VectorType
Expand All @@ -2111,19 +2105,16 @@ getHalfEltSizeTwiceNumElemsVecType(CIRGenBuilderTy &builder,
auto elemTy = mlir::cast<cir::IntType>(vecTy.getEltType());
elemTy = elemTy.isSigned() ? builder.getSIntNTy(elemTy.getWidth() / 2)
: builder.getUIntNTy(elemTy.getWidth() / 2);
return cir::VectorType::get(builder.getContext(), elemTy,
vecTy.getSize() * 2);
return cir::VectorType::get(elemTy, vecTy.getSize() * 2);
}

static cir::VectorType
castVecOfFPTypeToVecOfIntWithSameWidth(CIRGenBuilderTy &builder,
cir::VectorType vecTy) {
if (mlir::isa<cir::SingleType>(vecTy.getEltType()))
return cir::VectorType::get(builder.getContext(), builder.getSInt32Ty(),
vecTy.getSize());
return cir::VectorType::get(builder.getSInt32Ty(), vecTy.getSize());
if (mlir::isa<cir::DoubleType>(vecTy.getEltType()))
return cir::VectorType::get(builder.getContext(), builder.getSInt64Ty(),
vecTy.getSize());
return cir::VectorType::get(builder.getSInt64Ty(), vecTy.getSize());
llvm_unreachable(
"Unsupported element type in getVecOfIntTypeWithSameEltWidth");
}
Expand Down Expand Up @@ -2315,8 +2306,7 @@ static mlir::Value emitCommonNeonVecAcrossCall(CIRGenFunction &cgf,
const clang::CallExpr *e) {
CIRGenBuilderTy &builder = cgf.getBuilder();
mlir::Value op = cgf.emitScalarExpr(e->getArg(0));
cir::VectorType vTy =
cir::VectorType::get(&cgf.getMLIRContext(), eltTy, vecLen);
cir::VectorType vTy = cir::VectorType::get(eltTy, vecLen);
llvm::SmallVector<mlir::Value, 1> args{op};
return emitNeonCall(builder, {vTy}, args, intrincsName, eltTy,
cgf.getLoc(e->getExprLoc()));
Expand Down Expand Up @@ -2447,8 +2437,7 @@ mlir::Value CIRGenFunction::emitCommonNeonBuiltinExpr(
cir::VectorType resTy =
(builtinID == NEON::BI__builtin_neon_vqdmulhq_lane_v ||
builtinID == NEON::BI__builtin_neon_vqrdmulhq_lane_v)
? cir::VectorType::get(&getMLIRContext(), vTy.getEltType(),
vTy.getSize() * 2)
? cir::VectorType::get(vTy.getEltType(), vTy.getSize() * 2)
: vTy;
cir::VectorType mulVecT =
GetNeonType(this, NeonTypeFlags(neonType.getEltType(), false,
Expand Down Expand Up @@ -2888,10 +2877,8 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
llvm_unreachable(" neon_vqmovnh_u16 NYI ");
case NEON::BI__builtin_neon_vqmovns_s32: {
mlir::Location loc = cgf.getLoc(expr->getExprLoc());
cir::VectorType argVecTy =
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt32Ty, 4);
cir::VectorType resVecTy =
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt16Ty, 4);
cir::VectorType argVecTy = cir::VectorType::get(cgf.SInt32Ty, 4);
cir::VectorType resVecTy = cir::VectorType::get(cgf.SInt16Ty, 4);
vecExtendIntValue(cgf, argVecTy, ops[0], loc);
mlir::Value result = emitNeonCall(builder, {argVecTy}, ops,
"aarch64.neon.sqxtn", resVecTy, loc);
Expand Down Expand Up @@ -3706,88 +3693,74 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,

case NEON::BI__builtin_neon_vset_lane_f64: {
Ops.push_back(emitScalarExpr(E->getArg(2)));
Ops[1] = builder.createBitcast(
Ops[1], cir::VectorType::get(&getMLIRContext(), DoubleTy, 1));
Ops[1] = builder.createBitcast(Ops[1], cir::VectorType::get(DoubleTy, 1));
return builder.create<cir::VecInsertOp>(getLoc(E->getExprLoc()), Ops[1],
Ops[0], Ops[2]);
}
case NEON::BI__builtin_neon_vsetq_lane_f64: {
Ops.push_back(emitScalarExpr(E->getArg(2)));
Ops[1] = builder.createBitcast(
Ops[1], cir::VectorType::get(&getMLIRContext(), DoubleTy, 2));
Ops[1] = builder.createBitcast(Ops[1], cir::VectorType::get(DoubleTy, 2));
return builder.create<cir::VecInsertOp>(getLoc(E->getExprLoc()), Ops[1],
Ops[0], Ops[2]);
}
case NEON::BI__builtin_neon_vget_lane_i8:
case NEON::BI__builtin_neon_vdupb_lane_i8:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt8Ty, 8));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt8Ty, 8));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vgetq_lane_i8:
case NEON::BI__builtin_neon_vdupb_laneq_i8:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt8Ty, 16));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt8Ty, 16));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vget_lane_i16:
case NEON::BI__builtin_neon_vduph_lane_i16:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt16Ty, 4));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt16Ty, 4));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vgetq_lane_i16:
case NEON::BI__builtin_neon_vduph_laneq_i16:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt16Ty, 8));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt16Ty, 8));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vget_lane_i32:
case NEON::BI__builtin_neon_vdups_lane_i32:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt32Ty, 2));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt32Ty, 2));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vget_lane_f32:
case NEON::BI__builtin_neon_vdups_lane_f32:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), FloatTy, 2));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(FloatTy, 2));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vgetq_lane_i32:
case NEON::BI__builtin_neon_vdups_laneq_i32:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt32Ty, 4));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt32Ty, 4));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vget_lane_i64:
case NEON::BI__builtin_neon_vdupd_lane_i64:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt64Ty, 1));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt64Ty, 1));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vdupd_lane_f64:
case NEON::BI__builtin_neon_vget_lane_f64:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), DoubleTy, 1));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(DoubleTy, 1));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vgetq_lane_i64:
case NEON::BI__builtin_neon_vdupd_laneq_i64:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt64Ty, 2));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt64Ty, 2));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vgetq_lane_f32:
case NEON::BI__builtin_neon_vdups_laneq_f32:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), FloatTy, 4));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(FloatTy, 4));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vgetq_lane_f64:
case NEON::BI__builtin_neon_vdupd_laneq_f64:
Ops[0] = builder.createBitcast(
Ops[0], cir::VectorType::get(&getMLIRContext(), DoubleTy, 2));
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(DoubleTy, 2));
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
emitScalarExpr(E->getArg(1)));
case NEON::BI__builtin_neon_vaddh_f16: {
Expand Down Expand Up @@ -4318,7 +4291,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
[[fallthrough]];
case NEON::BI__builtin_neon_vaddv_s16: {
cir::IntType eltTy = usgn ? UInt16Ty : SInt16Ty;
cir::VectorType vTy = cir::VectorType::get(builder.getContext(), eltTy, 4);
cir::VectorType vTy = cir::VectorType::get(eltTy, 4);
Ops.push_back(emitScalarExpr(E->getArg(0)));
// This is to add across the vector elements, so wider result type needed.
Ops[0] = emitNeonCall(builder, {vTy}, Ops,
Expand Down Expand Up @@ -4427,8 +4400,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
usgn = true;
[[fallthrough]];
case NEON::BI__builtin_neon_vaddlvq_s16: {
mlir::Type argTy = cir::VectorType::get(builder.getContext(),
usgn ? UInt16Ty : SInt16Ty, 8);
mlir::Type argTy = cir::VectorType::get(usgn ? UInt16Ty : SInt16Ty, 8);
llvm::SmallVector<mlir::Value, 1> argOps = {emitScalarExpr(E->getArg(0))};
return emitNeonCall(builder, {argTy}, argOps,
usgn ? "aarch64.neon.uaddlv" : "aarch64.neon.saddlv",
Expand All @@ -4441,8 +4413,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
usgn = true;
[[fallthrough]];
case NEON::BI__builtin_neon_vaddlv_s16: {
mlir::Type argTy = cir::VectorType::get(builder.getContext(),
usgn ? UInt16Ty : SInt16Ty, 4);
mlir::Type argTy = cir::VectorType::get(usgn ? UInt16Ty : SInt16Ty, 4);
llvm::SmallVector<mlir::Value, 1> argOps = {emitScalarExpr(E->getArg(0))};
return emitNeonCall(builder, {argTy}, argOps,
usgn ? "aarch64.neon.uaddlv" : "aarch64.neon.saddlv",
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
// we need to pass it as `void *args[2] = { &a, &b }`.

auto loc = fn.getLoc();
auto voidPtrArrayTy =
cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size());
auto voidPtrArrayTy = cir::ArrayType::get(cgm.VoidPtrTy, args.size());
mlir::Value kernelArgs = builder.createAlloca(
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
CharUnits::fromQuantity(16));
Expand Down
Loading
Loading