Skip to content

Commit a99c839

Browse files
authored
[CIR] Refactor vector type constraints (#1626)
1 parent b7bc94c commit a99c839

File tree

11 files changed

+128
-116
lines changed

11 files changed

+128
-116
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,9 +1237,13 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
12371237
```
12381238
}];
12391239

1240-
let results = (outs CIR_AnyIntOrVecOfInt:$result);
1241-
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
1242-
UnitAttr:$isShiftleft);
1240+
let arguments = (ins
1241+
CIR_AnyIntOrVecOfIntType:$value,
1242+
CIR_AnyIntOrVecOfIntType:$amount,
1243+
UnitAttr:$isShiftleft
1244+
);
1245+
1246+
let results = (outs CIR_AnyIntOrVecOfIntType:$result);
12431247

12441248
let assemblyFormat = [{
12451249
`(`
@@ -3125,7 +3129,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
31253129

31263130
let arguments = (ins
31273131
CIR_VectorType:$vec,
3128-
AnyType:$value,
3132+
CIR_VectorElementType:$value,
31293133
CIR_AnyFundamentalIntType:$index
31303134
);
31313135

@@ -3158,7 +3162,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
31583162
CIR_AnyFundamentalIntType:$index
31593163
);
31603164

3161-
let results = (outs CIR_AnyType:$result);
3165+
let results = (outs CIR_VectorElementType:$result);
31623166

31633167
let assemblyFormat = [{
31643168
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
@@ -3181,7 +3185,7 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
31813185
in the vector type.
31823186
}];
31833187

3184-
let arguments = (ins Variadic<CIR_AnyType>:$elements);
3188+
let arguments = (ins Variadic<CIR_VectorElementType>:$elements);
31853189
let results = (outs CIR_VectorType:$result);
31863190

31873191
let assemblyFormat = [{
@@ -3211,7 +3215,7 @@ def VecSplatOp : CIR_Op<"vec.splat", [Pure,
32113215
All elements of the vector have the same value, that of the given scalar.
32123216
}];
32133217

3214-
let arguments = (ins CIR_AnyType:$value);
3218+
let arguments = (ins CIR_VectorElementType:$value);
32153219
let results = (outs CIR_VectorType:$result);
32163220

32173221
let assemblyFormat = [{
@@ -3264,8 +3268,13 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
32643268
The result is a vector of the same type as the second and third arguments.
32653269
Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
32663270
}];
3267-
let arguments = (ins IntegerVector:$cond, CIR_VectorType:$vec1,
3268-
CIR_VectorType:$vec2);
3271+
3272+
let arguments = (ins
3273+
CIR_VectorOfIntType:$cond,
3274+
CIR_VectorType:$vec1,
3275+
CIR_VectorType:$vec2
3276+
);
3277+
32693278
let results = (outs CIR_VectorType:$result);
32703279
let assemblyFormat = [{
32713280
`(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
@@ -3328,7 +3337,7 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
33283337
result vector is constructed by taking the elements from the first input
33293338
vector from the indices indicated by the elements of the second vector.
33303339
}];
3331-
let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
3340+
let arguments = (ins CIR_VectorType:$vec, CIR_VectorOfIntType:$indices);
33323341
let results = (outs CIR_VectorType:$result);
33333342
let assemblyFormat = [{
33343343
$vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))
@@ -4712,8 +4721,8 @@ def LLrintOp : UnaryFPToIntBuiltinOp<"llrint", "LlrintOp">;
47124721

47134722
class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
47144723
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
4715-
let arguments = (ins CIR_AnyFloatOrVecOfFloat:$src);
4716-
let results = (outs CIR_AnyFloatOrVecOfFloat:$result);
4724+
let arguments = (ins CIR_AnyFloatOrVecOfFloatType:$src);
4725+
let results = (outs CIR_AnyFloatOrVecOfFloatType:$result);
47174726
let summary = "libc builtin equivalent ignoring "
47184727
"floating point exceptions and errno";
47194728
let assemblyFormat = "$src `:` type($src) attr-dict";
@@ -4743,8 +4752,6 @@ def TanOp : UnaryFPToFPBuiltinOp<"tan", "TanOp">;
47434752
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;
47444753

47454754
def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
4746-
let arguments = (ins CIR_AnySignedIntOrVecOfSignedInt:$src, UnitAttr:$poison);
4747-
let results = (outs CIR_AnySignedIntOrVecOfSignedInt:$result);
47484755
let summary = [{
47494756
libc builtin equivalent abs, labs, llabs
47504757

@@ -4760,6 +4767,14 @@ def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
47604767
%2 = cir.abs %3 : !cir.vector<!s32i x 4>
47614768
```
47624769
}];
4770+
4771+
let arguments = (ins
4772+
CIR_AnySIntOrVecOfSIntType:$src,
4773+
UnitAttr:$poison
4774+
);
4775+
4776+
let results = (outs CIR_AnySIntOrVecOfSIntType:$result);
4777+
47634778
let assemblyFormat = "$src ( `poison` $poison^ )? `:` type($src) attr-dict";
47644779
}
47654780

@@ -4769,9 +4784,12 @@ class BinaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
47694784
libc builtin equivalent ignoring floating-point exceptions and errno.
47704785
}];
47714786

4772-
let arguments = (ins CIR_AnyFloatOrVecOfFloat:$lhs,
4773-
CIR_AnyFloatOrVecOfFloat:$rhs);
4774-
let results = (outs CIR_AnyFloatOrVecOfFloat:$result);
4787+
let arguments = (ins
4788+
CIR_AnyFloatOrVecOfFloatType:$lhs,
4789+
CIR_AnyFloatOrVecOfFloatType:$rhs
4790+
);
4791+
4792+
let results = (outs CIR_AnyFloatOrVecOfFloatType:$result);
47754793

47764794
let assemblyFormat = [{
47774795
$lhs `,` $rhs `:` qualified(type($lhs)) attr-dict

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,59 @@ defvar CIR_ScalarTypes = [
275275
def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
276276
let cppFunctionName = "isScalarType";
277277
}
278+
279+
//===----------------------------------------------------------------------===//
280+
// Vector Type predicates
281+
//===----------------------------------------------------------------------===//
282+
283+
def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
284+
285+
def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
286+
"any cir integer, floating point or pointer type"
287+
> {
288+
let cppFunctionName = "isValidVectorTypeElementType";
289+
}
290+
291+
// Element type constraint bases
292+
class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
293+
"::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
294+
295+
class CIR_VectorTypeOf<list<Type> types, string summary = "">
296+
: CIR_ConfinedType<CIR_AnyVectorType,
297+
[Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
298+
!if(!empty(summary),
299+
"vector of " # CIR_TypeSummaries<types>.value,
300+
summary)>;
301+
302+
// Vector of type constraints
303+
def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
304+
def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
305+
def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
306+
def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
307+
308+
// Vector or Scalar type constraints
309+
def CIR_AnyIntOrVecOfIntType
310+
: AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
311+
"integer or vector of integer type"> {
312+
let cppFunctionName = "isIntOrVectorOfIntType";
313+
}
314+
315+
def CIR_AnySIntOrVecOfSIntType
316+
: AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
317+
"signed integer or vector of signed integer type"> {
318+
let cppFunctionName = "isSIntOrVectorOfSIntType";
319+
}
320+
321+
def CIR_AnyUIntOrVecOfUIntType
322+
: AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
323+
"unsigned integer or vector of unsigned integer type"> {
324+
let cppFunctionName = "isUIntOrVectorOfUIntType";
325+
}
326+
327+
def CIR_AnyFloatOrVecOfFloatType
328+
: AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
329+
"floating point or vector of floating point type"> {
330+
let cppFunctionName = "isFPOrVectorOfFPType";
331+
}
332+
278333
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ struct RecordTypeStorage;
2626
} // namespace detail
2727

2828
bool isValidFundamentalIntWidth(unsigned width);
29-
bool isFPOrFPVectorTy(mlir::Type);
30-
bool isIntOrIntVectorTy(mlir::Type);
29+
3130
} // namespace cir
3231

3332
mlir::ParseResult parseAddrSpaceAttribute(mlir::AsmParser &p,

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,31 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",
345345

346346
let summary = "CIR vector type";
347347
let description = [{
348-
`cir.vector' represents fixed-size vector types. The parameters are the
349-
element type and the number of elements.
348+
The `!cir.vector` type represents a fixed-size, one-dimensional vector.
349+
It takes two parameters: the element type and the number of elements.
350+
351+
Syntax:
352+
353+
```mlir
354+
vector-type ::= !cir.vector<element-type x size>
355+
element-type ::= float-type | integer-type | pointer-type
356+
```
357+
358+
The `element-type` must be a scalar CIR type. Zero-sized vectors are not
359+
allowed. The `size` must be a positive integer.
360+
361+
Examples:
362+
363+
```mlir
364+
!cir.vector<!cir.int<u, 8> x 4>
365+
!cir.vector<!cir.float x 2>
366+
```
350367
}];
351368

352-
let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);
369+
let parameters = (ins
370+
CIR_VectorElementType:$elementType,
371+
"uint64_t":$size
372+
);
353373

354374
let builders = [
355375
TypeBuilderWithInferredContext<(ins
@@ -503,50 +523,6 @@ def CIR_VoidType : CIR_Type<"Void", "void"> {
503523
}];
504524
}
505525

506-
// Constraints
507-
508-
// Vector of integral type
509-
def IntegerVector : Type<
510-
And<[
511-
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
512-
CPred<"::mlir::isa<::cir::IntType>("
513-
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
514-
CPred<"::mlir::cast<::cir::IntType>("
515-
"::mlir::cast<::cir::VectorType>($_self).getElementType())"
516-
".isFundamental()">
517-
]>, "!cir.vector of !cir.int"> {
518-
}
519-
520-
// Vector of signed integral type
521-
def SignedIntegerVector : Type<
522-
And<[
523-
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
524-
CPred<"::mlir::isa<::cir::IntType>("
525-
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
526-
CPred<"::mlir::cast<::cir::IntType>("
527-
"::mlir::cast<::cir::VectorType>($_self).getElementType())"
528-
".isSignedFundamental()">
529-
]>, "!cir.vector of !cir.int"> {
530-
}
531-
532-
// Vector of Float type
533-
def FPVector : Type<
534-
And<[
535-
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
536-
CPred<"::mlir::isa<::cir::SingleType, ::cir::DoubleType>("
537-
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
538-
]>, "!cir.vector of !cir.fp"> {
539-
}
540-
541-
// Constraints
542-
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
543-
544-
def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf<[
545-
CIR_AnyFundamentalSIntType, SignedIntegerVector
546-
]>;
547-
548-
def CIR_AnyFloatOrVecOfFloat: AnyTypeOf<[CIR_AnyFloatType, FPVector]>;
549-
550526
//===----------------------------------------------------------------------===//
551527
// RecordType
552528
//

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
14821482

14831483
case Builtin::BI__builtin_elementwise_abs: {
14841484
mlir::Type cirTy = convertType(E->getArg(0)->getType());
1485-
bool isIntTy = cir::isIntOrIntVectorTy(cirTy);
1485+
bool isIntTy = cir::isIntOrVectorOfIntType(cirTy);
14861486
if (!isIntTy) {
14871487
return emitUnaryFPBuiltin<cir::FAbsOp>(*this, *E);
14881488
}

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4015,7 +4015,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
40154015
mlir::Location loc = getLoc(E->getExprLoc());
40164016
Ops[0] = builder.createBitcast(Ops[0], ty);
40174017
Ops[1] = builder.createBitcast(Ops[1], ty);
4018-
if (cir::isFPOrFPVectorTy(ty)) {
4018+
if (cir::isFPOrVectorOfFPType(ty)) {
40194019
return builder.create<cir::FMaximumOp>(loc, Ops[0], Ops[1]);
40204020
}
40214021
return builder.create<cir::BinOp>(loc, cir::BinOpKind::Max, Ops[0], Ops[1]);
@@ -4026,7 +4026,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
40264026
case NEON::BI__builtin_neon_vmin_v:
40274027
case NEON::BI__builtin_neon_vminq_v: {
40284028
llvm::StringRef name = usgn ? "aarch64.neon.umin" : "aarch64.neon.smin";
4029-
if (cir::isFPOrFPVectorTy(ty))
4029+
if (cir::isFPOrVectorOfFPType(ty))
40304030
name = "aarch64.neon.fmin";
40314031
return emitNeonCall(builder, {ty, ty}, Ops, name, ty,
40324032
getLoc(E->getExprLoc()));
@@ -4037,7 +4037,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
40374037
case NEON::BI__builtin_neon_vabd_v:
40384038
case NEON::BI__builtin_neon_vabdq_v: {
40394039
llvm::StringRef name = usgn ? "aarch64.neon.uabd" : "aarch64.neon.sabd";
4040-
if (cir::isFPOrFPVectorTy(ty))
4040+
if (cir::isFPOrVectorOfFPType(ty))
40414041
name = "aarch64.neon.fabd";
40424042
return emitNeonCall(builder, {ty, ty}, Ops, name, ty,
40434043
getLoc(E->getExprLoc()));

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,7 @@ mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &Ops) {
13611361
!CanElideOverflowCheck(CGF.getContext(), Ops))
13621362
llvm_unreachable("NYI");
13631363

1364-
if (cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
1364+
if (cir::isFPOrVectorOfFPType(Ops.LHS.getType())) {
13651365
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
13661366
return Builder.createFMul(Ops.LHS, Ops.RHS);
13671367
}
@@ -1414,7 +1414,7 @@ mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &Ops) {
14141414
!CanElideOverflowCheck(CGF.getContext(), Ops))
14151415
llvm_unreachable("NYI");
14161416

1417-
if (cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
1417+
if (cir::isFPOrVectorOfFPType(Ops.LHS.getType())) {
14181418
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
14191419
return Builder.createFAdd(Ops.LHS, Ops.RHS);
14201420
}
@@ -1457,7 +1457,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
14571457
!CanElideOverflowCheck(CGF.getContext(), Ops))
14581458
llvm_unreachable("NYI");
14591459

1460-
if (cir::isFPOrFPVectorTy(Ops.LHS.getType())) {
1460+
if (cir::isFPOrVectorOfFPType(Ops.LHS.getType())) {
14611461
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
14621462
return Builder.createFSub(Ops.LHS, Ops.RHS);
14631463
}

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -421,17 +421,7 @@ mlir::LogicalResult cir::VectorType::verify(
421421
mlir::Type elementType, uint64_t size) {
422422
if (size == 0)
423423
return emitError() << "the number of vector elements must be non-zero";
424-
425-
// Check if it a valid FixedVectorType
426-
if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
427-
return success();
428-
429-
// Check if it a valid VectorType
430-
if (mlir::isa<cir::IntType>(elementType) ||
431-
isAnyFloatingPointType(elementType))
432-
return success();
433-
434-
return emitError() << "unsupported element type for CIR vector";
424+
return success();
435425
}
436426
// TODO(cir): Implement a way to cache the datalayout info calculated below.
437427

@@ -758,32 +748,6 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
758748
.getABIAlignment(dataLayout, params);
759749
}
760750

761-
//===----------------------------------------------------------------------===//
762-
// Floating-point and Float-point Vector type helpers
763-
//===----------------------------------------------------------------------===//
764-
765-
bool cir::isFPOrFPVectorTy(mlir::Type t) {
766-
767-
if (isa<cir::VectorType>(t)) {
768-
return isAnyFloatingPointType(
769-
mlir::dyn_cast<cir::VectorType>(t).getElementType());
770-
}
771-
return isAnyFloatingPointType(t);
772-
}
773-
774-
//===----------------------------------------------------------------------===//
775-
// CIR Integer and Integer Vector type helpers
776-
//===----------------------------------------------------------------------===//
777-
778-
bool cir::isIntOrIntVectorTy(mlir::Type t) {
779-
780-
if (isa<cir::VectorType>(t)) {
781-
return isa<cir::IntType>(
782-
mlir::dyn_cast<cir::VectorType>(t).getElementType());
783-
}
784-
return isa<cir::IntType>(t);
785-
}
786-
787751
//===----------------------------------------------------------------------===//
788752
// ComplexType Definitions
789753
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)