11
11
#include " gc/Dialect/LLVMIR/XeVMDialect.h"
12
12
#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13
13
#include " mlir/Conversion/LLVMCommon/Pattern.h"
14
+ #include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14
15
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
15
16
#include " mlir/Pass/Pass.h"
16
17
#include " mlir/Support/LLVM.h"
18
+ #include " llvm/Support/FormatVariadic.h"
19
+
20
+ #include " mlir/IR/BuiltinTypes.h"
21
+ #include " mlir/IR/Types.h"
22
+
23
+ #include " llvm/ADT/STLExtras.h"
24
+ #include " llvm/ADT/TypeSwitch.h"
25
+ #include " llvm/Support/raw_ostream.h"
17
26
18
27
#define DEBUG_TYPE " xevm-to-llvm"
19
28
@@ -26,6 +35,230 @@ using namespace mlir;
26
35
using namespace xevm ;
27
36
28
37
namespace {
38
+ struct LLVMFuncAttributeOptions {
39
+ bool isConvergent = false ;
40
+ bool isNoUnwind = false ;
41
+ bool isWillReturn = false ;
42
+ LLVM::MemoryEffectsAttr memEffectsAttr{};
43
+ };
44
+ static constexpr LLVMFuncAttributeOptions convergentAttrs = {
45
+ true , false , false , {}};
46
+ static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
47
+ false , true , false , {}};
48
+ static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
49
+ false , true , true , {}};
50
+ static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
51
+ true , true , true , {}};
52
+
53
+ std::string getTypeMangling (Type ty, bool isUnsigned = false ) {
54
+ return TypeSwitch<Type, std::string>(ty)
55
+ .Case ([isUnsigned](VectorType ty) -> std::string {
56
+ return " Dv" + std::to_string (ty.getNumElements ()) + " _" +
57
+ getTypeMangling (ty.getElementType (), isUnsigned);
58
+ })
59
+ .Case ([](Float16Type) -> std::string { return " Dh" ; })
60
+ .Case ([](Float32Type) -> std::string { return " f" ; })
61
+ .Case ([](Float64Type) -> std::string { return " d" ; })
62
+ .Case ([isUnsigned](IntegerType ty) -> std::string {
63
+ switch (ty.getWidth ()) {
64
+ case 8 :
65
+ return isUnsigned ? " h" : " c" ;
66
+ case 16 :
67
+ return isUnsigned ? " t" : " s" ;
68
+ case 32 :
69
+ return isUnsigned ? " j" : " i" ;
70
+ case 64 :
71
+ return isUnsigned ? " m" : " l" ;
72
+ default :
73
+ llvm_unreachable (" unhandled integer type" );
74
+ }
75
+ });
76
+ }
77
+
78
+ template <typename OpType>
79
+ static std::optional<ArrayAttr>
80
+ getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op,
81
+ const bool isLoad) {
82
+ if ((op.getL1CacheControlAttr () ==
83
+ xevm::L1StoreCacheControlAttr::get (
84
+ rewriter.getContext (), xevm::L1StoreCacheControl::DEFAULT) &&
85
+ op.getL3CacheControlAttr () ==
86
+ xevm::L3StoreCacheControlAttr::get (
87
+ rewriter.getContext (), xevm::L3StoreCacheControl::DEFAULT)) ||
88
+
89
+ (op.getL1CacheControlAttr () ==
90
+ xevm::L1LoadCacheControlAttr::get (
91
+ rewriter.getContext (), xevm::L1LoadCacheControl::DEFAULT) &&
92
+ op.getL3CacheControlAttr () ==
93
+ xevm::L3LoadCacheControlAttr::get (
94
+ rewriter.getContext (), xevm::L3LoadCacheControl::DEFAULT))) {
95
+ return {};
96
+ }
97
+ constexpr int32_t decorationCacheControlArity{4 };
98
+ constexpr int32_t loadCacheControlKey{6442 };
99
+ constexpr int32_t storeCacheControlKey{6443 };
100
+ constexpr int32_t l1Level{0 };
101
+ constexpr int32_t l3Level{1 };
102
+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
103
+ SmallVector<int32_t , decorationCacheControlArity> decorationsL1{
104
+ controlKey, l1Level, static_cast <int32_t >(op.getL1CacheControl ()), 0 };
105
+ SmallVector<int32_t , decorationCacheControlArity> decorationsL3{
106
+ controlKey, l3Level, static_cast <int32_t >(op.getL3CacheControl ()), 0 };
107
+ auto arrayAttrL1 = rewriter.getI32ArrayAttr (decorationsL1);
108
+ auto arrayAttrL3 = rewriter.getI32ArrayAttr (decorationsL3);
109
+
110
+ SmallVector<Attribute, 2 > combinedAttrs = {arrayAttrL1, arrayAttrL3};
111
+ return rewriter.getArrayAttr (combinedAttrs);
112
+ }
113
+
114
+ static LLVM::CallOp createDeviceFunctionCall (
115
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
116
+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
117
+ mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
118
+ LLVMFuncAttributeOptions funcAttributeOptions) {
119
+ auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
120
+ MLIRContext *ctx = rewriter.getContext ();
121
+ Location loc = UnknownLoc::get (ctx);
122
+
123
+ LLVM::LLVMFuncOp funcOp =
124
+ LLVM::lookupOrCreateFn (moduleOp, funcName, argTypes, retType);
125
+ funcOp.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
126
+ funcOp.setConvergent (funcAttributeOptions.isConvergent );
127
+ funcOp.setNoUnwind (funcAttributeOptions.isNoUnwind );
128
+ funcOp.setWillReturn (funcAttributeOptions.isWillReturn );
129
+
130
+ if (funcAttributeOptions.memEffectsAttr )
131
+ funcOp.setMemoryEffectsAttr (funcAttributeOptions.memEffectsAttr );
132
+
133
+ for (auto [idx, attrName] : paramAttrs)
134
+ funcOp.setArgAttr (idx, attrName, rewriter.getUnitAttr ());
135
+
136
+ // if (!passthroughAttrs.getFnAttributes().empty())
137
+ // funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx));
138
+
139
+ auto callOp = rewriter.create <LLVM::CallOp>(loc, funcOp, args);
140
+ callOp->setAttrs (funcOp->getAttrs ());
141
+
142
+ return callOp;
143
+ }
144
+
145
+ template <typename OpType>
146
+ class LoadStorePrefetchNdToOCLPattern : public OpConversionPattern <OpType> {
147
+ using OpConversionPattern<OpType>::OpConversionPattern;
148
+ LogicalResult
149
+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
150
+ ConversionPatternRewriter &rewriter) const override {
151
+ constexpr bool isLoad = std::is_same_v<OpType, xevm::BlockLoad2dOp>;
152
+ constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStore2dOp>;
153
+ constexpr bool isPrefetch = std::is_same_v<OpType, xevm::BlockPrefetch2dOp>;
154
+ auto loc = op.getLoc ();
155
+ VectorType vecType;
156
+ if constexpr (isLoad) {
157
+ vecType = op.getRes ().getType ();
158
+ } else if constexpr (isStore) {
159
+ vecType = op.getStoredVal ().getType ();
160
+ }
161
+
162
+ auto i32Type = rewriter.getI32Type ();
163
+ bool vnni = false ;
164
+ bool transpose = false ;
165
+ if constexpr (isLoad) {
166
+ vnni = op.getVnniTransform ();
167
+ transpose = op.getTranspose ();
168
+ }
169
+
170
+ Value byteCoord =
171
+ rewriter.create <LLVM::UndefOp>(loc, VectorType::get (2 , i32Type));
172
+ Value zero = rewriter.create <LLVM::ConstantOp>(
173
+ loc, i32Type, rewriter.getI32IntegerAttr (0 ));
174
+ Value one = rewriter.create <LLVM::ConstantOp>(
175
+ loc, i32Type, rewriter.getI32IntegerAttr (1 ));
176
+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
177
+ loc, VectorType::get (2 , i32Type), byteCoord, op.getX (), zero);
178
+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
179
+ loc, VectorType::get (2 , i32Type), byteCoord, op.getY (), one);
180
+ SmallVector<Value> args{op.getPtr (), op.getBaseWidth (), op.getBaseHeight (),
181
+ op.getBasePitch (), byteCoord};
182
+ SmallVector<Type> retTypes;
183
+ Value spvLoadDstPtr;
184
+ std::string funcName, bitWidthId;
185
+ SmallVector<std::pair<unsigned , mlir::StringRef>, 4 > paramAttrs;
186
+ if constexpr (isPrefetch) { // Prefetch
187
+ funcName = " intel_sub_group_2d_block_prefetch" ;
188
+ paramAttrs = {std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ())};
189
+ } else {
190
+ auto vecElemType = vecType.getElementType ();
191
+ auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth ();
192
+ Value numElems = rewriter.create <LLVM::ConstantOp>(
193
+ loc, i32Type, vecType.getNumElements ());
194
+ auto dstOrSrcPtr = rewriter.create <LLVM::AllocaOp>(
195
+ loc, LLVM::LLVMPointerType::get (rewriter.getContext ()), vecElemType,
196
+ numElems);
197
+ args.push_back (dstOrSrcPtr);
198
+ if constexpr (isLoad) { // Load
199
+ funcName = " intel_sub_group_2d_block_read" ;
200
+ bitWidthId = getTypeMangling (vecElemType, /* isUnsigned=*/ true );
201
+ if (vnni)
202
+ funcName += " _transform" ;
203
+ else if (transpose)
204
+ funcName += " _transpose" ;
205
+ spvLoadDstPtr = dstOrSrcPtr;
206
+ retTypes.push_back (vecType);
207
+ paramAttrs = {
208
+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
209
+ std::make_pair (0 , LLVM::LLVMDialect::getReadonlyAttrName ()),
210
+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
211
+ std::make_pair (5 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
212
+ };
213
+ } else { // Store
214
+ funcName = " intel_sub_group_2d_block_write" ;
215
+ bitWidthId = (vecElemBitWidth == 32 )
216
+ ? " j"
217
+ : ((vecElemBitWidth == 16 ) ? " t" : " h" );
218
+ rewriter.create <LLVM::StoreOp>(loc, op.getStoredVal (), dstOrSrcPtr);
219
+ paramAttrs = {
220
+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
221
+ std::make_pair (0 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
222
+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
223
+ std::make_pair (5 , LLVM::LLVMDialect::getReadonlyAttrName ()),
224
+ };
225
+ }
226
+ }
227
+
228
+ // !X = !{i32 %decoration_kind%, i32 %level%, i32 %control%, i32 %operand of
229
+ // the instruction to decorate%}
230
+ funcName =
231
+ llvm::formatv (" {0}_{1}b_{2}r{3}x{4}c" , funcName, op.getElemSizeInBits (),
232
+ op.getTileHeight (), op.getTileWidth (), op.getVBlocks ())
233
+ .str ();
234
+ funcName = llvm::formatv (" _Z{0}{1}PU3AS1viiiDv2_i{2}{3}" , funcName.size (),
235
+ funcName, isPrefetch ? " " : " P" , bitWidthId)
236
+ .str ();
237
+ SmallVector<Type> argTypes;
238
+ for (auto arg : args) {
239
+ argTypes.push_back (arg.getType ());
240
+ }
241
+ LLVM::CallOp call = createDeviceFunctionCall (
242
+ rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
243
+ argTypes, args, paramAttrs, noUnwindWillReturnAttrs);
244
+ if (std::optional<ArrayAttr> optCacheControls =
245
+ getCacheControlMetadata (rewriter, op, isLoad || isPrefetch)) {
246
+ call->setAttr (xevm::XeVMDialect::getCacheControlsAttrName (),
247
+ *optCacheControls);
248
+ }
249
+ if constexpr (isLoad)
250
+ rewriter.replaceOp (
251
+ op, rewriter.create <LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
252
+ else
253
+ rewriter.eraseOp (op);
254
+ return success ();
255
+ }
256
+ };
257
+
258
+ // ===----------------------------------------------------------------------===//
259
+ // Pass Definition
260
+ // ===----------------------------------------------------------------------===//
261
+
29
262
struct ConvertXeVMToLLVMPass
30
263
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
31
264
using Base::Base;
@@ -37,19 +270,51 @@ struct ConvertXeVMToLLVMPass
37
270
void runOnOperation () override {
38
271
ConversionTarget target (getContext ());
39
272
target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
40
- RewritePatternSet pattern (&getContext ());
41
- mlir::populateXeVMToLLVMConversionPatterns (pattern);
42
- if (failed (
43
- applyPartialConversion (getOperation (), target, std::move (pattern))))
273
+ target.addIllegalDialect <xevm::XeVMDialect>();
274
+ RewritePatternSet patterns (&getContext ());
275
+ mlir::populateXeVMToLLVMConversionPatterns (patterns);
276
+ if (failed (applyPartialConversion (getOperation (), target,
277
+ std::move (patterns))))
44
278
signalPassFailure ();
45
279
}
46
280
};
47
281
} // namespace
48
282
283
+ // ===----------------------------------------------------------------------===//
284
+ // Pattern Population
285
+ // ===----------------------------------------------------------------------===//
286
+
49
287
void mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
50
- /* TODO*/
288
+ patterns.add <LoadStorePrefetchNdToOCLPattern<xevm::BlockLoad2dOp>,
289
+ LoadStorePrefetchNdToOCLPattern<xevm::BlockStore2dOp>,
290
+ LoadStorePrefetchNdToOCLPattern<xevm::BlockPrefetch2dOp>>(
291
+ patterns.getContext ());
51
292
}
52
293
294
+ // ===----------------------------------------------------------------------===//
295
+ // ConvertToLLVMPatternInterface implementation
296
+ // ===----------------------------------------------------------------------===//
297
+
298
+ namespace {
299
+ // / Implement the interface to convert XeVM to LLVM.
300
+ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
301
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
302
+ void loadDependentDialects (MLIRContext *context) const final {
303
+ context->loadDialect <LLVM::LLVMDialect>();
304
+ }
305
+
306
+ // / Hook for derived dialect interface to provide conversion patterns
307
+ // / and mark dialect legal for the conversion target.
308
+ void populateConvertToLLVMConversionPatterns (
309
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
310
+ RewritePatternSet &patterns) const final {
311
+ populateXeVMToLLVMConversionPatterns (patterns);
312
+ }
313
+ };
314
+ } // namespace
315
+
53
316
void mlir::registerConvertXeVMToLLVMInterface (DialectRegistry ®istry) {
54
- /* TODO*/
317
+ registry.addExtension (+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
318
+ dialect->addInterfaces <XeVMToLLVMDialectInterface>();
319
+ });
55
320
}
0 commit comments