Skip to content

Commit 5d210ef

Browse files
committed
revirew comments
1 parent a6cd3f1 commit 5d210ef

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,10 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
10881088
if (rankOffset < 0)
10891089
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
10901090

1091-
auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
1091+
auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
10921092
int64_t finalRank = 0;
1093-
if (resVecType)
1094-
finalRank = resVecType.getRank();
1093+
if (extractVecType)
1094+
finalRank = extractVecType.getRank();
10951095

10961096
SmallVector<Value> indices = loadOp.getIndices();
10971097
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
@@ -1113,8 +1113,8 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11131113
}
11141114

11151115
Value base = loadOp.getBase();
1116-
if (resVecType) {
1117-
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
1116+
if (extractVecType) {
1117+
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
11181118
indices);
11191119
} else {
11201120
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
@@ -1136,7 +1136,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11361136
/// ```
11371137
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
11381138
/// ```
1139-
class StoreFromSplatOrBroadcast final
1139+
class StoreOpFromSplatOrBroadcast final
11401140
: public OpRewritePattern<vector::StoreOp> {
11411141
public:
11421142
using OpRewritePattern::OpRewritePattern;
@@ -2246,7 +2246,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
22462246

22472247
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
22482248
PatternBenefit benefit) {
2249-
patterns.add<ExtractOpFromLoad, StoreFromSplatOrBroadcast>(
2249+
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
22502250
patterns.getContext(), benefit);
22512251
}
22522252

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,9 @@ func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2
551551
return %1 : f32
552552
}
553553

554-
// CHECK-LABEL: @extract_load_vec
554+
// CHECK-LABEL: @extract_load_vec_non_zero_off
555555
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
556-
func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
556+
func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
557557
// CHECK: %[[C1:.*]] = arith.constant 1 : index
558558
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
559559
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
@@ -563,9 +563,9 @@ func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index)
563563
return %1 : vector<4xf32>
564564
}
565565

566-
// CHECK-LABEL: @extract_load_scalar_high_rank
566+
// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
567567
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
568-
func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
568+
func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
569569
// CHECK: %[[C1:.*]] = arith.constant 1 : index
570570
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
571571
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
@@ -587,9 +587,9 @@ func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %a
587587
return %1 : vector<4xf32>
588588
}
589589

590-
// CHECK-LABEL: @negative_extract_load_scalar_from_vec_memref
590+
// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
591591
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592-
func.func @negative_extract_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
592+
func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
593593
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594594
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595595
// CHECK: return %[[EXT]] : f32
@@ -621,7 +621,7 @@ func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) ->
621621
}
622622

623623
//-----------------------------------------------------------------------------
624-
// [Pattern: StoreFromSplat]
624+
// [Pattern: StoreOpFromSplatOrBroadcast]
625625
//-----------------------------------------------------------------------------
626626

627627
// CHECK-LABEL: @store_splat
@@ -661,9 +661,9 @@ func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f3
661661
return
662662
}
663663

664-
// CHECK-LABEL: @negative_store_vec_memref
664+
// CHECK-LABEL: @negative_store_memref_of_vec
665665
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
666-
func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
666+
func.func @negative_store_memref_of_vec(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
667667
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
668668
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
669669
%0 = vector.splat %arg2 : vector<1xf32>

0 commit comments

Comments
 (0)