Skip to content

Commit 511fa23

Browse files
committed
[mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding.
In the past, it was hard to set padding values because we did not have ub.poison. It is not always correct if we set zeros as padding values. Now we can use `ub.poison` in this case. The revision adds the support for setting padding value using `ub.poison` when padding is required in the propagation. Otherwise, it creats an invalid pack op. Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 6af5b41 commit 511fa23

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/UB/IR/UBOps.h"
1515
#include "mlir/Dialect/Utils/IndexingUtils.h"
1616
#include "mlir/IR/Dominance.h"
17+
#include "mlir/IR/TypeUtilities.h"
1718
#include "llvm/ADT/SetOperations.h"
1819
#include "llvm/ADT/SetVector.h"
1920
#include "llvm/ADT/TypeSwitch.h"
@@ -289,9 +290,11 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
289290

290291
auto empty = linalg::PackOp::createDestinationTensor(
291292
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
292-
auto packedOperand = linalg::PackOp::create(
293-
b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
294-
/*padding=*/std::nullopt, outerDimsPerm);
293+
auto poison = ub::PoisonOp::create(
294+
b, loc, getElementTypeOrSelf(opOperand->get().getType()));
295+
auto packedOperand =
296+
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
297+
innerTileSizes, poison, outerDimsPerm);
295298
return std::make_tuple(packedOperand, indexingMap);
296299
}
297300

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,33 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
14501450

14511451
// -----
14521452

1453+
#map = affine_map<(d0, d1) -> (d0, d1)>
1454+
func.func @push_unpack_in_padded_domain_multiple_inputs(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<8x64xf32>, %arg2: tensor<8x64xf32>) -> tensor<8x64xf32> {
1455+
%0 = tensor.empty() : tensor<8x64xf32>
1456+
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<1x4x16x16xf32> -> tensor<8x64xf32>
1457+
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg1, %unpack : tensor<8x64xf32>, tensor<8x64xf32>) outs(%arg2 : tensor<8x64xf32>) {
1458+
^bb0(%in: f32, %in_0: f32, %out: f32):
1459+
%2 = arith.addf %in, %in_0 : f32
1460+
linalg.yield %2 : f32
1461+
} -> tensor<8x64xf32>
1462+
return %1 : tensor<8x64xf32>
1463+
}
1464+
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_multiple_inputs
1465+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1466+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1467+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1468+
// CHECK-DAG: %[[POISON:.+]] = ub.poison : f32
1469+
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] padding_value(%[[POISON]] : f32)
1470+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16]
1471+
// CHECK: %[[ELEM:.+]] = linalg.generic
1472+
// CHECK: ins(%[[PACK]], %[[ARG0]]
1473+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ELEM]]
1474+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16]
1475+
// CHECK-SAME: into %[[ARG2]]
1476+
// CHECK: return %[[UNPACK]]
1477+
1478+
// -----
1479+
14531480
module {
14541481
func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
14551482
%extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
@@ -1473,7 +1500,7 @@ module {
14731500
// CHECK: } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
14741501
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
14751502
// CHECK: %[[GENERIC:.+]] = linalg.generic
1476-
// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
1503+
// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
14771504
// CHECK-SAME: outs(%[[EMPTY]]
14781505
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
14791506
// CHECK: return %[[EXTRACT]]
@@ -1492,7 +1519,7 @@ func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32
14921519

14931520
// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
14941521
// CHECK: %[[GENERIC:.+]] = linalg.generic
1495-
// CHECK: return %[[GENERIC]]
1522+
// CHECK: return %[[GENERIC]]
14961523

14971524
// -----
14981525

@@ -1508,7 +1535,7 @@ func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32
15081535

15091536
// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
15101537
// CHECK: %[[GENERIC:.+]] = linalg.generic
1511-
// CHECK: return %[[GENERIC]]
1538+
// CHECK: return %[[GENERIC]]
15121539

15131540
// -----
15141541

@@ -1575,7 +1602,7 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
15751602

15761603
// CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand
15771604
// CHECK: %[[GENERIC:.+]] = linalg.generic
1578-
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
1605+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
15791606
// CHECK: return %[[EXTRACT]]
15801607

15811608
// -----

0 commit comments

Comments
 (0)