Skip to content

Commit 9042844

Browse files
authored
[SYCL-MLIR]: Add AlwaysInliner pass (#7481)
PR #7536 removed the MLIR inlining pass (too greedy). This PR introduces an 'always-inline' pass in the `cgeist` MLIR pipeline. This pass inlines only `sycl.call` call sites with callee that has the 'alwaysinline' mlir function attribute. The pass uses a generic SCC inliner to perform its actions and specializes the inlining heuristic used at outline above. Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
1 parent 9a2faf2 commit 9042844

File tree

14 files changed

+649
-42
lines changed

14 files changed

+649
-42
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
//
9+
// This is the operation definition file for SYCL dialect operations.
10+
//
11+
//===----------------------------------------------------------------------===//
812

913
#ifndef SYCL_OPS
1014
#define SYCL_OPS
1115

1216
include "mlir/IR/OpBase.td"
1317
include "mlir/IR/AttrTypeBase.td"
1418

19+
include "mlir/Dialect/SYCL/IR/SYCLOpInterfaces.td"
1520
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
1621
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
1722
include "mlir/IR/BuiltinTypeInterfaces.td"
23+
include "mlir/Interfaces/CallInterfaces.td"
1824
include "mlir/Interfaces/CastInterfaces.td"
1925
include "mlir/Interfaces/SideEffectInterfaces.td"
2026

21-
include "SYCLOpInterfaces.td"
22-
2327
////////////////////////////////////////////////////////////////////////////////
2428
// TYPE DECLARATIONS
2529
////////////////////////////////////////////////////////////////////////////////
@@ -77,7 +81,7 @@ class SYCL_Op<string mnemonic, list<Trait> traits = []>
7781

7882
class SYCLMethodOpInterfaceImpl<
7983
string mnemonic, string type, list<string> methodNames, list<Trait> traits = []>
80-
: SYCL_Op<mnemonic, !listconcat(traits, [SYCLMethodOpInterface])> {
84+
: SYCL_Op<mnemonic, !listconcat(traits, [SYCLMethodOpInterface, CallOpInterface])> {
8185
string baseType = type;
8286
list<string> memberFunctionNames = methodNames;
8387
int arrSize = !size(memberFunctionNames);
@@ -88,10 +92,25 @@ class SYCLMethodOpInterfaceImpl<
8892
}};
8993
static ::mlir::TypeID getTypeID() { return ::mlir::sycl::}] # type # [{::getTypeID(); }
9094
static constexpr llvm::ArrayRef<llvm::StringLiteral> getMethodNames() { return methods; }
95+
96+
Operation::operand_iterator arg_operand_begin() { return (*this)->operand_begin(); }
97+
Operation::operand_iterator arg_operand_end() { return (*this)->operand_end(); }
98+
99+
/// Return the callee of the generic SYCL call operation, this is required by
100+
/// the call interface.
101+
CallInterfaceCallable getCallableForCallee() {
102+
return (*this)->getAttrOfType<FlatSymbolRefAttr>(getMangledFunctionNameAttrName());
103+
}
104+
105+
/// Get the argument operands to the called function, this is required by the
106+
/// call interface.
107+
Operation::operand_range getArgOperands() {
108+
return {arg_operand_begin(), arg_operand_end()};
109+
}
91110
}];
92111

93112
let extraClassDefinition = [{
94-
constexpr std::array<llvm::StringLiteral, }] # arrSize # [{> $cppClass::methods;
113+
constexpr std::array<llvm::StringLiteral, }] # arrSize # [{> $cppClass::methods;
95114
}];
96115
}
97116

@@ -283,11 +302,10 @@ def SYCLCastOp : SYCL_Op<"cast", [DeclareOpInterfaceMethods<CastOpInterface>,
283302
// CALL OPERATION
284303
////////////////////////////////////////////////////////////////////////////////
285304

286-
def SYCLCallOp : SYCL_Op<"call", []> {
305+
def SYCLCallOp : SYCL_Op<"call", [CallOpInterface]> {
287306
let summary = "Generic call operation";
288307
let description = [{
289-
This operation represent the call to any function part of the sycl's
290-
namespace.
308+
This operation represent a call to any function part of the sycl's namespace.
291309
}];
292310

293311
let arguments = (ins
@@ -318,6 +336,23 @@ def SYCLCallOp : SYCL_Op<"call", []> {
318336
}]>
319337
];
320338

339+
let extraClassDeclaration = [{
340+
operand_iterator arg_operand_begin() { return operand_begin(); }
341+
operand_iterator arg_operand_end() { return operand_end(); }
342+
343+
/// Return the callee of the generic SYCL call operation, this is required by
344+
/// the call interface.
345+
CallInterfaceCallable getCallableForCallee() {
346+
return (*this)->getAttrOfType<FlatSymbolRefAttr>(getMangledFunctionNameAttrName());
347+
}
348+
349+
/// Get the argument operands to the called function, this is required by the
350+
/// call interface.
351+
operand_range getArgOperands() {
352+
return {arg_operand_begin(), arg_operand_end()};
353+
}
354+
}];
355+
321356
let assemblyFormat = [{
322357
`(` $Args `)` attr-dict `:` functional-type($Args, results)
323358
}];

mlir-sycl/include/mlir/Dialect/SYCL/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace sycl {
2424
// Passes
2525
//===----------------------------------------------------------------------===//
2626

27+
std::unique_ptr<Pass> createAlwaysInlinePass();
2728
std::unique_ptr<Pass> createSYCLMethodToSYCLCallPass();
2829

2930
//===----------------------------------------------------------------------===//

mlir-sycl/include/mlir/Dialect/SYCL/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
def AlwaysInlinePass : Pass<"sycl-always-inline"> {
15+
let summary = "Inline SYCL calls to functions marked 'alwaysinline'";
16+
let description = [{
17+
Replace a sycl.call operation with the body of the callee if the callee has
18+
the 'always_inline' attribute.
19+
}];
20+
let constructor = "mlir::sycl::createAlwaysInlinePass()";
21+
}
22+
1423
def SYCLMethodToSYCLCall : Pass<"sycl-method-to-sycl-call", "ModuleOp"> {
1524
let summary = "Convert SYCLMethodOpInterface instances to SYCLCallOps";
1625
let description = [{

mlir-sycl/lib/Dialect/IR/SYCLOpsDialect.cpp

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,76 @@
1-
// Copyright (C) Codeplay Software Limited
2-
3-
//===--- SYCLOpsDialect.cpp -----------------------------------------------===//
1+
//===--- SYCLOpsDialect.cpp - SYCL Dialect registration in MLIR -----------===//
42
//
53
// MLIR-SYCL is under the Apache License v2.0 with LLVM Exceptions.
64
// See https://llvm.org/LICENSE.txt for license information.
75
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
86
//
97
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the dialect for the SYCL IR.
10+
//
11+
//===----------------------------------------------------------------------===//
1012

1113
#include "mlir/Dialect/SYCL/IR/SYCLOpsDialect.h"
12-
1314
#include "mlir/Dialect/SYCL/IR/SYCLOps.h"
1415
#include "mlir/Dialect/SYCL/IR/SYCLOpsAlias.h"
1516
#include "mlir/Dialect/SYCL/IR/SYCLOpsTypes.h"
17+
#include "mlir/IR/DialectImplementation.h"
18+
#include "mlir/Transforms/InliningUtils.h"
19+
20+
//===----------------------------------------------------------------------===//
21+
// SYCL Dialect Interfaces
22+
//===----------------------------------------------------------------------===//
23+
24+
namespace {
25+
26+
/// This class defines the interface for inlining SYCL operations.
27+
class SYCLInlinerInterface : public mlir::DialectInlinerInterface {
28+
public:
29+
using DialectInlinerInterface::DialectInlinerInterface;
30+
31+
//===--------------------------------------------------------------------===//
32+
// Analysis Hooks
33+
//===--------------------------------------------------------------------===//
34+
35+
/// This hook checks whether is legal to inline the \p Callable operation and
36+
/// replace the \p Call operation with it. For the SYCL dialect we want to
37+
/// allow inlining only SYCLCallOp operations.
38+
bool isLegalToInline(mlir::Operation *Call, mlir::Operation *Callable,
39+
bool WouldBeCloned) const final {
40+
return mlir::isa<mlir::sycl::SYCLCallOp>(Call);
41+
}
42+
43+
/// This hook checks whether is legal to inline the \p Op operation into the
44+
/// \p Dest region. All operations in the SYCL dialect are legal to inline.
45+
bool isLegalToInline(mlir::Operation *Op, mlir::Region *Dest,
46+
bool WouldBeCloned,
47+
mlir::BlockAndValueMapping &ValueMapping) const final {
48+
return true;
49+
}
50+
51+
//===--------------------------------------------------------------------===//
52+
// Transformation Hooks
53+
//===--------------------------------------------------------------------===//
54+
55+
/// Attempts to materialize a conversion for a type mismatch between a call
56+
/// from the SYCL dialect, and a callable region. This method should generate
57+
/// an operation that takes \p Input as the only operand, and produces a
58+
/// single result of \p ResultType. If a conversion cannot be generated,
59+
/// nullptr should be returned.
60+
mlir::Operation *
61+
materializeCallConversion(mlir::OpBuilder &Builder, mlir::Value Input,
62+
mlir::Type ResultType,
63+
mlir::Location ConversionLoc) const final {
64+
return Builder.create<mlir::sycl::SYCLCastOp>(ConversionLoc, ResultType,
65+
Input);
66+
}
67+
};
68+
69+
} // namespace
70+
71+
//===----------------------------------------------------------------------===//
72+
// SYCL Dialect
73+
//===----------------------------------------------------------------------===//
1674

1775
#include "llvm/ADT/TypeSwitch.h"
1876
#include "llvm/Support/Debug.h"
@@ -37,6 +95,7 @@ void mlir::sycl::SYCLDialect::initialize() {
3795
>();
3896

3997
mlir::Dialect::addInterfaces<SYCLOpAsmInterface>();
98+
mlir::Dialect::addInterfaces<SYCLInlinerInterface>();
4099
}
41100

42101
llvm::Optional<llvm::StringRef>

mlir-sycl/lib/Transforms/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_library(MLIRSYCLTransforms
2+
Inliner.cpp
23
SYCLMethodToSYCLCall.cpp
34

45
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,9 @@ add_mlir_library(MLIRSYCLTransforms
910

1011
LINK_LIBS PUBLIC
1112
MLIRPass
13+
MLIRTransforms
1214
MLIRTransformUtils
1315
MLIRSYCLDialect
16+
MLIRFuncDialect
17+
MLIRGPUOps
1418
)

0 commit comments

Comments
 (0)