Skip to content

Commit c5b35e2

Browse files
committed
[mlir][vector] Update representation of insert/extract_strided_slice
This commit updates the representation of both extract_strided_slice and insert_strided_slice to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t. This is done by adding a new `StridedSliceAttr` which matches the previous syntax and can be used for both operations. It may also be possible to explore alternate slice syntax for the `StridedSliceAttr` in future.
1 parent dac9042 commit c5b35e2

File tree

13 files changed

+228
-315
lines changed

13 files changed

+228
-315
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
include "mlir/Dialect/Vector/IR/Vector.td"
1717
include "mlir/IR/EnumAttr.td"
1818

19+
class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
20+
: AttrDef<Vector_Dialect, attrName, traits> {
21+
let mnemonic = attrMnemonic;
22+
}
23+
1924
// The "kind" of combining function for contractions and reductions.
2025
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
2126
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,42 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
8287
let assemblyFormat = "`<` $value `>`";
8388
}
8489

90+
def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
91+
{
92+
let summary = "strided vector slice";
93+
94+
let description = [{
95+
An attribute that represents a strided slice of a vector.
96+
97+
*Examples:*
98+
99+
Without sizes:
100+
101+
`{offsets = [0, 0, 2], strides = [1, 1]}`
102+
103+
With sizes (used for extract_strided_slice):
104+
105+
`{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}`
106+
107+
TODO? Come up with a range syntax (similar to Python slices).
108+
}];
109+
110+
let parameters = (ins
111+
ArrayRefParameter<"int64_t">:$offsets,
112+
OptionalArrayRefParameter<"int64_t">:$sizes,
113+
ArrayRefParameter<"int64_t">:$strides
114+
);
115+
116+
let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
117+
return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
118+
}]>
119+
];
120+
121+
let assemblyFormat = [{
122+
`{` `offsets` `=` `[` $offsets `]` `,`
123+
(`sizes` `=` `[` $sizes^ `]` `,`)?
124+
`strides` `=` `[` $strides `]` `}`
125+
}];
126+
}
127+
85128
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
10401040
PredOpTrait<"operand #0 and result have same element type",
10411041
TCresVTEtIsSameAsOpBase<0, 0>>,
10421042
AllTypesMatch<["dest", "res"]>]>,
1043-
Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
1044-
I64ArrayAttr:$strides)>,
1043+
Arguments<(ins AnyVector:$source, AnyVector:$dest,
1044+
Vector_StridedSliceAttr:$strided_slice)>,
10451045
Results<(outs AnyVector:$res)> {
10461046
let summary = "strided_slice operation";
10471047
let description = [{
@@ -1060,13 +1060,13 @@ def Vector_InsertStridedSliceOp :
10601060

10611061
```mlir
10621062
%2 = vector.insert_strided_slice %0, %1
1063-
{offsets = [0, 0, 2], strides = [1, 1]}:
1064-
vector<2x4xf32> into vector<16x4x8xf32>
1063+
{offsets = [0, 0, 2], strides = [1, 1]}
1064+
: vector<2x4xf32> into vector<16x4x8xf32>
10651065
```
10661066
}];
10671067

10681068
let assemblyFormat = [{
1069-
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
1069+
$source `,` $dest $strided_slice attr-dict `:` type($source) `into` type($dest)
10701070
}];
10711071

10721072
let builders = [
@@ -1081,10 +1081,13 @@ def Vector_InsertStridedSliceOp :
10811081
return ::llvm::cast<VectorType>(getDest().getType());
10821082
}
10831083
bool hasNonUnitStrides() {
1084-
return llvm::any_of(getStrides(), [](Attribute attr) {
1085-
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
1084+
return llvm::any_of(getStrides(), [](int64_t stride) {
1085+
return stride != 1;
10861086
});
10871087
}
1088+
1089+
ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
1090+
ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
10881091
}];
10891092

10901093
let hasFolder = 1;
@@ -1182,8 +1185,7 @@ def Vector_ExtractStridedSliceOp :
11821185
Vector_Op<"extract_strided_slice", [Pure,
11831186
PredOpTrait<"operand and result have same element type",
11841187
TCresVTEtIsSameAsOpBase<0, 0>>]>,
1185-
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
1186-
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
1188+
Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
11871189
Results<(outs AnyVector)> {
11881190
let summary = "extract_strided_slice operation";
11891191
let description = [{
@@ -1201,12 +1203,8 @@ def Vector_ExtractStridedSliceOp :
12011203

12021204
```mlir
12031205
%1 = vector.extract_strided_slice %0
1204-
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
1205-
vector<4x8x16xf32> to vector<2x4x16xf32>
1206-
1207-
// TODO: Evolve to a range form syntax similar to:
1208-
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
1209-
vector<4x8x16xf32> to vector<2x4x16xf32>
1206+
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
1207+
: vector<4x8x16xf32> to vector<2x4x16xf32>
12101208
```
12111209
}];
12121210
let builders = [
@@ -1217,17 +1215,20 @@ def Vector_ExtractStridedSliceOp :
12171215
VectorType getSourceVectorType() {
12181216
return ::llvm::cast<VectorType>(getVector().getType());
12191217
}
1220-
void getOffsets(SmallVectorImpl<int64_t> &results);
12211218
bool hasNonUnitStrides() {
1222-
return llvm::any_of(getStrides(), [](Attribute attr) {
1223-
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
1219+
return llvm::any_of(getStrides(), [](int64_t stride) {
1220+
return stride != 1;
12241221
});
12251222
}
1223+
1224+
ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
1225+
ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
1226+
ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
12261227
}];
12271228
let hasCanonicalizer = 1;
12281229
let hasFolder = 1;
12291230
let hasVerifier = 1;
1230-
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
1231+
let assemblyFormat = "$vector $strided_slice attr-dict `:` type($vector) `to` type(results)";
12311232
}
12321233

12331234
// TODO: Tighten semantics so that masks and inbounds can't be used

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
940940
return success();
941941
}
942942

943-
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
944-
SmallVectorImpl<int64_t> &results) {
945-
for (auto attr : arrayAttr)
946-
results.push_back(cast<IntegerAttr>(attr).getInt());
947-
}
948-
949943
static LogicalResult
950944
convertExtractStridedSlice(RewriterBase &rewriter,
951945
vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
996990
auto sourceVector = it->second;
997991

998992
// offset and sizes at warp-level of onwership.
999-
SmallVector<int64_t> offsets;
1000-
populateFromInt64AttrArray(op.getOffsets(), offsets);
993+
ArrayRef<int64_t> offsets = op.getOffsets();
1001994

1002-
SmallVector<int64_t> sizes;
1003-
populateFromInt64AttrArray(op.getSizes(), sizes);
1004995
ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1005996

1006997
// Compute offset in vector registers. Note that the mma.sync vector registers

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
4646
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
4747
return cast<IntegerAttr>(attr[0]).getInt();
4848
}
49-
static uint64_t getFirstIntValue(ArrayAttr attr) {
50-
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51-
}
5249
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
5350
auto attr = foldResults[0].dyn_cast<Attribute>();
5451
if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
187184
if (!dstType)
188185
return failure();
189186

190-
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
191-
uint64_t size = getFirstIntValue(extractOp.getSizes());
192-
uint64_t stride = getFirstIntValue(extractOp.getStrides());
187+
int64_t offset = extractOp.getOffsets().front();
188+
int64_t size = extractOp.getSizes().front();
189+
int64_t stride = extractOp.getStrides().front();
193190
if (stride != 1)
194191
return failure();
195192

@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
323320
Value srcVector = adaptor.getOperands().front();
324321
Value dstVector = adaptor.getOperands().back();
325322

326-
uint64_t stride = getFirstIntValue(insertOp.getStrides());
323+
uint64_t stride = insertOp.getStrides().front();
327324
if (stride != 1)
328325
return failure();
329-
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
326+
uint64_t offset = insertOp.getOffsets().front();
330327

331328
if (isa<spirv::ScalarType>(srcVector.getType())) {
332329
assert(!isa<spirv::ScalarType>(dstVector.getType()));

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,8 @@ struct ExtensionOverExtractStridedSlice final
549549
if (failed(ext))
550550
return failure();
551551

552-
VectorType origTy = op.getType();
553-
VectorType extractTy =
554-
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
555552
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
556-
op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
553+
op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
557554
op.getStrides());
558555
ext->recreateAndReplace(rewriter, op, newExtract);
559556
return success();

0 commit comments

Comments
 (0)