Skip to content

Commit f0a401f

Browse files
committed
Upstream Refactored vector type constraints
1 parent 77d3586 commit f0a401f

File tree

5 files changed

+78
-50
lines changed

5 files changed

+78
-50
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,9 +1464,13 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
14641464
```
14651465
}];
14661466

1467-
let results = (outs CIR_AnyIntOrVecOfInt:$result);
1468-
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
1469-
UnitAttr:$isShiftleft);
1467+
let arguments = (ins
1468+
CIR_AnyIntOrVecOfIntType:$value,
1469+
CIR_AnyIntOrVecOfIntType:$amount,
1470+
UnitAttr:$isShiftleft
1471+
);
1472+
1473+
let results = (outs CIR_AnyIntOrVecOfIntType:$result);
14701474

14711475
let assemblyFormat = [{
14721476
`(`
@@ -2050,7 +2054,7 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
20502054
in the vector type.
20512055
}];
20522056

2053-
let arguments = (ins Variadic<CIR_AnyType>:$elements);
2057+
let arguments = (ins Variadic<CIR_VectorElementType>:$elements);
20542058
let results = (outs CIR_VectorType:$result);
20552059

20562060
let assemblyFormat = [{
@@ -2085,7 +2089,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
20852089

20862090
let arguments = (ins
20872091
CIR_VectorType:$vec,
2088-
AnyType:$value,
2092+
CIR_VectorElementType:$value,
20892093
CIR_AnyFundamentalIntType:$index
20902094
);
20912095

@@ -2118,7 +2122,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
21182122
}];
21192123

21202124
let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
2121-
let results = (outs CIR_AnyType:$result);
2125+
let results = (outs CIR_VectorElementType:$result);
21222126

21232127
let assemblyFormat = [{
21242128
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
@@ -2180,7 +2184,7 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
21802184
```
21812185
}];
21822186

2183-
let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
2187+
let arguments = (ins CIR_VectorType:$vec, CIR_VectorOfIntType:$indices);
21842188
let results = (outs CIR_VectorType:$result);
21852189
let assemblyFormat = [{
21862190
$vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ def CIR_PtrToVoidPtrType
198198

199199
def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
200200

201+
def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
202+
"any cir integer, floating point or pointer type"
203+
> {
204+
let cppFunctionName = "isValidVectorTypeElementType";
205+
}
206+
207+
class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
208+
"::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
209+
210+
class CIR_VectorTypeOf<list<Type> types, string summary = "">
211+
: CIR_ConfinedType<CIR_AnyVectorType,
212+
[Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
213+
!if(!empty(summary),
214+
"vector of " # CIR_TypeSummaries<types>.value,
215+
summary)>;
216+
201217
// Vector of integral type
202218
def IntegerVector : Type<
203219
And<[
@@ -210,8 +226,36 @@ def IntegerVector : Type<
210226
]>, "!cir.vector of !cir.int"> {
211227
}
212228

213-
// Any Integer or Vector of Integer Constraints
214-
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
229+
// Vector of type constraints
230+
def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
231+
def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
232+
def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
233+
def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
234+
235+
// Vector or Scalar type constraints
236+
def CIR_AnyIntOrVecOfIntType
237+
: AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
238+
"integer or vector of integer type"> {
239+
let cppFunctionName = "isIntOrVectorOfIntType";
240+
}
241+
242+
def CIR_AnySIntOrVecOfSIntType
243+
: AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
244+
"signed integer or vector of signed integer type"> {
245+
let cppFunctionName = "isSIntOrVectorOfSIntType";
246+
}
247+
248+
def CIR_AnyUIntOrVecOfUIntType
249+
: AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
250+
"unsigned integer or vector of unsigned integer type"> {
251+
let cppFunctionName = "isUIntOrVectorOfUIntType";
252+
}
253+
254+
def CIR_AnyFloatOrVecOfFloatType
255+
: AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
256+
"floating point or vector of floating point type"> {
257+
let cppFunctionName = "isFPOrVectorOfFPType";
258+
}
215259

216260
//===----------------------------------------------------------------------===//
217261
// Scalar Type predicates
@@ -225,27 +269,4 @@ def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
225269
let cppFunctionName = "isScalarType";
226270
}
227271

228-
//===----------------------------------------------------------------------===//
229-
// Element type constraint bases
230-
//===----------------------------------------------------------------------===//
231-
232-
class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
233-
"::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
234-
235-
class CIR_VectorTypeOf<list<Type> types, string summary = "">
236-
: CIR_ConfinedType<CIR_AnyVectorType,
237-
[Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
238-
!if(!empty(summary),
239-
"vector of " # CIR_TypeSummaries<types>.value,
240-
summary)>;
241-
242-
// Vector of type constraints
243-
def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
244-
245-
def CIR_AnyFloatOrVecOfFloatType
246-
: AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
247-
"floating point or vector of floating point type"> {
248-
let cppFunctionName = "isFPOrVectorOfFPType";
249-
}
250-
251272
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,31 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",
275275

276276
let summary = "CIR vector type";
277277
let description = [{
278-
`!cir.vector' represents fixed-size vector types, parameterized
279-
by the element type and the number of elements.
278+
The `!cir.vector` type represents a fixed-size, one-dimensional vector.
279+
It takes two parameters: the element type and the number of elements.
280280

281-
Example:
281+
Syntax:
282282

283283
```mlir
284-
!cir.vector<!u64i x 2>
285-
!cir.vector<!cir.float x 4>
284+
vector-type ::= !cir.vector<size x element-type>
285+
element-type ::= float-type | integer-type | pointer-type
286+
```
287+
288+
The `element-type` must be a scalar CIR type. Zero-sized vectors are not
289+
allowed. The `size` must be a positive integer.
290+
291+
Examples:
292+
293+
```mlir
294+
!cir.vector<4 x !cir.int<u, 8>>
295+
!cir.vector<2 x !cir.float>
286296
```
287297
}];
288298

289-
let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);
299+
let parameters = (ins
300+
CIR_VectorElementType:$elementType,
301+
"uint64_t":$size
302+
);
290303

291304
let assemblyFormat = [{
292305
`<` $size `x` $elementType `>`

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -684,17 +684,7 @@ mlir::LogicalResult cir::VectorType::verify(
684684
mlir::Type elementType, uint64_t size) {
685685
if (size == 0)
686686
return emitError() << "the number of vector elements must be non-zero";
687-
688-
// Check if it a valid FixedVectorType
689-
if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
690-
return success();
691-
692-
// Check if it a valid VectorType
693-
if (mlir::isa<cir::IntType>(elementType) ||
694-
isAnyFloatingPointType(elementType))
695-
return success();
696-
697-
return emitError() << "unsupported element type for CIR vector";
687+
return success();
698688
}
699689

700690
//===----------------------------------------------------------------------===//

clang/test/CIR/IR/invalid-vector.cir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
module {
66

7-
// expected-error @below {{unsupported element type for CIR vector}}
7+
// expected-error @below {{failed to verify 'elementType'}}
88
cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>
99

1010
}

0 commit comments

Comments
 (0)