Skip to content

Commit cdb12b8

Browse files
committed
[mlir][ArmSME] Audit arm_sme.tile_store
Makes that the following cases are rejected (as opposed causing the mlir-opt to crash): ```mlir arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32> arm_sme.tile_store %arg0, %arg1[%c0] : memref<?xi8>, vector<[4]x[4]xi32> ``` Instead, we should be loading from a rank-2 MemRef, using 2 indices, e.g.: ```mlir arm_sme.tile_store %arg0, %arg1[%c0, %c0] : memref<?x?xi8>, vector<[4]x[4]xi32> ``` Fixes #118769
1 parent 596034d commit cdb12b8

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
369369
```
370370
}];
371371
let arguments = (ins
372-
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
372+
Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
373373
Variadic<Index>:$indices,
374374
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
375375
ArmSME_TileSliceLayoutAttr:$layout
@@ -443,7 +443,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
443443
```
444444
}];
445445
let arguments = (ins SMETile:$valueToStore,
446-
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
446+
Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
447447
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
448448
ArmSME_TileSliceLayoutAttr:$layout
449449
);

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
3333
Value tileSliceIndex,
3434
Value tileSliceNumElts, Location loc,
3535
PatternRewriter &rewriter) {
36-
assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
36+
assert(rank == 2 && "memref has unexpected rank!");
3737
SmallVector<Value, 2> outIndices;
3838

3939
auto tileSliceOffset = tileSliceIndex;
40-
if (rank == 1)
41-
tileSliceOffset =
42-
rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
4340

4441
auto baseIndexPlusTileSliceOffset =
4542
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
4643
outIndices.push_back(baseIndexPlusTileSliceOffset);
47-
48-
if (rank == 2)
49-
outIndices.push_back(indices[1]);
44+
outIndices.push_back(indices[1]);
5045

5146
return outIndices;
5247
}
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
6055
makeLoopBody) {
6156
PatternRewriter::InsertionGuard guard(rewriter);
6257

58+
// TODO: This case should be captured and rejected by a verifier.
59+
if (memrefIndices.size() != 2)
60+
return rewriter.notifyMatchFailure(loc, "invalid number of indices");
61+
6362
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
6463
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
6564
auto vscale =

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
111111
return
112112
}
113113

114+
// -----
115+
116+
func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
117+
%c0 = arith.constant 0 : index
118+
// expected-error@+1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
119+
%tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
120+
return
121+
}
122+
114123
//===----------------------------------------------------------------------===//
115124
// arm_sme.load_tile_slice
116125
//===----------------------------------------------------------------------===//
@@ -138,6 +147,15 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
138147
return
139148
}
140149

150+
// -----
151+
152+
func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
153+
%c0 = arith.constant 0 : index
154+
// expected-error@+1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
155+
arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
156+
return
157+
}
158+
141159
//===----------------------------------------------------------------------===//
142160
// arm_sme.store_tile_slice
143161
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)