Skip to content

Commit fb8eb42

Browse files
authored
[mlir][ArmSME] Fix loop bounds of masked loads/stores (#78983)
Previously, for masked tile loads/stores we directly used the dimension size from the `vector.create_mask` operation as the upper bound of the `scf.for` over the tile slices. This was not correct, as `create_mask` allows operands to be greater than the size of the vector dimension, in which case the for loop bounds should be clamped to the number of tile slices.
1 parent f6290e0 commit fb8eb42

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,18 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
8585
auto maskDim0 = createMaskOp.getOperands()[0];
8686
auto maskDim1 = createMaskOp.getOperands()[1];
8787

88-
upperBound = maskDim0;
88+
// The upper bound of the loop must be clamped at `numTileSlices` as
89+
// `vector.create_mask` allows operands to be greater than the size of a
90+
// dimension.
91+
auto numRowI64 = rewriter.create<arith::IndexCastOp>(
92+
loc, rewriter.getI64Type(), maskDim0);
93+
auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
94+
loc, rewriter.getI64Type(), numTileSlices);
95+
auto upperBoundI64 =
96+
rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97+
upperBound = rewriter.create<arith::IndexCastOp>(
98+
loc, rewriter.getIndexType(), upperBoundI64);
99+
89100
predicate =
90101
rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
91102
} else {

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,17 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
3939
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
4040
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
4141
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
42+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
4243
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
44+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
45+
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
46+
// CHECK-DAG: %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
47+
// CHECK-DAG: %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
48+
// CHECK-DAG: %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
49+
// CHECK-DAG: %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
4350
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
4451
// CHECK-DAG: %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
45-
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
52+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
4653
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
4754
// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
4855
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
@@ -150,9 +157,16 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
150157
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
151158
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
152159
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
160+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
153161
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
162+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
163+
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
164+
// CHECK-DAG: %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
165+
// CHECK-DAG: %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
166+
// CHECK-DAG: %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
167+
// CHECK-DAG: %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
154168
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
155-
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
169+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] {
156170
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
157171
// CHECK-NEXT: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
158172
func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {

0 commit comments

Comments
 (0)