Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend #3376

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ externals/pytorch/
libtorch*

/build/
.build-cache/
/setup_build/
__pycache__
*.pyc
Expand Down
5 changes: 5 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ struct TorchLoweringPipelineOptions
void createTorchScriptModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that lowers the graph IR that is produced by
/// TorchDynamo export into the form expected by torch-verify-backend-contract.
void createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by torch-verify-backend-contract.
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ void mlir::torch::registerTorchPasses() {
"torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchdynamo-export-to-torch-backend-pipeline",
"Pipeline lowering TorchDynamo exported graph IR to Torch backend form.",
mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-function-to-torch-backend-pipeline",
"Pipeline lowering a Torch function to Torch backend form.",
Expand Down Expand Up @@ -59,6 +63,18 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
createTorchFunctionToTorchBackendPipeline(pm, options);
}

void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
}

void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// Incorporate user annotations and remove signature Python-isms.
Expand Down
17 changes: 3 additions & 14 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,20 @@ def _module_lowering(
verbose,
output_type,
torch_mod,
backend_legal_ops=None,
extra_library_file_name=None,
):

if output_type == OutputType.TORCH:
if verbose:
print(torch_mod)
return torch_mod
# TODO: pass backend_legal_ops/extra_library_file_name by caller
if backend_legal_ops is None:
backend_legal_ops = []
# TODO: pass extra_library_file_name by caller
if extra_library_file_name is None:
extra_library_file_name = ""
option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ " shape-dtype-refine="
+ ("false" if not backend_legal_ops and not extra_library_file_name else "true")
+ "}"
)
option_string = "{extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
Expand Down
Loading