Skip to content

Commit

Permalink
[transform] add transform dialect interpreter pass (alibaba#765)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyzero authored Nov 22, 2022
1 parent 77bf731 commit 437b220
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/glob_lit_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
# just a placeholder from the copybara rewrite.
data = [d for d in data if d != _default_driver]

if name.startswith('metadata-only'):
return
if name.startswith('gpu-only'):
tags.append('gpu')
if name.startswith('cpu-only'):
Expand Down
8 changes: 8 additions & 0 deletions tao_compiler/mlir/disc/tools/disc-opt/disc-opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/lhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
Expand Down Expand Up @@ -49,6 +54,9 @@ int main(int argc, char** argv) {
registry.insert<mlir::disc_ral::RalDialect>();
registry.insert<mlir::TF::TensorFlowDialect>();
registry.insert<mlir::disc_shape::DISCShapeDialect>();
registry.insert<mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
mlir::linalg::transform::LinalgTransformDialect,
mlir::iree_compiler::IREE::Input::IREEInputDialect>();

return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
registry,
Expand Down
30 changes: 30 additions & 0 deletions tao_compiler/mlir/disc/tools/disc-transform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,35 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "transform_dialect_interpreter",
srcs = ["transforms/transform_dialect_interpreter.cc"],
deps = [
":pass_details",
"@iree-dialects//:IREELinalgExtDialect",
"@iree-dialects//:IREELinalgExtTransformOps",
"@iree-dialects//:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
],
alwayslink = 1,
)

cc_library(
name = "all_passes",
hdrs = [
Expand All @@ -84,6 +113,7 @@ cc_library(
],
deps = [
":legalize_lmhlo_fusion_to_linalg",
":transform_dialect_interpreter",
"@llvm-project//mlir:Pass",
],
alwayslink = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define DISC_TOOLS_DISC_TRANSFORM_TRANSFORMS_PASSES_H_

#include <memory>
#include <string>

namespace mlir {

Expand All @@ -36,6 +37,11 @@ namespace disc_ral {
std::unique_ptr<OperationPass<ModuleOp>>
createDiscLegalizeLmhloFusionToLinalgPass();

// Applys transform dialect ops for codegen.
std::unique_ptr<OperationPass<ModuleOp>>
createDiscTransformDialectInterpreterPass(const std::string& fileName = "",
bool enableExpensiveChecks = false);

} // namespace disc_ral
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ glob_lit_tests(
filegroup(
name = "test_utilities",
testonly = True,
srcs = glob([
"metadata-only*.mlir",
]),
data = [
"//tensorflow/compiler/mlir/disc:disc-opt",
"@llvm-project//llvm:FileCheck",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: disc-opt --disc-transform-dialect-interpreter -split-input-file %s | FileCheck %s --check-prefix=INLINE
// RUN: disc-opt --disc-transform-dialect-interpreter=transform-file-name=%p/metadata-only-transform-dialect-interpreter-standalone.mlir -split-input-file %s | FileCheck %s --check-prefix=STANDALONE

// INLINE-LABEL: @matmul_nn
func.func @matmul_nn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
// INLINE: %{{.*}} = scf.for
// INLINE-NEXT: %{{.*}} = scf.for
// INLINE-NEXT: %{{.*}} = scf.for
// INLINE: tensor.extract_slice
// INLINE-NEXT: tensor.extract_slice
// INLINE-NEXT: tensor.extract_slice
// INLINE-NEXT: linalg.matmul
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}


transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
}

// -----

// STANDALONE-LABEL: @matmul_nn
func.func @matmul_nn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
// STANDALONE: %{{.*}} = scf.for
// STANDALONE-NEXT: %{{.*}} = scf.for
// STANDALONE-NEXT: %{{.*}} = scf.for
// STANDALONE: tensor.extract_slice
// STANDALONE-NEXT: tensor.extract_slice
// STANDALONE-NEXT: tensor.extract_slice
// STANDALONE-NEXT: linalg.matmul
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright 2022 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h"

#define DEBUG_TYPE "disc-transform-dialect-interpreter"

// This file implements the logic to apply transform dialect ops for codegen.

namespace mlir {
namespace disc_ral {
namespace {

LogicalResult parseTransformModuleFromFile(
MLIRContext* context, llvm::StringRef transformFileName,
OwningOpRef<ModuleOp>& transformModule) {
// Parse transformFileName content into a ModuleOp.
std::string errorMessage;
auto memoryBuffer = openInputFile(transformFileName, &errorMessage);
if (!memoryBuffer) {
llvm::errs() << "failed to parse transform file: " << transformFileName
<< "\n";
return failure();
}
// Tell sourceMgr about this buffer, the parser will pick it up.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
transformModule =
OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
return success();
}

struct DiscTransformDialectInterpreterPass
: public DiscTransformDialectInterpreterPassBase<
DiscTransformDialectInterpreterPass> {
explicit DiscTransformDialectInterpreterPass(const std::string& fileName,
bool enableExpensiveChecks)
: DiscTransformDialectInterpreterPassBase<
DiscTransformDialectInterpreterPass>::
DiscTransformDialectInterpreterPassBase() {
this->transformFileName_ = fileName;
this->enableExpensiveChecks_ = enableExpensiveChecks;
}

void getDependentDialects(DialectRegistry& registry) const override {
// TODO: this is only necessary to make registry subset happy when running
// the lowering to LLVM. The lowering should be changed to stop using the
// nested pass manager and this will go away.

// clang-format off
registry.insert<arith::ArithDialect,
AffineDialect,
bufferization::BufferizationDialect,
iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
func::FuncDialect,
linalg::LinalgDialect,
linalg::transform::LinalgTransformDialect,
LLVM::LLVMDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
tensor::TensorDialect,
vector::VectorDialect
// clang-format on
>();

// TODO: these should be registered by the extension instead, but there is
// no support for it in core currently.
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);

registry.addExtensions<
mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension,
transform_ext::StructuredTransformOpsExtension>();
linalg::registerTransformDialectExtension(registry);
}

void runOnOperation() override;
};

void DiscTransformDialectInterpreterPass::runOnOperation() {
ModuleOp module = getOperation();
if (transformFileName_.empty()) {
llvm::errs() << "no transform file name specified, assuming the transform "
"module is embedded in the IR next to the top-level\n";
// parse transform ops from the module itself.
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(transform::applyTransforms(
module, op,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks_))))
return signalPassFailure();
}
} else {
// parse transform ops from a standalone file.
OwningOpRef<ModuleOp> transformModule;
if (failed(parseTransformModuleFromFile(
module.getContext(), transformFileName_, transformModule))) {
llvm::errs() << "failed to load transform ops from file "
<< transformFileName_ << "\n";
return signalPassFailure();
}
for (auto op : transformModule.get()
.getBody()
->getOps<transform::TransformOpInterface>()) {
if (failed(transform::applyTransforms(
module, op,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks_))))
return signalPassFailure();
}
}
}

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
createDiscTransformDialectInterpreterPass(const std::string& fileName,
bool enableExpensiveChecks) {
return std::make_unique<DiscTransformDialectInterpreterPass>(
fileName, enableExpensiveChecks);
}

} // namespace disc_ral
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,14 @@ def DiscLegalizeLmhloFusionToLinalgPass : Pass<"disc-legalize-lmhlo-fusion-to-li
let summary = "Pass to convert a lmhlo fusion op to linalg on tensor.";
let constructor = "createDiscLegalizeLmhloFusionToLinalgPass()";
}

def DiscTransformDialectInterpreterPass : Pass<"disc-transform-dialect-interpreter", "ModuleOp"> {
let summary = "Pass to apply transform dialect operations one by one.";
let constructor = "createDiscTransformDialectInterpreterPass()";
let options = [
Option<"transformFileName_", "transform-file-name", "std::string",
/*default=*/"\"\"", "Filename of the transform schedule.">,
Option<"enableExpensiveChecks_", "enable-expensive-checks", "bool",
/*default=*/"false", "perform expensive checks to better report errors in the transform IR.">,
];
}
2 changes: 1 addition & 1 deletion tf_community

0 comments on commit 437b220

Please sign in to comment.