Skip to content

Commit c7fe252

Browse files
[mlir][vector] LoadOp/StoreOp: Allow 0-D vectors
Similar to `vector.transfer_read`/`vector.transfer_write`, allow 0-D vectors. This commit fixes `mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir` when verifying the IR after each pattern (llvm#74270). That test produces a temporary 0-D load/store op.
1 parent 7022a24 commit c7fe252

File tree

3 files changed

+67
-15
lines changed

3 files changed

+67
-15
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,22 +1582,27 @@ def Vector_LoadOp : Vector_Op<"load"> {
15821582
vector. If the memref element type is vector, it should match the result
15831583
vector type.
15841584

1585-
Example 1: 1-D vector load on a scalar memref.
1585+
Example: 0-D vector load on a scalar memref.
1586+
```mlir
1587+
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<f32>
1588+
```
1589+
1590+
Example: 1-D vector load on a scalar memref.
15861591
```mlir
15871592
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
15881593
```
15891594

1590-
Example 2: 1-D vector load on a vector memref.
1595+
Example: 1-D vector load on a vector memref.
15911596
```mlir
15921597
%result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
15931598
```
15941599

1595-
Example 3: 2-D vector load on a scalar memref.
1600+
Example: 2-D vector load on a scalar memref.
15961601
```mlir
15971602
%result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
15981603
```
15991604

1600-
Example 4: 2-D vector load on a vector memref.
1605+
Example: 2-D vector load on a vector memref.
16011606
```mlir
16021607
%result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
16031608
```
@@ -1608,12 +1613,12 @@ def Vector_LoadOp : Vector_Op<"load"> {
16081613
loaded out of bounds. Not all targets may support out-of-bounds vector
16091614
loads.
16101615

1611-
Example 5: Potential out-of-bound vector load.
1616+
Example: Potential out-of-bound vector load.
16121617
```mlir
16131618
%result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
16141619
```
16151620

1616-
Example 6: Explicit out-of-bound vector load.
1621+
Example: Explicit out-of-bound vector load.
16171622
```mlir
16181623
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
16191624
```
@@ -1622,7 +1627,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
16221627
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
16231628
[MemRead]>:$base,
16241629
Variadic<Index>:$indices);
1625-
let results = (outs AnyVector:$result);
1630+
let results = (outs AnyVectorOfAnyRank:$result);
16261631

16271632
let extraClassDeclaration = [{
16281633
MemRefType getMemRefType() {
@@ -1660,22 +1665,27 @@ def Vector_StoreOp : Vector_Op<"store"> {
16601665
to store. If the memref element type is vector, it should match the type
16611666
of the value to store.
16621667

1663-
Example 1: 1-D vector store on a scalar memref.
1668+
Example: 0-D vector store on a scalar memref.
1669+
```mlir
1670+
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
1671+
```
1672+
1673+
Example: 1-D vector store on a scalar memref.
16641674
```mlir
16651675
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
16661676
```
16671677

1668-
Example 2: 1-D vector store on a vector memref.
1678+
Example: 1-D vector store on a vector memref.
16691679
```mlir
16701680
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
16711681
```
16721682

1673-
Example 3: 2-D vector store on a scalar memref.
1683+
Example: 2-D vector store on a scalar memref.
16741684
```mlir
16751685
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
16761686
```
16771687

1678-
Example 4: 2-D vector store on a vector memref.
1688+
Example: 2-D vector store on a vector memref.
16791689
```mlir
16801690
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
16811691
```
@@ -1685,21 +1695,23 @@ def Vector_StoreOp : Vector_Op<"store"> {
16851695
target-specific. No assumptions should be made on the memory written out of
16861696
bounds. Not all targets may support out-of-bounds vector stores.
16871697

1688-
Example 5: Potential out-of-bounds vector store.
1698+
Example: Potential out-of-bounds vector store.
16891699
```mlir
16901700
vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
16911701
```
16921702

1693-
Example 6: Explicit out-of-bounds vector store.
1703+
Example: Explicit out-of-bounds vector store.
16941704
```mlir
16951705
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
16961706
```
16971707
}];
16981708

1699-
let arguments = (ins AnyVector:$valueToStore,
1709+
let arguments = (ins
1710+
AnyVectorOfAnyRank:$valueToStore,
17001711
Arg<AnyMemRef, "the reference to store to",
17011712
[MemWrite]>:$base,
1702-
Variadic<Index>:$indices);
1713+
Variadic<Index>:$indices
1714+
);
17031715

17041716
let extraClassDeclaration = [{
17051717
MemRefType getMemRefType() {

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,36 @@ func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j
20592059

20602060
// -----
20612061

2062+
func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<f32> {
2063+
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
2064+
return %0 : vector<f32>
2065+
}
2066+
2067+
// CHECK-LABEL: func @vector_load_op_0d
2068+
// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
2069+
// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
2070+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
2071+
// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
2072+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
2073+
// CHECK: return %[[cast]] : vector<f32>
2074+
2075+
// -----
2076+
2077+
func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
2078+
%val = arith.constant dense<11.0> : vector<f32>
2079+
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
2080+
return
2081+
}
2082+
2083+
// CHECK-LABEL: func @vector_store_op_0d
2084+
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
2085+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
2086+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
2087+
// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
2088+
// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
2089+
2090+
// -----
2091+
20622092
func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
20632093
%c0 = arith.constant 0: index
20642094
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,16 @@ func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
714714
return %0 : vector<16xi32>
715715
}
716716

717+
// CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
718+
func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
719+
%i : index, %j : index) {
720+
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
721+
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
722+
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
723+
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
724+
return
725+
}
726+
717727
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
718728
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
719729
%i : index, %j : index) {

0 commit comments

Comments
 (0)