Skip to content

Commit 11a8a6d

Browse files
authored
Merge pull request #2 from sahas3/tosa-linalg-pipeline
Add `tosa_linalg` pipeline that lowers to `tosa` prior to `linalg_on_tensors` ops
2 parents 14ef05a + f408f53 commit 11a8a6d

File tree

31 files changed

+861
-334
lines changed

31 files changed

+861
-334
lines changed

include/torch-mlir/Conversion/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
123123
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
124124
}
125125

126+
def ConvertTorchToTosaLinalg : Pass<"convert-torch-to-tosa-linalg", "func::FuncOp"> {
127+
let summary = "Convert Torch ops to a mix of TOSA ops and LINALG_ON_TENSORS ops";
128+
let description = [{
129+
This pass tries to lower torch ops to tosa ops if possible. Otherwise lowers to a mix of linalg, tensor, scf, and other dialects.
130+
}];
131+
let constructor = "mlir::torch::createConvertTorchToTosaLinalgPass()";
132+
}
133+
126134
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
127135
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
128136
let description = [{

include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
1212

1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
14-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1514
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/DialectConversion.h"
1616
#include <memory>
1717

1818
namespace mlir {
1919
namespace torch {
20+
void populateTorchToLinalgOnTensorsPatterns(TypeConverter &typeConverter,
21+
RewritePatternSet &patterns);
22+
void populateTorchToLinalgOnTensorsOpsLegality(ConversionTarget &target);
2023
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
21-
}
24+
} // namespace torch
2225
} // namespace mlir
2326

2427
#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,23 @@
1212

1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/DialectConversion.h"
1516
#include <memory>
1617

1718
namespace mlir {
1819
namespace torch {
20+
21+
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
22+
/// dialect.
23+
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
24+
void populateTorchToTosaConversionIllegalOps(ConversionTarget &target);
25+
26+
/// Collect a set of patterns to convert Torch operations to Tosa dialect.
27+
void populateTorchToTosaConversionPatterns(TypeConverter &typeConverter,
28+
RewritePatternSet &patterns);
29+
1930
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
20-
}
31+
} // namespace torch
2132
} // namespace mlir
2233

2334
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef TORCHMLIR_CONVERSION_ATENTOTOSALINALG_ATENTOTOSALINALG_H
11+
#define TORCHMLIR_CONVERSION_ATENTOTOSALINALG_ATENTOTOSALINALG_H
12+
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include <memory>
16+
17+
namespace mlir {
18+
namespace torch {
19+
std::unique_ptr<OperationPass<func::FuncOp>>
20+
createConvertTorchToTosaLinalgPass();
21+
}
22+
} // namespace mlir
23+
24+
#endif // TORCHMLIR_CONVERSION_ATENTOTOSALINALG_ATENTOTOSALINALG_H

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
3030
/// TOSA backend contract.
3131
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
3232

33+
/// Creates a pipeline that lowers from the torch backend contract to the
34+
/// TOSA + linalg backend contract.
35+
void createTorchBackendToTosaLinalgBackendPipeline(OpPassManager &pm);
36+
3337
// Do not register the stablehlo options if the stablehlo target is disabled
3438
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
3539
struct StablehloBackendPipelineOptions
@@ -79,6 +83,9 @@ createVerifyLinalgOnTensorsBackendContractPass();
7983

8084
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
8185

86+
std::unique_ptr<OperationPass<ModuleOp>>
87+
createVerifyTosaLinalgBackendContractPass();
88+
8289
} // namespace TorchConversion
8390

8491
/// Registers all Torch transformation passes.

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
6666
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
6767
}
6868

69+
def VerifyTosaLinalgBackendContract : Pass<"torch-verify-tosa-linalg-backend-contract", "ModuleOp"> {
70+
let summary = "Verifies conformity to the tosa + linalg-on-tensors backend contract";
71+
let constructor = "mlir::torch::TorchConversion::createVerifyTosaLinalgBackendContractPass()";
72+
}
73+
6974
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
7075
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
7176
let summary = "Verifies conformity to the stablehlo backend contract";

lib/Conversion/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(TorchToLinalg)
44
add_subdirectory(TorchToSCF)
55
add_subdirectory(TorchToTensor)
66
add_subdirectory(TorchToTosa)
7+
add_subdirectory(TorchToTosaLinalg)
78
if(TORCH_MLIR_ENABLE_STABLEHLO)
89
add_subdirectory(TorchToStablehlo)
910
endif()
@@ -17,6 +18,7 @@ set(linked_libs TorchMLIRTorchToArith
1718
TorchMLIRTorchToSCF
1819
TorchMLIRTorchToTensor
1920
TorchMLIRTorchToTosa
21+
TorchMLIRTorchToTosaLinalg
2022
TorchMLIRTorchToTMTensor
2123
TorchMLIRTorchConversionToMLProgram
2224
TorchMLIRConversionUtils)

lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
2121
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
2222
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
23+
#include "torch-mlir/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.h"
2324

2425
//===----------------------------------------------------------------------===//
2526
// Pass registration

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1919
#include "mlir/Dialect/Math/IR/Math.h"
2020
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2122
#include "mlir/IR/Matchers.h"
2223
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
2324
#include "torch-mlir/Conversion/Utils/Utils.h"
@@ -2610,21 +2611,43 @@ SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
26102611
};
26112612
} // namespace
26122613

2613-
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
2614-
TypeConverter &typeConverter, RewritePatternSet &patterns,
2615-
ConversionTarget &target) {
2616-
// Add some legal ops for torch-torch lowering.
2614+
void mlir::torch::torch_to_linalg::populateDataMovementOpsLegality(
2615+
ConversionTarget &target) { // Add some legal ops for torch-torch lowering.
26172616
target.addLegalOp<ConstantIntOp>();
2617+
target.addIllegalOp<AtenReflectionPad1dOp>();
2618+
target.addIllegalOp<AtenReflectionPad2dOp>();
2619+
target.addIllegalOp<AtenFlattenUsingIntsOp>();
2620+
target.addIllegalOp<AtenUnflattenIntOp>();
2621+
target.addIllegalOp<AtenViewOp>();
2622+
target.addIllegalOp<AtenSqueezeOp>();
2623+
target.addIllegalOp<AtenSqueezeDimOp>();
2624+
target.addIllegalOp<AtenUnsqueezeOp>();
2625+
target.addIllegalOp<AtenTransposeIntOp>();
2626+
target.addIllegalOp<AtenPermuteOp>();
2627+
target.addIllegalOp<AtenSliceTensorOp>();
2628+
target.addIllegalOp<AtenCatOp>();
2629+
target.addIllegalOp<AtenBroadcastToOp>();
2630+
target.addIllegalOp<AtenContiguousOp>();
2631+
target.addIllegalOp<AtenCopyOp>();
2632+
target.addIllegalOp<AtenSliceScatterOp>();
2633+
target.addIllegalOp<AtenViewAsComplexOp>();
2634+
target.addIllegalOp<AtenViewAsRealOp>();
2635+
target.addIllegalOp<AtenDiagonalOp>();
2636+
target.addIllegalOp<AtenDiagEmbedOp>();
2637+
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
2638+
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
2639+
});
2640+
}
2641+
2642+
void mlir::torch::torch_to_linalg::populateDataMovementPatterns(
2643+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
26182644

26192645
MLIRContext *context = patterns.getContext();
2620-
target.addIllegalOp<AtenReflectionPad1dOp>();
2646+
26212647
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
2622-
target.addIllegalOp<AtenReflectionPad2dOp>();
26232648
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
2624-
target.addIllegalOp<AtenFlattenUsingIntsOp>();
26252649
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
26262650
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
2627-
target.addIllegalOp<AtenUnflattenIntOp>();
26282651

26292652
// View op sadness: In the future, we only want ConvertAtenViewOpStrict,
26302653
// but this requires work upstream to fully generalize reshape handling.
@@ -2635,46 +2658,26 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
26352658
// due to not statically switching between inferred and non-inferred view
26362659
// cases. They are ordered by optimiality of the lowerings they generate
26372660
// when they are able.
2638-
target.addIllegalOp<AtenViewOp>();
26392661
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/300);
26402662
patterns.add<ConvertAtenViewOpStrict>(typeConverter, context,
26412663
/*benefit=*/200);
26422664
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
26432665
/*benefit=*/100);
2644-
2645-
target.addIllegalOp<AtenSqueezeOp>();
26462666
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
2647-
target.addIllegalOp<AtenSqueezeDimOp>();
26482667
patterns.add<ConvertAtenSqueezeDimOp>(typeConverter, context);
2649-
target.addIllegalOp<AtenUnsqueezeOp>();
26502668
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
2651-
target.addIllegalOp<AtenTransposeIntOp>();
26522669
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
2653-
target.addIllegalOp<AtenPermuteOp>();
26542670
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
2655-
target.addIllegalOp<AtenSliceTensorOp>();
26562671
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
2657-
target.addIllegalOp<AtenCatOp>();
26582672
patterns.add<ConvertAtenCatOp>(typeConverter, context);
2659-
target.addIllegalOp<AtenBroadcastToOp>();
26602673
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
2661-
target.addIllegalOp<AtenContiguousOp>();
26622674
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
2663-
target.addIllegalOp<AtenCopyOp>();
26642675
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
2665-
target.addIllegalOp<AtenSliceScatterOp>();
26662676
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
2667-
target.addIllegalOp<AtenViewAsComplexOp>();
26682677
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
2669-
target.addIllegalOp<AtenViewAsRealOp>();
26702678
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
2671-
target.addIllegalOp<AtenDiagonalOp>();
26722679
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
2673-
target.addIllegalOp<AtenDiagEmbedOp>();
26742680
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
26752681
// Rewrite all special sparse conversions hidden as operators.
2676-
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
2677-
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
2678-
});
26792682
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
26802683
}

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1616
#include "mlir/Dialect/Math/IR/Math.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/IR/Matchers.h"
1819
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
1920
#include "torch-mlir/Conversion/Utils/Utils.h"
@@ -1103,23 +1104,25 @@ class ConvertAtenUpsampleNearest2dBackwardOp
11031104
};
11041105
} // namespace
11051106

1106-
void mlir::torch::torch_to_linalg::
1107-
populateIndirectDataMovementPatternsAndLegality(
1108-
TypeConverter &typeConverter, RewritePatternSet &patterns,
1109-
ConversionTarget &target) {
1110-
MLIRContext *context = patterns.getContext();
1107+
void mlir::torch::torch_to_linalg::populateIndirectDataMovementOpsLegality(
1108+
ConversionTarget &target) {
11111109
target.addIllegalOp<AtenGatherOp>();
1112-
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
11131110
target.addIllegalOp<AtenEmbeddingOp>();
1114-
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
11151111
target.addIllegalOp<AtenIndexSelectOp>();
1116-
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
11171112
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
1118-
patterns.add<ConvertAtenIndexTensorHackedTwinOp>(typeConverter, context);
11191113
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
1120-
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
11211114
target.addIllegalOp<AtenUpsampleNearest2dOp>();
1122-
patterns.add<ConvertAtenUpsampleNearest2dOp>(typeConverter, context);
11231115
target.addIllegalOp<AtenUpsampleNearest2dBackwardOp>();
1116+
}
1117+
1118+
void mlir::torch::torch_to_linalg::populateIndirectDataMovementPatterns(
1119+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
1120+
MLIRContext *context = patterns.getContext();
1121+
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
1122+
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
1123+
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
1124+
patterns.add<ConvertAtenIndexTensorHackedTwinOp>(typeConverter, context);
1125+
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
1126+
patterns.add<ConvertAtenUpsampleNearest2dOp>(typeConverter, context);
11241127
patterns.add<ConvertAtenUpsampleNearest2dBackwardOp>(typeConverter, context);
11251128
}

0 commit comments

Comments
 (0)