Skip to content

[mlir][vector] Update representation of insert/extract_strided_slice #101850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/IR/EnumAttr.td"

class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Vector_Dialect, attrName, traits> {
let mnemonic = attrMnemonic;
}

// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
Expand Down Expand Up @@ -82,4 +87,42 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
let assemblyFormat = "`<` $value `>`";
}

def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
{
let summary = "strided vector slice";

let description = [{
An attribute that represents a strided slice of a vector.

*Examples:*

Without sizes:

`{offsets = [0, 0, 2], strides = [1, 1]}`

With sizes (used for extract_strided_slice):

`{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}`

TODO? Come up with a range syntax (similar to Python slices).
}];

let parameters = (ins
ArrayRefParameter<"int64_t">:$offsets,
OptionalArrayRefParameter<"int64_t">:$sizes,
ArrayRefParameter<"int64_t">:$strides
);

let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
}]>
];

let assemblyFormat = [{
`{` `offsets` `=` `[` $offsets `]` `,`
(`sizes` `=` `[` $sizes^ `]` `,`)?
`strides` `=` `[` $strides `]` `}`
}];
}

#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
39 changes: 20 additions & 19 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
I64ArrayAttr:$strides)>,
Arguments<(ins AnyVector:$source, AnyVector:$dest,
Vector_StridedSliceAttr:$strided_slice)>,
Results<(outs AnyVector:$res)> {
let summary = "strided_slice operation";
let description = [{
Expand All @@ -1060,13 +1060,13 @@ def Vector_InsertStridedSliceOp :

```mlir
%2 = vector.insert_strided_slice %0, %1
{offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
{offsets = [0, 0, 2], strides = [1, 1]}
: vector<2x4xf32> into vector<16x4x8xf32>
```
}];

let assemblyFormat = [{
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
$source `,` $dest $strided_slice attr-dict `:` type($source) `into` type($dest)
}];

let builders = [
Expand All @@ -1081,10 +1081,13 @@ def Vector_InsertStridedSliceOp :
return ::llvm::cast<VectorType>(getDest().getType());
}
bool hasNonUnitStrides() {
return llvm::any_of(getStrides(), [](Attribute attr) {
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
return llvm::any_of(getStrides(), [](int64_t stride) {
return stride != 1;
});
}

ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
}];

let hasFolder = 1;
Expand Down Expand Up @@ -1182,8 +1185,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
Results<(outs AnyVector)> {
let summary = "extract_strided_slice operation";
let description = [{
Expand All @@ -1201,12 +1203,8 @@ def Vector_ExtractStridedSliceOp :

```mlir
%1 = vector.extract_strided_slice %0
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
vector<4x8x16xf32> to vector<2x4x16xf32>

// TODO: Evolve to a range form syntax similar to:
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
vector<4x8x16xf32> to vector<2x4x16xf32>
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
: vector<4x8x16xf32> to vector<2x4x16xf32>
```
}];
let builders = [
Expand All @@ -1217,17 +1215,20 @@ def Vector_ExtractStridedSliceOp :
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
return llvm::any_of(getStrides(), [](Attribute attr) {
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
return llvm::any_of(getStrides(), [](int64_t stride) {
return stride != 1;
});
}

ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
let assemblyFormat = "$vector $strided_slice attr-dict `:` type($vector) `to` type(results)";
}

// TODO: Tighten semantics so that masks and inbounds can't be used
Expand Down
11 changes: 1 addition & 10 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return success();
}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
SmallVectorImpl<int64_t> &results) {
for (auto attr : arrayAttr)
results.push_back(cast<IntegerAttr>(attr).getInt());
}

static LogicalResult
convertExtractStridedSlice(RewriterBase &rewriter,
vector::ExtractStridedSliceOp op,
Expand Down Expand Up @@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
auto sourceVector = it->second;

// offset and sizes at warp-level of onwership.
SmallVector<int64_t> offsets;
populateFromInt64AttrArray(op.getOffsets(), offsets);
ArrayRef<int64_t> offsets = op.getOffsets();

SmallVector<int64_t> sizes;
populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();

// Compute offset in vector registers. Note that the mma.sync vector registers
Expand Down
13 changes: 5 additions & 8 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
return cast<IntegerAttr>(attr[0]).getInt();
}
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
auto attr = foldResults[0].dyn_cast<Attribute>();
if (attr)
Expand Down Expand Up @@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
if (!dstType)
return failure();

uint64_t offset = getFirstIntValue(extractOp.getOffsets());
uint64_t size = getFirstIntValue(extractOp.getSizes());
uint64_t stride = getFirstIntValue(extractOp.getStrides());
int64_t offset = extractOp.getOffsets().front();
int64_t size = extractOp.getSizes().front();
int64_t stride = extractOp.getStrides().front();
if (stride != 1)
return failure();

Expand Down Expand Up @@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();

uint64_t stride = getFirstIntValue(insertOp.getStrides());
uint64_t stride = insertOp.getStrides().front();
if (stride != 1)
return failure();
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
uint64_t offset = insertOp.getOffsets().front();

if (isa<spirv::ScalarType>(srcVector.getType())) {
assert(!isa<spirv::ScalarType>(dstVector.getType()));
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,8 @@ struct ExtensionOverExtractStridedSlice final
if (failed(ext))
return failure();

VectorType origTy = op.getType();
VectorType extractTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
op.getStrides());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
Expand Down
Loading
Loading