@@ -494,6 +494,141 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
494
494
reshapedResult.getOperation ());
495
495
}
496
496
497
+ FailureOr<std::pair<Operation *, Operation *>>
498
+ rewriteInIm2Col (RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499
+ auto inputType = cast<ShapedType>(convOp.getInputs ()[0 ].getType ());
500
+ auto filterType = cast<ShapedType>(convOp.getInputs ()[1 ].getType ());
501
+ auto outputType = cast<ShapedType>(convOp.getOutputs ()[0 ].getType ());
502
+
503
+ if (!filterType.hasStaticShape ())
504
+ return rewriter.notifyMatchFailure (
505
+ convOp, " expected a static shape for the filter" );
506
+
507
+ if (!inputType.hasStaticShape ())
508
+ return rewriter.notifyMatchFailure (convOp,
509
+ " expected a static shape for the input" );
510
+
511
+ // TODO: Support dilation.
512
+ if (!hasAllOneValues (convOp.getDilations ()))
513
+ return rewriter.notifyMatchFailure (convOp,
514
+ " expected all ones for dilations" );
515
+
516
+ MLIRContext *context = rewriter.getContext ();
517
+ Value input = convOp.getInputs ()[0 ];
518
+ Value filter = convOp.getInputs ()[1 ];
519
+ Value output = convOp.getOutputs ()[0 ];
520
+
521
+ ArrayRef<int64_t > filterShape = filterType.getShape ();
522
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
523
+
524
+ int64_t n = outputShape[0 ];
525
+ int64_t oh = outputShape[1 ];
526
+ int64_t ow = outputShape[2 ];
527
+ int64_t oc = outputShape[3 ];
528
+ int64_t fh = filterShape[1 ];
529
+ int64_t fw = filterShape[2 ];
530
+ int64_t ic = filterShape[3 ];
531
+
532
+ Location loc = convOp.getLoc ();
533
+
534
+ // Reshape output and filter to the LHS and result of a "row-wise" matrix
535
+ // multiplication.
536
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0 }, {1 , 2 , 3 }};
537
+ auto reshapedFilterType =
538
+ RankedTensorType::get ({oc, fh * fw * ic}, filterType.getElementType ());
539
+ Value reshapedFilter = rewriter.create <tensor::CollapseShapeOp>(
540
+ loc, reshapedFilterType, filter, filterReassocIndices);
541
+
542
+ SmallVector<ReassociationIndices> outputReassocIndices = {{0 }, {1 , 2 }, {3 }};
543
+ RankedTensorType reshapedOutputType =
544
+ RankedTensorType::get ({n, oh * ow, oc}, outputType.getElementType ());
545
+ Value reshapedOutput = rewriter.create <tensor::CollapseShapeOp>(
546
+ loc, reshapedOutputType, output, outputReassocIndices);
547
+
548
+ SmallVector<int64_t > colTensorShape = {n, oh * ow, fh * fw * ic};
549
+ Value colTensor = rewriter.create <tensor::EmptyOp>(
550
+ loc, colTensorShape, inputType.getElementType ());
551
+
552
+ // Convert the input to a (BMK) column tensor.
553
+ auto nloops = colTensorShape.size ();
554
+
555
+ auto parallel = utils::IteratorType::parallel;
556
+ auto reduction = utils::IteratorType::reduction;
557
+ SmallVector<utils::IteratorType> img2colIterators (nloops, parallel);
558
+
559
+ SmallVector<AffineMap> img2colIndexingMaps = {
560
+ AffineMap::getMultiDimIdentityMap (nloops, context)};
561
+
562
+ auto img2ColTensor = rewriter.create <linalg::GenericOp>(
563
+ loc, colTensor.getType (),
564
+ /* inputs=*/ ValueRange{}, /* outputs=*/ colTensor, img2colIndexingMaps,
565
+ img2colIterators,
566
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567
+ // Get the iterators named based on the matmul (batch, m, k).
568
+ Value bIndex = nestedBuilder.create <linalg::IndexOp>(loc, 0 );
569
+ Value mIndex = nestedBuilder.create <linalg::IndexOp>(loc, 1 );
570
+ Value kIndex = nestedBuilder.create <linalg::IndexOp>(loc, 2 );
571
+
572
+ // Recover the original iteration indices from the problem/input sizes.
573
+ SmallVector<Value> mIndices = unrollIndex (
574
+ nestedBuilder, nestedLoc, mIndex , ArrayRef<int64_t >{oh, ow});
575
+ auto ohIndex = mIndices [0 ];
576
+ auto owIndex = mIndices [1 ];
577
+
578
+ SmallVector<Value> kIndices = unrollIndex (
579
+ nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{fh, fw, ic});
580
+ auto fhIndex = kIndices [0 ];
581
+ auto fwIndex = kIndices [1 ];
582
+ auto icIndex = kIndices [2 ];
583
+
584
+ // Extract the input element corresponding to the expanded indices.
585
+ Value hIndex =
586
+ getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
587
+ convOp.getStrides ().getValues <int64_t >()[0 ]);
588
+ Value wIndex =
589
+ getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
590
+ convOp.getStrides ().getValues <int64_t >()[1 ]);
591
+
592
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593
+ SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594
+ Value inputVal = nestedBuilder.create <tensor::ExtractOp>(
595
+ loc, input, extractionIndices);
596
+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, inputVal);
597
+ });
598
+
599
+ // Because we didn't transpose the filters we don't actually have a batched
600
+ // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601
+ // products.
602
+ AffineExpr bDim, mDim , nDim, kDim ;
603
+ bindDims (context, bDim, mDim , nDim, kDim );
604
+ auto lhsMap = AffineMap::get (4 , 0 , {bDim, mDim , kDim }, context);
605
+ auto rhsMap = AffineMap::get (4 , 0 , {nDim, kDim }, context);
606
+ auto resultMap = AffineMap::get (4 , 0 , {bDim, mDim , nDim}, context);
607
+ SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608
+ parallel, reduction};
609
+
610
+ auto genericOp = rewriter.create <linalg::GenericOp>(
611
+ loc, reshapedOutputType,
612
+ /* inputs=*/ ValueRange{img2ColTensor.getResult (0 ), reshapedFilter},
613
+ /* outputs=*/ ValueRange{reshapedOutput},
614
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616
+ Value mul =
617
+ createMul (loc, args[0 ], args[1 ], args[2 ].getType (), nestedBuilder);
618
+ Value add = createAdd (loc, mul, args[2 ], nestedBuilder);
619
+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, add);
620
+ });
621
+ Value result = genericOp.getResults ().front ();
622
+
623
+ auto reshapedResult = rewriter.create <tensor::ExpandShapeOp>(
624
+ loc, outputType, result, outputReassocIndices);
625
+
626
+ rewriter.replaceOp (convOp, ArrayRef<Value>{reshapedResult});
627
+
628
+ return std::make_pair (img2ColTensor.getOperation (),
629
+ reshapedResult.getOperation ());
630
+ }
631
+
497
632
namespace {
498
633
499
634
class ConvertConv2DNhwcHwcf final
@@ -534,12 +669,25 @@ class ConvertConv2DNchwFchw final
534
669
return success ();
535
670
}
536
671
};
672
+
673
+ class ConvertConv2DNhwcFhwc final
674
+ : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675
+ public:
676
+ using OpRewritePattern::OpRewritePattern;
677
+
678
+ LogicalResult matchAndRewrite (linalg::Conv2DNhwcFhwcOp convOp,
679
+ PatternRewriter &rewriter) const override {
680
+ if (failed (rewriteInIm2Col (rewriter, convOp)))
681
+ return failure ();
682
+ return success ();
683
+ }
684
+ };
537
685
} // end anonymous namespace
538
686
539
687
void populateConvertConv2DToImg2ColPatterns (RewritePatternSet &patterns) {
540
688
MLIRContext *context = patterns.getContext ();
541
689
patterns.insert <ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
542
- ConvertConv2DNchwFchw>(context);
690
+ ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc >(context);
543
691
}
544
692
} // end namespace linalg
545
693
} // end namespace mlir
0 commit comments