1818#include  " mlir/Dialect/Linalg/IR/Linalg.h" 
1919#include  " mlir/Dialect/Math/IR/Math.h" 
2020#include  " mlir/Dialect/SparseTensor/IR/SparseTensor.h" 
21+ #include  " mlir/Dialect/Tensor/IR/Tensor.h" 
2122#include  " mlir/IR/Matchers.h" 
2223#include  " torch-mlir/Conversion/TorchToLinalg/Utils.h" 
2324#include  " torch-mlir/Conversion/Utils/Utils.h" 
@@ -2610,21 +2611,43 @@ SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
26102611};
26112612} //  namespace
26122613
2613- void  mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality (
2614-     TypeConverter &typeConverter, RewritePatternSet &patterns,
2615-     ConversionTarget &target) {
2616-   //  Add some legal ops for torch-torch lowering.
2614+ void  mlir::torch::torch_to_linalg::populateDataMovementOpsLegality (
2615+     ConversionTarget &target) { //  Add some legal ops for torch-torch lowering.
26172616  target.addLegalOp <ConstantIntOp>();
2617+   target.addIllegalOp <AtenReflectionPad1dOp>();
2618+   target.addIllegalOp <AtenReflectionPad2dOp>();
2619+   target.addIllegalOp <AtenFlattenUsingIntsOp>();
2620+   target.addIllegalOp <AtenUnflattenIntOp>();
2621+   target.addIllegalOp <AtenViewOp>();
2622+   target.addIllegalOp <AtenSqueezeOp>();
2623+   target.addIllegalOp <AtenSqueezeDimOp>();
2624+   target.addIllegalOp <AtenUnsqueezeOp>();
2625+   target.addIllegalOp <AtenTransposeIntOp>();
2626+   target.addIllegalOp <AtenPermuteOp>();
2627+   target.addIllegalOp <AtenSliceTensorOp>();
2628+   target.addIllegalOp <AtenCatOp>();
2629+   target.addIllegalOp <AtenBroadcastToOp>();
2630+   target.addIllegalOp <AtenContiguousOp>();
2631+   target.addIllegalOp <AtenCopyOp>();
2632+   target.addIllegalOp <AtenSliceScatterOp>();
2633+   target.addIllegalOp <AtenViewAsComplexOp>();
2634+   target.addIllegalOp <AtenViewAsRealOp>();
2635+   target.addIllegalOp <AtenDiagonalOp>();
2636+   target.addIllegalOp <AtenDiagEmbedOp>();
2637+   target.addDynamicallyLegalOp <OperatorOp>([&](Torch::OperatorOp op) {
2638+     return  !ConvertSparseOperatorOp::isSparsePrimitive (op.getNameAttr ());
2639+   });
2640+ }
2641+ 
2642+ void  mlir::torch::torch_to_linalg::populateDataMovementPatterns (
2643+     TypeConverter &typeConverter, RewritePatternSet &patterns) {
26182644
26192645  MLIRContext *context = patterns.getContext ();
2620-   target. addIllegalOp <AtenReflectionPad1dOp>(); 
2646+ 
26212647  patterns.add <ConvertAtenReflectionPad1dOp>(typeConverter, context);
2622-   target.addIllegalOp <AtenReflectionPad2dOp>();
26232648  patterns.add <ConvertAtenReflectionPad2dOp>(typeConverter, context);
2624-   target.addIllegalOp <AtenFlattenUsingIntsOp>();
26252649  patterns.add <ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
26262650  patterns.add <ConvertAtenUnflattenIntOp>(typeConverter, context);
2627-   target.addIllegalOp <AtenUnflattenIntOp>();
26282651
26292652  //  View op sadness: In the future, we only want ConvertAtenViewOpStrict,
26302653  //  but this requires work upstream to fully generalize reshape handling.
@@ -2635,46 +2658,26 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
26352658  //  due to not statically switching between inferred and non-inferred view
26362659  //  cases. They are ordered by optimiality of the lowerings they generate
26372660  //  when they are able.
2638-   target.addIllegalOp <AtenViewOp>();
26392661  patterns.add <ConvertAtenViewOp>(typeConverter, context, /* benefit=*/ 300 );
26402662  patterns.add <ConvertAtenViewOpStrict>(typeConverter, context,
26412663                                        /* benefit=*/ 200 );
26422664  patterns.add <ConvertAtenViewOpToReshape>(typeConverter, context,
26432665                                           /* benefit=*/ 100 );
2644- 
2645-   target.addIllegalOp <AtenSqueezeOp>();
26462666  patterns.add <ConvertAtenSqueezeOp>(typeConverter, context);
2647-   target.addIllegalOp <AtenSqueezeDimOp>();
26482667  patterns.add <ConvertAtenSqueezeDimOp>(typeConverter, context);
2649-   target.addIllegalOp <AtenUnsqueezeOp>();
26502668  patterns.add <ConvertAtenUnsqueezeOp>(typeConverter, context);
2651-   target.addIllegalOp <AtenTransposeIntOp>();
26522669  patterns.add <ConvertAtenTransposeIntOp>(typeConverter, context);
2653-   target.addIllegalOp <AtenPermuteOp>();
26542670  patterns.add <ConvertAtenPermuteOp>(typeConverter, context);
2655-   target.addIllegalOp <AtenSliceTensorOp>();
26562671  patterns.add <ConvertAtenSliceTensorOp>(typeConverter, context);
2657-   target.addIllegalOp <AtenCatOp>();
26582672  patterns.add <ConvertAtenCatOp>(typeConverter, context);
2659-   target.addIllegalOp <AtenBroadcastToOp>();
26602673  patterns.add <ConvertAtenBroadcastToOp>(typeConverter, context);
2661-   target.addIllegalOp <AtenContiguousOp>();
26622674  patterns.add <ConvertAtenContiguousOp>(typeConverter, context);
2663-   target.addIllegalOp <AtenCopyOp>();
26642675  patterns.add <ConvertAtenCopyOp>(typeConverter, context);
2665-   target.addIllegalOp <AtenSliceScatterOp>();
26662676  patterns.add <ConvertAtenSliceScatterOp>(typeConverter, context);
2667-   target.addIllegalOp <AtenViewAsComplexOp>();
26682677  patterns.add <ConvertAtenViewAsComplexOp>(typeConverter, context);
2669-   target.addIllegalOp <AtenViewAsRealOp>();
26702678  patterns.add <ConvertAtenViewAsRealOp>(typeConverter, context);
2671-   target.addIllegalOp <AtenDiagonalOp>();
26722679  patterns.add <ConvertAtenDiagonalOp>(typeConverter, context);
2673-   target.addIllegalOp <AtenDiagEmbedOp>();
26742680  patterns.add <ConvertAtenDiagEmbedOp>(typeConverter, context);
26752681  //  Rewrite all special sparse conversions hidden as operators.
2676-   target.addDynamicallyLegalOp <OperatorOp>([&](Torch::OperatorOp op) {
2677-     return  !ConvertSparseOperatorOp::isSparsePrimitive (op.getNameAttr ());
2678-   });
26792682  patterns.add <ConvertSparseOperatorOp>(typeConverter, context);
26802683}
0 commit comments