Skip to content

[mlir][vector] Tighten the semantics of vector.{load|store} #135151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5099,6 +5099,10 @@ LogicalResult vector::LoadOp::verify() {
if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();

if (memRefTy.getRank() < resVecTy.getRank())
return emitOpError(
"destination memref has lower rank than the result vector");

// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
Expand Down Expand Up @@ -5131,6 +5135,9 @@ LogicalResult vector::StoreOp::verify() {
if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();

if (memRefTy.getRank() < valueVecTy.getRank())
return emitOpError("source memref has lower rank than the vector to store");

// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
Expand Down
12 changes: 0 additions & 12 deletions mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -718,18 +718,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16

// -----

// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xi8>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could follow up to apply the same changes to ArmSME. There's some unused code ArmSMEToSCF.cpp for dealing with rank 1 memrefs (e.g. in getMemrefIndices) and it somewhat broken: #118769.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need a few more cycles for this, but have posted a draft here: #135396

Let me know if you plan to review this further. If not, I will land this to unblock #134389.

Thanks for the pointer, Ben! 🙏🏻

func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
return %tile : vector<[16]x[16]xi8>
}

// -----

// CHECK-LABEL: @vector_load_i16(
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
Expand Down
67 changes: 45 additions & 22 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -819,18 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind

// -----

func.func @fold_vector_load_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
%off1 : index,
%off2 : index,
%dim1 : index,
%dim2 : index,
%idx : index) -> vector<12x32xf32> {

%0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%1 = vector.load %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
}

// CHECK: func @fold_vector_load_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
// CHECK: #[[$ATTR_46:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-LABEL: func.func @fold_vector_load_subview(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
// CHECK-SAME: %[[OFF_1:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[DIM_1:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[DIM_2:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index) -> vector<12x32xf32> {
// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_1]], %[[IDX]]]
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_2]], %[[IDX]]]
// CHECK: %[[VAL_8:.*]] = vector.load %[[SRC]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref<24x64xf32>, vector<12x32xf32>

// -----

Expand All @@ -851,20 +862,32 @@ func.func @fold_vector_maskedload_subview(

// -----

func.func @fold_vector_store_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
return
func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
%off1 : index,
%off2 : index,
%vec: vector<2x32xf32>,
%idx : index,
%dim1 : index,
%dim2 : index) -> () {

%0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
vector.store %vec, %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>> , vector<2x32xf32>
return
}

// CHECK: func @fold_vector_store_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<2x32xf32>
// CHECK: return
// CHECK: #[[$ATTR_47:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>

// CHECK-LABEL: func.func @fold_vector_store_subview(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
// CHECK-SAME: %[[OFF1:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[VEC:[a-zA-Z0-9$._-]*]]: vector<2x32xf32>,
// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[VAL_5:[a-zA-Z0-9$._-]*]]: index,
// CHECK-SAME: %[[VAL_6:[a-zA-Z0-9$._-]*]]: index) {
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF1]], %[[IDX]]]
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF_2]], %[[IDX]]]
// CHECK: vector.store %[[VEC]], %[[SRC]]{{\[}}%[[VAL_7]], %[[VAL_8]]] : memref<24x64xf32>, vector<2x32xf32>

// -----

Expand Down
32 changes: 28 additions & 4 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1743,13 +1743,11 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {

// -----

func.func @invalid_outerproduct1(%src : memref<?xf32>) {
func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32>, %rhs : vector<[4]xf32>) {
%idx = arith.constant 0 : index
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
%1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>

// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
%op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
%op = vector.outerproduct %lhs, %rhs : vector<[4]x[4]xf32>, vector<[4]xf32>
}

// -----
Expand Down Expand Up @@ -1870,3 +1868,29 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
: vector<[16]xf32> -> vector<[16]xf32>
return %0 : vector<[16]xf32>
}

// -----

//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//

func.func @vector_load(%src : memref<?xi8>) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{'vector.load' op destination memref has lower rank than the result vector}}
%0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8>
return
}

// -----

//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//

func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{'vector.store' op source memref has lower rank than the vector to store}}
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
// CHECK-SAME: %[[MEM:.*]]: memref<f32>
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1xf32>
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
// CHECK-SAME: %[[VEC:.*]]: vector<f32>
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<f32>) {
%f0 = arith.constant 0.0 : f32

// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
Expand All @@ -12,8 +12,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>

// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<f32>
vector.store %vec, %mem[] : memref<f32>, vector<f32>

return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ func.func @entry() {

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%svl_s = arm_sme.streaming_vl <word>
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
%mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>

// Fill each "row" of "mem1" with row number.
//
Expand All @@ -29,15 +28,15 @@ func.func @entry() {
// 3, 3, 3, 3
//
%init_0 = arith.constant 0 : i32
scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
%val_next = arith.addi %val, %c1_i32 : i32
scf.yield %val_next : i32
}

// Load tile from "mem1".
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
%tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

// Transpose tile.
%transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
Expand Down
Loading