5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
//===----------------------------------------------------------------------===//
8
+ //
9
+ // This is the operation definition file for SYCL dialect operations.
10
+ //
11
+ //===----------------------------------------------------------------------===//
8
12
9
13
#ifndef SYCL_OPS
10
14
#define SYCL_OPS
11
15
12
16
include "mlir/IR/OpBase.td"
13
17
include "mlir/IR/AttrTypeBase.td"
14
18
19
+ include "mlir/Dialect/SYCL/IR/SYCLOpInterfaces.td"
15
20
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
16
21
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
17
22
include "mlir/IR/BuiltinTypeInterfaces.td"
23
+ include "mlir/Interfaces/CallInterfaces.td"
18
24
include "mlir/Interfaces/CastInterfaces.td"
19
25
include "mlir/Interfaces/SideEffectInterfaces.td"
20
26
21
- include "SYCLOpInterfaces.td"
22
-
23
27
////////////////////////////////////////////////////////////////////////////////
24
28
// TYPE DECLARATIONS
25
29
////////////////////////////////////////////////////////////////////////////////
@@ -77,7 +81,7 @@ class SYCL_Op<string mnemonic, list<Trait> traits = []>
77
81
78
82
class SYCLMethodOpInterfaceImpl<
79
83
string mnemonic, string type, list<string> methodNames, list<Trait> traits = []>
80
- : SYCL_Op<mnemonic, !listconcat(traits, [SYCLMethodOpInterface])> {
84
+ : SYCL_Op<mnemonic, !listconcat(traits, [SYCLMethodOpInterface, CallOpInterface ])> {
81
85
string baseType = type;
82
86
list<string> memberFunctionNames = methodNames;
83
87
int arrSize = !size(memberFunctionNames);
@@ -88,10 +92,25 @@ class SYCLMethodOpInterfaceImpl<
88
92
}};
89
93
static ::mlir::TypeID getTypeID() { return ::mlir::sycl::}] # type # [{::getTypeID(); }
90
94
static constexpr llvm::ArrayRef<llvm::StringLiteral> getMethodNames() { return methods; }
95
+
96
+ Operation::operand_iterator arg_operand_begin() { return (*this)->operand_begin(); }
97
+ Operation::operand_iterator arg_operand_end() { return (*this)->operand_end(); }
98
+
99
+ /// Return the callee of the generic SYCL call operation, this is required by
100
+ /// the call interface.
101
+ CallInterfaceCallable getCallableForCallee() {
102
+ return (*this)->getAttrOfType<FlatSymbolRefAttr>(getMangledFunctionNameAttrName());
103
+ }
104
+
105
+ /// Get the argument operands to the called function, this is required by the
106
+ /// call interface.
107
+ Operation::operand_range getArgOperands() {
108
+ return {arg_operand_begin(), arg_operand_end()};
109
+ }
91
110
}];
92
111
93
112
let extraClassDefinition = [{
94
- constexpr std::array<llvm::StringLiteral, }] # arrSize # [{> $cppClass::methods;
113
+ constexpr std::array<llvm::StringLiteral, }] # arrSize # [{> $cppClass::methods;
95
114
}];
96
115
}
97
116
@@ -283,11 +302,10 @@ def SYCLCastOp : SYCL_Op<"cast", [DeclareOpInterfaceMethods<CastOpInterface>,
283
302
// CALL OPERATION
284
303
////////////////////////////////////////////////////////////////////////////////
285
304
286
- def SYCLCallOp : SYCL_Op<"call", []> {
305
+ def SYCLCallOp : SYCL_Op<"call", [CallOpInterface ]> {
287
306
let summary = "Generic call operation";
288
307
let description = [{
289
- This operation represent the call to any function part of the sycl's
290
- namespace.
308
+ This operation represent a call to any function part of the sycl's namespace.
291
309
}];
292
310
293
311
let arguments = (ins
@@ -318,6 +336,23 @@ def SYCLCallOp : SYCL_Op<"call", []> {
318
336
}]>
319
337
];
320
338
339
+ let extraClassDeclaration = [{
340
+ operand_iterator arg_operand_begin() { return operand_begin(); }
341
+ operand_iterator arg_operand_end() { return operand_end(); }
342
+
343
+ /// Return the callee of the generic SYCL call operation, this is required by
344
+ /// the call interface.
345
+ CallInterfaceCallable getCallableForCallee() {
346
+ return (*this)->getAttrOfType<FlatSymbolRefAttr>(getMangledFunctionNameAttrName());
347
+ }
348
+
349
+ /// Get the argument operands to the called function, this is required by the
350
+ /// call interface.
351
+ operand_range getArgOperands() {
352
+ return {arg_operand_begin(), arg_operand_end()};
353
+ }
354
+ }];
355
+
321
356
let assemblyFormat = [{
322
357
`(` $Args `)` attr-dict `:` functional-type($Args, results)
323
358
}];
0 commit comments