Skip to content

Commit aa22d0f

Browse files
committed
[mlir][linalg] Add NHWC + FHWC Img2Col
Adds the Img2Col transformation for the fhwc channel ordering in a Conv2D. Because of how the channel ordering affects the matrix dimensions in the flattened filter this results in a slightly different implementation of the actual "matrix multiplication". Instead of doing a regular row-column dot-product this arrangement requires a row-row dot product, otherwise the filter matrix would first need to be transposed. Adds a lit test to the transform dialect to check the semantics of the optimization are correct. Signed-off-by: Jack Frankland <jack.frankland@arm.com>
1 parent 0ce6255 commit aa22d0f

File tree

4 files changed

+230
-1
lines changed

4 files changed

+230
-1
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,14 @@ FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
11751175
FailureOr<std::pair<Operation *, Operation *>>
11761176
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
11771177

1178+
/// Same as the above but for Fhwc channel orderings in the filter. In this case
1179+
/// the matrix multiplication is actually a row-wise dot-product rather than a
1180+
/// row-column dot-product. This is to avoid transposing the filter matrix which
1181+
/// would be required for a regular matrix multiplication to produce the correct
1182+
/// output dimensions.
1183+
FailureOr<std::pair<Operation *, Operation *>>
1184+
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
1185+
11781186
/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
11791187
/// reduction among the input channels so each convolution can be a
11801188
/// matrix-vector product and by transposing both input filter so channels are

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3118,6 +3118,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
31183118
.Case([&](linalg::Conv2DNhwcHwcfOp op) {
31193119
return rewriteInIm2Col(rewriter, op);
31203120
})
3121+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3122+
return rewriteInIm2Col(rewriter, op);
3123+
})
31213124
.Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
31223125
return rewriteInIm2Col(rewriter, op);
31233126
})

mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,141 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
494494
reshapedResult.getOperation());
495495
}
496496

497+
FailureOr<std::pair<Operation *, Operation *>>
498+
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499+
auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500+
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501+
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502+
503+
if (!filterType.hasStaticShape())
504+
return rewriter.notifyMatchFailure(
505+
convOp, "expected a static shape for the filter");
506+
507+
if (!inputType.hasStaticShape())
508+
return rewriter.notifyMatchFailure(convOp,
509+
"expected a static shape for the input");
510+
511+
// TODO: Support dilation.
512+
if (!hasAllOneValues(convOp.getDilations()))
513+
return rewriter.notifyMatchFailure(convOp,
514+
"expected all ones for dilations");
515+
516+
MLIRContext *context = rewriter.getContext();
517+
Value input = convOp.getInputs()[0];
518+
Value filter = convOp.getInputs()[1];
519+
Value output = convOp.getOutputs()[0];
520+
521+
ArrayRef<int64_t> filterShape = filterType.getShape();
522+
ArrayRef<int64_t> outputShape = outputType.getShape();
523+
524+
int64_t n = outputShape[0];
525+
int64_t oh = outputShape[1];
526+
int64_t ow = outputShape[2];
527+
int64_t oc = outputShape[3];
528+
int64_t fh = filterShape[1];
529+
int64_t fw = filterShape[2];
530+
int64_t ic = filterShape[3];
531+
532+
Location loc = convOp.getLoc();
533+
534+
// Reshape output and filter to the LHS and result of a "row-wise" matrix
535+
// multiplication.
536+
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537+
auto reshapedFilterType =
538+
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539+
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
540+
loc, reshapedFilterType, filter, filterReassocIndices);
541+
542+
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543+
RankedTensorType reshapedOutputType =
544+
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545+
Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
546+
loc, reshapedOutputType, output, outputReassocIndices);
547+
548+
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549+
Value colTensor = rewriter.create<tensor::EmptyOp>(
550+
loc, colTensorShape, inputType.getElementType());
551+
552+
// Convert the input to a (BMK) column tensor.
553+
auto nloops = colTensorShape.size();
554+
555+
auto parallel = utils::IteratorType::parallel;
556+
auto reduction = utils::IteratorType::reduction;
557+
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558+
559+
SmallVector<AffineMap> img2colIndexingMaps = {
560+
AffineMap::getMultiDimIdentityMap(nloops, context)};
561+
562+
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
563+
loc, colTensor.getType(),
564+
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565+
img2colIterators,
566+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567+
// Get the iterators named based on the matmul (batch, m, k).
568+
Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569+
Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570+
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
571+
572+
// Recover the original iteration indices from the problem/input sizes.
573+
SmallVector<Value> mIndices = unrollIndex(
574+
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575+
auto ohIndex = mIndices[0];
576+
auto owIndex = mIndices[1];
577+
578+
SmallVector<Value> kIndices = unrollIndex(
579+
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580+
auto fhIndex = kIndices[0];
581+
auto fwIndex = kIndices[1];
582+
auto icIndex = kIndices[2];
583+
584+
// Extract the input element corresponding to the expanded indices.
585+
Value hIndex =
586+
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587+
convOp.getStrides().getValues<int64_t>()[0]);
588+
Value wIndex =
589+
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590+
convOp.getStrides().getValues<int64_t>()[1]);
591+
592+
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593+
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594+
Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595+
loc, input, extractionIndices);
596+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
597+
});
598+
599+
// Because we didn't transpose the filters we don't actually have a batched
600+
// matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601+
// products.
602+
AffineExpr bDim, mDim, nDim, kDim;
603+
bindDims(context, bDim, mDim, nDim, kDim);
604+
auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605+
auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606+
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607+
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608+
parallel, reduction};
609+
610+
auto genericOp = rewriter.create<linalg::GenericOp>(
611+
loc, reshapedOutputType,
612+
/*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613+
/*outputs=*/ValueRange{reshapedOutput},
614+
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616+
Value mul =
617+
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618+
Value add = createAdd(loc, mul, args[2], nestedBuilder);
619+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620+
});
621+
Value result = genericOp.getResults().front();
622+
623+
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
624+
loc, outputType, result, outputReassocIndices);
625+
626+
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
627+
628+
return std::make_pair(img2ColTensor.getOperation(),
629+
reshapedResult.getOperation());
630+
}
631+
497632
namespace {
498633

499634
class ConvertConv2DNhwcHwcf final
@@ -534,12 +669,25 @@ class ConvertConv2DNchwFchw final
534669
return success();
535670
}
536671
};
672+
673+
class ConvertConv2DNhwcFhwc final
674+
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675+
public:
676+
using OpRewritePattern::OpRewritePattern;
677+
678+
LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679+
PatternRewriter &rewriter) const override {
680+
if (failed(rewriteInIm2Col(rewriter, convOp)))
681+
return failure();
682+
return success();
683+
}
684+
};
537685
} // end anonymous namespace
538686

539687
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
540688
MLIRContext *context = patterns.getContext();
541689
patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
542-
ConvertConv2DNchwFchw>(context);
690+
ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
543691
}
544692
} // end namespace linalg
545693
} // end namespace mlir

mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,76 @@ transform.sequence failures(propagate) {
279279

280280
// -----
281281

282+
// CHECK: IR printer: tensor_producer
283+
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
284+
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
285+
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
286+
287+
// Collapsed indices.
288+
// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
289+
// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
290+
// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
291+
292+
// Compute input channel/convolved indices.
293+
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
294+
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
295+
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
296+
297+
// Extract from the input tensor.
298+
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
299+
// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
300+
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
301+
302+
// CHECK: IR printer: transformed
303+
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
304+
305+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
306+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
307+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
308+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
309+
// CHECK: @conv_2d_nhwc_fhwc
310+
// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
311+
// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
312+
// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
313+
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
314+
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
315+
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
316+
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
317+
// CHECK-SAME: #[[MAP0]]
318+
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
319+
// CHECK: linalg.yield %{{.+}} : f32
320+
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
321+
// CHECK-SAME: #[[MAP1]]
322+
// CHECK-SAME: #[[MAP2]]
323+
// CHECK-SAME: #[[MAP3]]
324+
// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
325+
// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
326+
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
327+
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
328+
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
329+
// CHECK: linalg.yield %[[ADD]] : f32
330+
// CHECK: } -> tensor<1x196x16xf32>
331+
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
332+
// CHECK: return %[[RESULT]]
333+
334+
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
335+
%0 = linalg.conv_2d_nhwc_fhwc
336+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
337+
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
338+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
339+
return %0 : tensor<1x14x14x16xf32>
340+
}
341+
342+
transform.sequence failures(propagate) {
343+
^bb1(%arg1: !transform.any_op):
344+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
345+
%img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
346+
transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
347+
transform.print %transformed {name = "transformed"}: !transform.any_op
348+
}
349+
350+
// -----
351+
282352
// Check for signed extend when the input type is smaller than the accumulator type.
283353

284354
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

0 commit comments

Comments
 (0)