Skip to content

Commit 9f4bd46

Browse files
committed
fuse tensor.pack as consumer
1 parent bf68e90 commit 9f4bd46

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,97 @@ struct PackOpTiling
246246
return failure();
247247
return tilingResult.value();
248248
}
249+
250+
/// Method to return the position of iteration domain tile computed by the
251+
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
252+
/// `resultSizes` only cover outer dimensions.
253+
LogicalResult getIterationDomainTileFromOperandTile(
254+
Operation *op, OpBuilder &b, unsigned operandNumber,
255+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
256+
SmallVectorImpl<OpFoldResult> &resultOffsets,
257+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
258+
auto packOp = cast<PackOp>(op);
259+
Location loc = packOp.getLoc();
260+
261+
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
262+
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
263+
packOp.getDimAndTileMapping();
264+
for (auto dim : packOp.getOuterDimsPerm()) {
265+
if (dimAndTileMapping.count(dim)) {
266+
FailureOr<int64_t> cstSize =
267+
ValueBoundsConstraintSet::computeConstantBound(
268+
presburger::BoundType::UB, sizes[dim],
269+
/*stopCondition=*/nullptr, /*closedUB=*/true);
270+
std::optional<int64_t> cstInnerSize =
271+
getConstantIntValue(dimAndTileMapping[dim]);
272+
// Currently only expect perfect tiling cases.
273+
if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
274+
return failure();
275+
}
276+
277+
using AV = affine::AffineValueExpr;
278+
affine::AffineBuilder ab(b, loc);
279+
AffineExpr dim0, sym;
280+
bindDims(b.getContext(), dim0);
281+
bindSymbols(b.getContext(), sym);
282+
auto avOffset = AV(dim0).bind(offsets[dim]);
283+
auto avSize = AV(dim0).bind(sizes[dim]);
284+
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
285+
outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
286+
outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
287+
} else {
288+
outerDimOffsets.push_back(offsets[dim]);
289+
outerDimSizes.push_back(sizes[dim]);
290+
}
291+
}
292+
293+
resultOffsets = outerDimOffsets;
294+
resultSizes = outerDimSizes;
295+
return success();
296+
}
297+
298+
/// Method to return the tiled implementation of tensor.pack as a consumer.
299+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
300+
Operation *op, OpBuilder &b, unsigned operandNumber,
301+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
302+
auto packOp = cast<PackOp>(op);
303+
Location loc = packOp.getLoc();
304+
305+
int64_t inputRank = packOp.getSourceRank();
306+
auto oneAttr = b.getI64IntegerAttr(1);
307+
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
308+
309+
SmallVector<Value> tiledOperands;
310+
tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
311+
offsets, sizes, strides));
312+
313+
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
314+
if (failed(getIterationDomainTileFromOperandTile(
315+
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
316+
outerDimSizes)))
317+
return failure();
318+
319+
SmallVector<OpFoldResult> outputOffsets, outputSizes;
320+
if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
321+
outputOffsets, outputSizes)))
322+
return failure();
323+
324+
strides.append(packOp.getDestRank() - inputRank, oneAttr);
325+
auto extractSlice = b.create<ExtractSliceOp>(
326+
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
327+
tiledOperands.push_back(extractSlice);
328+
329+
if (auto val = packOp.getPaddingValue())
330+
tiledOperands.push_back(val);
331+
for (auto tile : packOp.getInnerTiles())
332+
tiledOperands.push_back(tile);
333+
334+
Operation *tiledPackOp = b.create<PackOp>(
335+
loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
336+
337+
return TilingResult{{tiledPackOp},
338+
SmallVector<Value>(tiledPackOp->getResults())};
339+
}
249340
};
250341

251342
struct UnpackTileDimInfo {

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} {
315315
// CHECK: }
316316
// CHECK: }
317317
// CHECK: return %[[FINAL_RESULT]]#1 :
318+
319+
// -----
320+
321+
#map = affine_map<(d0, d1) -> (d0, d1)>
322+
module {
323+
func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
324+
%c4 = arith.constant 4 : index
325+
%c64 = arith.constant 64 : index
326+
%c0 = arith.constant 0 : index
327+
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
328+
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
329+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
330+
^bb0(%in: f32, %in_16: f32, %out: f32):
331+
%13 = arith.mulf %in, %in_16 : f32
332+
%14 = arith.addf %out, %13 : f32
333+
linalg.yield %14 : f32
334+
} -> tensor<32x32xf32>
335+
scf.forall.in_parallel {
336+
tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
337+
}
338+
}
339+
%output = tensor.empty() : tensor<4x32x16xf32>
340+
%pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
341+
return %pack : tensor<4x32x16xf32>
342+
}
343+
}
344+
345+
module attributes {transform.with_named_sequence} {
346+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
347+
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
348+
: (!transform.any_op) -> !transform.any_op
349+
%a, %b = transform.test.fuse_consumer %slice_op
350+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
351+
transform.yield
352+
}
353+
}
354+
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
355+
// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
356+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
357+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
358+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
359+
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
360+
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
361+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
362+
// CHECK-SAME: {
363+
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
364+
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
365+
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
366+
// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
367+
// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
368+
// CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
369+
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16]
370+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
371+
// CHECK: scf.forall.in_parallel {
372+
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
373+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
374+
// CHECK: }
375+
// CHECK: }
376+
// CHECK: return %[[FINAL_RESULT]]#1 :

0 commit comments

Comments
 (0)