Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,49 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}

//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp
//===----------------------------------------------------------------------===//

def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
"structured.flatten_elementwise",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Flattens the iteration space and (applicable) operands of elementwise
linalg ops to a single dimension.

Returns one handle:
- Flattened linalg operation.

#### Return modes:

Returns a definite failure if target is not isolated from above.
Returns a silenceable failure if the pattern application failed.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// Transpose Conv2D
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1074,16 +1074,20 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);

struct CollapseResult {
SmallVector<Value> results;
LinalgOp collapsedOp;
};

/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
template <typename LinalgType>
FailureOr<SmallVector<Value>>
collapseOpIterationDims(LinalgType op,
FailureOr<CollapseResult>
collapseOpIterationDims(LinalgOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3244,6 +3244,31 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp.
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
if (target.getNumLoops() <= 1)
return DiagnosedSilenceableFailure::success();
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
(isElementwise(target))
? collapseOpIterationDims(target, reassociation, rewriter)
: FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
target, "only elementwise flattening is supported"));
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(maybeFlattened->collapsedOp);
rewriter.replaceOp(target, maybeFlattened->results);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TransposeConv2DOp
//===----------------------------------------------------------------------===//
Expand Down
114 changes: 67 additions & 47 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
}

template <typename LinalgType>
Operation *createCollapsedOp(LinalgType op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to create");
void collapseOperandsAndResults(LinalgOp op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter,
SmallVectorImpl<Value> &inputOperands,
SmallVectorImpl<Value> &outputOperands,
SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();

// Get the input operands.
SmallVector<Value> inputOperands =
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
});

// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
resultTypes.reserve(op.getNumDpsInits());
outputOperands.reserve(op.getNumDpsInits());
for (OpOperand &output : op.getDpsInitsMutable()) {
Expand All @@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op,
if (!op.hasPureBufferSemantics())
resultTypes.push_back(newOutput.getType());
}
}

if (isa<linalg::CopyOp>(op)) {
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
outputOperands[0]);
}
/// Clone a `LinalgOp` to a collapsed version of same name
template <typename OpTy>
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
const CollapsingInfo &collapsingInfo) {
return nullptr;
}

// Get the iterator types for the operand.
SmallVector<utils::IteratorType> iteratorTypes =
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
/// Collapse any `LinalgOp` that does not require any specialization such as
/// indexing_maps, iterator_types, etc.
template <>
LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
const CollapsingInfo &collapsingInfo) {
SmallVector<Value> inputOperands, outputOperands;
SmallVector<Type> resultTypes;
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
outputOperands, resultTypes);
return cast<LinalgOp>(clone(
rewriter, origOp, resultTypes,
llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))));
}

// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
/// Collapse a `GenericOp`
template <>
GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
GenericOp origOp,
const CollapsingInfo &collapsingInfo) {
SmallVector<Value> inputOperands, outputOperands;
SmallVector<Type> resultTypes;
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
outputOperands, resultTypes);
SmallVector<AffineMap> indexingMaps(
llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});
}));

SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
origOp.getIteratorTypesArray(), collapsingInfo));

Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &op->getRegion(0).front();
Block *origOpBlock = &origOp->getRegion(0).front();
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
collapsedOpBlock->getArguments());

return collapsedOp;
}

LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
} else {
return cloneToCollapsedOp(rewriter, op, collapsingInfo);
}
}

/// Implementation of fusion with reshape operation by collapsing dimensions.
template <typename LinalgType>
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to collapse");

// Bail on trivial no-op cases.
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
Expand Down Expand Up @@ -1538,8 +1562,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
}

// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
Expand All @@ -1555,8 +1578,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
op, "expected all loop ranges to have zero start and unit stride");
}

LinalgType collapsedOp = cast<LinalgType>(
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);

Location loc = op->getLoc();
if (collapsedOp.hasIndexSemantics()) {
Expand Down Expand Up @@ -1597,7 +1619,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
results.push_back(collapsedOpResult);
}
}
return results;
return CollapseResult{results, collapsedOp};
}

namespace {
Expand Down Expand Up @@ -1629,15 +1651,14 @@ class FoldWithProducerReshapeOpByCollapsing
continue;
}

std::optional<SmallVector<Value>> replacements =
collapseOpIterationDims<linalg::GenericOp>(
genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
genericOp, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
}

rewriter.replaceOp(genericOp, *replacements);
rewriter.replaceOp(genericOp, collapseResult->results);
return success();
}
return failure();
Expand Down Expand Up @@ -1671,13 +1692,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
op, "specified dimensions cannot be collapsed");
}

std::optional<SmallVector<Value>> replacements =
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
rewriter);
if (!replacements) {
std::optional<CollapseResult> collapseResult =
collapseOpIterationDims(op, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
rewriter.replaceOp(op, *replacements);
rewriter.replaceOp(op, collapseResult->results);
return success();
}

Expand Down
99 changes: 99 additions & 0 deletions mlir/test/Dialect/Linalg/flatten-elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @fill(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @fill_tensor(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
%0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32>
return %0 : tensor<32x7xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @map(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
// CHECK-NEXT: linalg.yield %[[SUM]]
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
^bb0(%a: f32, %b: f32, %c: f32):
%0 = arith.addf %a, %b : f32
linalg.yield %0 : f32
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}