Skip to content

[CIR] Refactor vector type constraints #1626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1237,9 +1237,13 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
```
}];

let results = (outs CIR_AnyIntOrVecOfInt:$result);
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
UnitAttr:$isShiftleft);
let arguments = (ins
CIR_AnyIntOrVecOfIntType:$value,
CIR_AnyIntOrVecOfIntType:$amount,
UnitAttr:$isShiftleft
);

let results = (outs CIR_AnyIntOrVecOfIntType:$result);

let assemblyFormat = [{
`(`
Expand Down Expand Up @@ -3125,7 +3129,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,

let arguments = (ins
CIR_VectorType:$vec,
AnyType:$value,
CIR_VectorElementType:$value,
CIR_AnyFundamentalIntType:$index
);

Expand Down Expand Up @@ -3158,7 +3162,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
CIR_AnyFundamentalIntType:$index
);

let results = (outs CIR_AnyType:$result);
let results = (outs CIR_VectorElementType:$result);

let assemblyFormat = [{
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
Expand All @@ -3181,7 +3185,7 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
in the vector type.
}];

let arguments = (ins Variadic<CIR_AnyType>:$elements);
let arguments = (ins Variadic<CIR_VectorElementType>:$elements);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -3211,7 +3215,7 @@ def VecSplatOp : CIR_Op<"vec.splat", [Pure,
All elements of the vector have the same value, that of the given scalar.
}];

let arguments = (ins CIR_AnyType:$value);
let arguments = (ins CIR_VectorElementType:$value);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -3264,8 +3268,13 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
The result is a vector of the same type as the second and third arguments.
Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
}];
let arguments = (ins IntegerVector:$cond, CIR_VectorType:$vec1,
CIR_VectorType:$vec2);

let arguments = (ins
CIR_VectorOfIntType:$cond,
CIR_VectorType:$vec1,
CIR_VectorType:$vec2
);

let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
`(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
Expand Down Expand Up @@ -3328,7 +3337,7 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
result vector is constructed by taking the elements from the first input
vector from the indices indicated by the elements of the second vector.
}];
let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
let arguments = (ins CIR_VectorType:$vec, CIR_VectorOfIntType:$indices);
let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
$vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))
Expand Down Expand Up @@ -4712,8 +4721,8 @@ def LLrintOp : UnaryFPToIntBuiltinOp<"llrint", "LlrintOp">;

class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins CIR_AnyFloatOrVecOfFloat:$src);
let results = (outs CIR_AnyFloatOrVecOfFloat:$result);
let arguments = (ins CIR_AnyFloatOrVecOfFloatType:$src);
let results = (outs CIR_AnyFloatOrVecOfFloatType:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
Expand Down Expand Up @@ -4743,8 +4752,6 @@ def TanOp : UnaryFPToFPBuiltinOp<"tan", "TanOp">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;

def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
let arguments = (ins CIR_AnySignedIntOrVecOfSignedInt:$src, UnitAttr:$poison);
let results = (outs CIR_AnySignedIntOrVecOfSignedInt:$result);
let summary = [{
libc builtin equivalent abs, labs, llabs

Expand All @@ -4760,6 +4767,14 @@ def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
%2 = cir.abs %3 : !cir.vector<!s32i x 4>
```
}];

let arguments = (ins
CIR_AnySIntOrVecOfSIntType:$src,
UnitAttr:$poison
);

let results = (outs CIR_AnySIntOrVecOfSIntType:$result);

let assemblyFormat = "$src ( `poison` $poison^ )? `:` type($src) attr-dict";
}

Expand All @@ -4769,9 +4784,12 @@ class BinaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
libc builtin equivalent ignoring floating-point exceptions and errno.
}];

let arguments = (ins CIR_AnyFloatOrVecOfFloat:$lhs,
CIR_AnyFloatOrVecOfFloat:$rhs);
let results = (outs CIR_AnyFloatOrVecOfFloat:$result);
let arguments = (ins
CIR_AnyFloatOrVecOfFloatType:$lhs,
CIR_AnyFloatOrVecOfFloatType:$rhs
);

let results = (outs CIR_AnyFloatOrVecOfFloatType:$result);

let assemblyFormat = [{
$lhs `,` $rhs `:` qualified(type($lhs)) attr-dict
Expand Down
55 changes: 55 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,59 @@ defvar CIR_ScalarTypes = [
def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
let cppFunctionName = "isScalarType";
}

//===----------------------------------------------------------------------===//
// Vector Type predicates
//===----------------------------------------------------------------------===//

def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;

def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
"any cir integer, floating point or pointer type"
> {
let cppFunctionName = "isValidVectorTypeElementType";
}

// Element type constraint bases
class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
"::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;

class CIR_VectorTypeOf<list<Type> types, string summary = "">
: CIR_ConfinedType<CIR_AnyVectorType,
[Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
!if(!empty(summary),
"vector of " # CIR_TypeSummaries<types>.value,
summary)>;

// Vector of type constraints
def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;

// Vector or Scalar type constraints
def CIR_AnyIntOrVecOfIntType
: AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
"integer or vector of integer type"> {
let cppFunctionName = "isIntOrVectorOfIntType";
}

def CIR_AnySIntOrVecOfSIntType
: AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
"signed integer or vector of signed integer type"> {
let cppFunctionName = "isSIntOrVectorOfSIntType";
}

def CIR_AnyUIntOrVecOfUIntType
: AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
"unsigned integer or vector of unsigned integer type"> {
let cppFunctionName = "isUIntOrVectorOfUIntType";
}

def CIR_AnyFloatOrVecOfFloatType
: AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
"floating point or vector of floating point type"> {
let cppFunctionName = "isFPOrVectorOfFPType";
}

#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
3 changes: 1 addition & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ struct RecordTypeStorage;
} // namespace detail

bool isValidFundamentalIntWidth(unsigned width);
bool isFPOrFPVectorTy(mlir::Type);
bool isIntOrIntVectorTy(mlir::Type);

} // namespace cir

mlir::ParseResult parseAddrSpaceAttribute(mlir::AsmParser &p,
Expand Down
70 changes: 23 additions & 47 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,31 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",

let summary = "CIR vector type";
let description = [{
`cir.vector' represents fixed-size vector types. The parameters are the
element type and the number of elements.
The `!cir.vector` type represents a fixed-size, one-dimensional vector.
It takes two parameters: the element type and the number of elements.

Syntax:

```mlir
vector-type ::= !cir.vector<element-type x size>
element-type ::= float-type | integer-type | pointer-type
```

The `element-type` must be a scalar CIR type. Zero-sized vectors are not
allowed. The `size` must be a positive integer.

Examples:

```mlir
!cir.vector<!cir.int<u, 8> x 4>
!cir.vector<!cir.float x 2>
```
}];

let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);
let parameters = (ins
CIR_VectorElementType:$elementType,
"uint64_t":$size
);

let builders = [
TypeBuilderWithInferredContext<(ins
Expand Down Expand Up @@ -503,50 +523,6 @@ def CIR_VoidType : CIR_Type<"Void", "void"> {
}];
}

// Constraints

// Vector of integral type
def IntegerVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
CPred<"::mlir::cast<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())"
".isFundamental()">
]>, "!cir.vector of !cir.int"> {
}

// Vector of signed integral type
def SignedIntegerVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
CPred<"::mlir::cast<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())"
".isSignedFundamental()">
]>, "!cir.vector of !cir.int"> {
}

// Vector of Float type
def FPVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::SingleType, ::cir::DoubleType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
]>, "!cir.vector of !cir.fp"> {
}

// Constraints
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;

def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf<[
CIR_AnyFundamentalSIntType, SignedIntegerVector
]>;

def CIR_AnyFloatOrVecOfFloat: AnyTypeOf<[CIR_AnyFloatType, FPVector]>;

//===----------------------------------------------------------------------===//
// RecordType
//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,

case Builtin::BI__builtin_elementwise_abs: {
mlir::Type cirTy = convertType(E->getArg(0)->getType());
bool isIntTy = cir::isIntOrIntVectorTy(cirTy);
bool isIntTy = cir::isIntOrVectorOfIntType(cirTy);
if (!isIntTy) {
return emitUnaryFPBuiltin<cir::FAbsOp>(*this, *E);
}
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4015,7 +4015,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
mlir::Location loc = getLoc(E->getExprLoc());
Ops[0] = builder.createBitcast(Ops[0], ty);
Ops[1] = builder.createBitcast(Ops[1], ty);
if (cir::isFPOrFPVectorTy(ty)) {
if (cir::isFPOrVectorOfFPType(ty)) {
return builder.create<cir::FMaximumOp>(loc, Ops[0], Ops[1]);
}
return builder.create<cir::BinOp>(loc, cir::BinOpKind::Max, Ops[0], Ops[1]);
Expand All @@ -4026,7 +4026,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
case NEON::BI__builtin_neon_vmin_v:
case NEON::BI__builtin_neon_vminq_v: {
llvm::StringRef name = usgn ? "aarch64.neon.umin" : "aarch64.neon.smin";
if (cir::isFPOrFPVectorTy(ty))
if (cir::isFPOrVectorOfFPType(ty))
name = "aarch64.neon.fmin";
return emitNeonCall(builder, {ty, ty}, Ops, name, ty,
getLoc(E->getExprLoc()));
Expand All @@ -4037,7 +4037,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
case NEON::BI__builtin_neon_vabd_v:
case NEON::BI__builtin_neon_vabdq_v: {
llvm::StringRef name = usgn ? "aarch64.neon.uabd" : "aarch64.neon.sabd";
if (cir::isFPOrFPVectorTy(ty))
if (cir::isFPOrVectorOfFPType(ty))
name = "aarch64.neon.fabd";
return emitNeonCall(builder, {ty, ty}, Ops, name, ty,
getLoc(E->getExprLoc()));
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &Ops) {
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

if (cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
if (cir::isFPOrVectorOfFPType(Ops.LHS.getType())) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFMul(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1414,7 +1414,7 @@ mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &Ops) {
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

if (cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
if (cir::isFPOrVectorOfFPType(Ops.LHS.getType())) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFAdd(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1457,7 +1457,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

if (cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
if (cir::isFPOrVectorOfFPType(Ops.LHS.getType())) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down
38 changes: 1 addition & 37 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,17 +421,7 @@ mlir::LogicalResult cir::VectorType::verify(
mlir::Type elementType, uint64_t size) {
if (size == 0)
return emitError() << "the number of vector elements must be non-zero";

// Check if it a valid FixedVectorType
if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
return success();

// Check if it a valid VectorType
if (mlir::isa<cir::IntType>(elementType) ||
isAnyFloatingPointType(elementType))
return success();

return emitError() << "unsupported element type for CIR vector";
return success();
}
// TODO(cir): Implement a way to cache the datalayout info calculated below.

Expand Down Expand Up @@ -758,32 +748,6 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
.getABIAlignment(dataLayout, params);
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isFPOrFPVectorTy(mlir::Type t) {

if (isa<cir::VectorType>(t)) {
return isAnyFloatingPointType(
mlir::dyn_cast<cir::VectorType>(t).getElementType());
}
return isAnyFloatingPointType(t);
}

//===----------------------------------------------------------------------===//
// CIR Integer and Integer Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isIntOrIntVectorTy(mlir::Type t) {

if (isa<cir::VectorType>(t)) {
return isa<cir::IntType>(
mlir::dyn_cast<cir::VectorType>(t).getElementType());
}
return isa<cir::IntType>(t);
}

//===----------------------------------------------------------------------===//
// ComplexType Definitions
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading