Skip to content

Commit 1cee446

Browse files
[fixup] Address misc review comments
1 parent 8e7a91d commit 1cee446

File tree

4 files changed

+27
-26
lines changed

4 files changed

+27
-26
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -838,8 +838,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
838838
///
839839
bool areTrailingDimsContiguous(int64_t n);
840840

841-
/// Return the maximum number of trailing dimensions that are
842-
/// contiguous.
841+
/// Return the number of trailing dimensions that are contiguous.
843842
///
844843
/// Examples:
845844
/// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
@@ -856,7 +855,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
856855
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
857856
/// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
858857
/// trailing dimensions is 2 (stride 0 is dynamic)
859-
int64_t getMaxContiguousTrailingDims();
858+
int64_t getNumContiguousTrailingDims();
860859

861860
/// Return a version of this type with identity layout if it can be
862861
/// determined statically that the layout is the canonical contiguous

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
6969
//===----------------------------------------------------------------------===//
7070

7171
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
72-
assert((sizes.size() == 0 ||
72+
assert((sizes.empty() ||
7373
llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
7474
"sizes must be nonnegative");
7575
int64_t unit = 1;

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,10 +646,12 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
646646
}
647647

648648
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649-
return getMaxContiguousTrailingDims() >= std::min(n, getRank());
649+
assert(n <= getRank() &&
650+
"number of dimensions to check must not exceed rank");
651+
return n <= getNumContiguousTrailingDims();
650652
}
651653

652-
int64_t MemRefType::getMaxContiguousTrailingDims() {
654+
int64_t MemRefType::getNumContiguousTrailingDims() {
653655
const int64_t n = getRank();
654656

655657
// memrefs with identity layout are entirely contiguous.

mlir/unittests/Dialect/MemRef/LayoutTest.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,75 +27,75 @@ TEST(MemRefLayout, maxContigDim) {
2727

2828
// memref<2x2x2xf32, strided<[4,2,1]>
2929
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
30-
EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
30+
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
3131

3232
// memref<2x2x2xf32, strided<[8,2,1]>
3333
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
34-
EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 2);
34+
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
3535

3636
// memref<2x2x2xf32, strided<[8,4,1]>
3737
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
38-
EXPECT_EQ(m3.getMaxContiguousTrailingDims(), 1);
38+
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
3939

4040
// memref<2x2x2xf32, strided<[8,4,2]>
4141
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
42-
EXPECT_EQ(m4.getMaxContiguousTrailingDims(), 0);
42+
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
4343

4444
// memref<2x2x?xf32, strided<[?,?,1]>
4545
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
46-
EXPECT_EQ(m5.getMaxContiguousTrailingDims(), 1);
46+
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
4747

4848
// memref<2x2x?xf32, strided<[?,?,2]>
4949
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
50-
EXPECT_EQ(m6.getMaxContiguousTrailingDims(), 0);
50+
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
5151

5252
// memref<2x?x2xf32, strided<[?,2,1]>
5353
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
54-
EXPECT_EQ(m7.getMaxContiguousTrailingDims(), 2);
54+
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
5555

5656
// memref<2x?x2xf32, strided<[?,4,1]>
5757
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
58-
EXPECT_EQ(m8.getMaxContiguousTrailingDims(), 1);
58+
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
5959

6060
// memref<2x?x2xf32, strided<[?,4,2]>
6161
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
62-
EXPECT_EQ(m9.getMaxContiguousTrailingDims(), 0);
62+
EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
6363

6464
// memref<?x2x2xf32, strided<[4,2,1]>
6565
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
66-
EXPECT_EQ(m10.getMaxContiguousTrailingDims(), 3);
66+
EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
6767

6868
// memref<?x2x2xf32, strided<[8,2,1]>
6969
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
70-
EXPECT_EQ(m11.getMaxContiguousTrailingDims(), 2);
70+
EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
7171

7272
// memref<?x2x2xf32, strided<[8,4,1]>
7373
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
74-
EXPECT_EQ(m12.getMaxContiguousTrailingDims(), 1);
74+
EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
7575

7676
// memref<?x2x2xf32, strided<[8,4,2]>
7777
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
78-
EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
78+
EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
7979

8080
// memref<2x2x1xf32, strided<[2,1,2]>
8181
auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
82-
EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3);
82+
EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
8383

8484
// memref<2x2x1xf32, strided<[2,1,?]>
8585
auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
86-
EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3);
86+
EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
8787

8888
// memref<2x2x1xf32, strided<[4,2,2]>
8989
auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
90-
EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1);
90+
EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
9191

9292
// memref<2x1x2xf32, strided<[2,4,1]>
9393
auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
94-
EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3);
94+
EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
9595

9696
// memref<2x1x2xf32, strided<[2,?,1]>
9797
auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
98-
EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3);
98+
EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
9999
}
100100

101101
TEST(MemRefLayout, contigTrailingDim) {
@@ -196,14 +196,14 @@ TEST(MemRefLayout, identityMaps) {
196196

197197
// memref<2x2x2xf32>
198198
auto m1 = MemRefType::get({2, 2, 2}, f32);
199-
EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
199+
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
200200
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
201201
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
202202
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
203203

204204
// memref<?x?x?xf32>
205205
auto m2 = MemRefType::get({_, _, _}, f32);
206-
EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 3);
206+
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 3);
207207
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
208208
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
209209
EXPECT_TRUE(m2.areTrailingDimsContiguous(3));

0 commit comments

Comments
 (0)