-
Notifications
You must be signed in to change notification settings - Fork 518
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
- Loading branch information
1 parent
21f905a
commit b90a76f
Showing
21 changed files
with
351 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "external/llvm-project"] | ||
path = externals/llvm-project | ||
url = https://github.com/llvm/llvm-project.git | ||
[submodule "externals/mlir-hlo"] | ||
path = externals/mlir-hlo | ||
url = https://github.com/tensorflow/mlir-hlo.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls) | ||
if(TORCH_MLIR_ENABLE_MHLO) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) | ||
else() | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls) | ||
endif() | ||
add_public_tablegen_target(TorchMLIRConversionPassIncGen) | ||
|
||
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H | ||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include <memory> | ||
|
||
namespace mlir { | ||
namespace torch { | ||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass(); | ||
} // namespace torch | ||
} // namespace mlir | ||
|
||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "./PopulatePatterns.h" | ||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||
#include "torch-mlir/Conversion/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" | ||
#include <iostream> | ||
#include <numeric> | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
|
||
namespace { | ||
template <typename AtenOpT> | ||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> { | ||
public: | ||
using OpConversionPattern<AtenOpT>::OpConversionPattern; | ||
using OpAdaptor = typename AtenOpT::Adaptor; | ||
LogicalResult | ||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override; | ||
}; | ||
} // namespace | ||
|
||
// AtenTanhOp | ||
namespace { | ||
template <> | ||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite( | ||
AtenTanhOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
Value self = adaptor.self(); | ||
auto selfTy = self.getType().cast<TensorType>(); | ||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) { | ||
rewriter.replaceOpWithNewOp<mhlo::TanhOp>( | ||
op, getTypeConverter()->convertType(op.getType()), self); | ||
return success(); | ||
} else { | ||
return op.emitError( | ||
"Only floating-point datatype legalization currently supported"); | ||
} | ||
} | ||
} // namespace | ||
|
||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( | ||
TypeConverter &typeConverter, RewritePatternSet &patterns, | ||
ConversionTarget &target) { | ||
MLIRContext *context = patterns.getContext(); | ||
|
||
#define INSERT_ATENOP_PATTERN(AtenOp) \ | ||
target.addIllegalOp<AtenOp>(); \ | ||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); | ||
INSERT_ATENOP_PATTERN(AtenTanhOp); | ||
#undef INSERT_ATENOP_PATTERN | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
add_mlir_conversion_library(TorchMLIRTorchToMhlo | ||
TorchToMhlo.cpp | ||
BasicOp.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo | ||
|
||
DEPENDS | ||
MhloDialect | ||
TorchMLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRPass | ||
MhloDialect | ||
TorchMLIRTorchDialect | ||
) | ||
|
||
torch_mlir_target_includes(TorchMLIRTorchToMhlo) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H | ||
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H | ||
|
||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
namespace torch { | ||
namespace torch_to_mhlo { | ||
|
||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, | ||
RewritePatternSet &patterns, | ||
ConversionTarget &target); | ||
|
||
} // namespace torch_to_mhlo | ||
} // namespace torch | ||
} // namespace mlir | ||
|
||
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "./PopulatePatterns.h" | ||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Traits.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" | ||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
|
||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<mhlo::MhloDialect>(); | ||
registry.insert<tensor::TensorDialect>(); | ||
registry.insert<arith::ArithmeticDialect>(); | ||
TorchConversion::getBackendTypeConversionDependentDialects(registry); | ||
} | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect, | ||
arith::ArithmeticDialect, Torch::TorchDialect>(); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
TorchConversion::setupBackendTypeConversion(target, typeConverter); | ||
|
||
RewritePatternSet patterns(context); | ||
|
||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, | ||
target); | ||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::torch::createConvertTorchToMhloPass() { | ||
return std::make_unique<ConvertTorchToMhlo>(); | ||
} |
Oops, something went wrong.