@@ -74,6 +74,12 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
74
74
LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
75
75
}
76
76
77
+ /// Base class that defines BasicPtxBuilderOpInterface.
78
+ class NVVM_PTXBuilder_Op<string mnemonic,
79
+ list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
80
+ LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
81
+ }
82
+
77
83
//===----------------------------------------------------------------------===//
78
84
// NVVM attribute definitions
79
85
//===----------------------------------------------------------------------===//
@@ -206,21 +212,31 @@ def NVVM_ReduxOp :
206
212
//===----------------------------------------------------------------------===//
207
213
208
214
/// mbarrier.init instruction with generic pointer type
209
- def NVVM_MBarrierInitOp : NVVM_Op <"mbarrier.init">,
210
- Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
215
+ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op <"mbarrier.init">,
216
+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate )> {
211
217
string llvmBuilder = [{
212
218
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
213
219
}];
214
- let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
220
+ let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
221
+ let extraClassDeclaration = [{
222
+ bool hasIntrinsic() { if(getPredicate()) return false; return true; }
223
+ }];
224
+ let extraClassDefinition = [{
225
+ std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
226
+ }];
215
227
}
216
228
217
229
/// mbarrier.init instruction with shared pointer type
218
- def NVVM_MBarrierInitSharedOp : NVVM_Op <"mbarrier.init.shared">,
219
- Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
230
+ def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op <"mbarrier.init.shared">,
231
+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate )> {
220
232
string llvmBuilder = [{
221
233
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
222
234
}];
223
- let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
235
+ let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
236
+ let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
237
+ let extraClassDefinition = [{
238
+ std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
239
+ }];
224
240
}
225
241
226
242
def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -275,26 +291,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
275
291
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
276
292
}
277
293
278
- def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
279
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
280
- Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
281
- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
294
+ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
295
+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
296
+ let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
282
297
let extraClassDefinition = [{
283
298
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
284
299
}];
285
300
}
286
301
287
- def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared",
288
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
289
- Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
290
- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
302
+ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
303
+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
304
+ let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
291
305
let extraClassDefinition = [{
292
306
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
293
307
}];
294
308
}
295
309
296
- def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
297
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
310
+ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
298
311
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {
299
312
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
300
313
let extraClassDefinition = [{
@@ -313,8 +326,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
313
326
}];
314
327
}
315
328
316
- def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared",
317
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
329
+ def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
318
330
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {
319
331
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
320
332
let extraClassDefinition = [{
@@ -488,7 +500,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
488
500
489
501
def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
490
502
491
- def NVVM_CpAsyncOp : NVVM_Op <"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>] >,
503
+ def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op <"cp.async.shared.global">,
492
504
Arguments<(ins LLVM_i8Ptr_shared:$dst,
493
505
LLVM_i8Ptr_global:$src,
494
506
I32Attr:$size,
@@ -1359,12 +1371,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
1359
1371
// NVVM TMA Ops
1360
1372
//===----------------------------------------------------------------------===//
1361
1373
1362
- def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1374
+ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
1375
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
1376
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1377
+ AttrSizedOperandSegments]>,
1363
1378
Arguments<(ins LLVM_i64ptr_shared:$dstMem,
1364
1379
LLVM_i64ptr_any:$tmaDescriptor,
1365
1380
LLVM_i64ptr_shared:$mbar,
1366
- Variadic<I32>:$coordinates)> {
1367
- let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1381
+ Variadic<I32>:$coordinates,
1382
+ PtxPredicate:$predicate)> {
1383
+ let assemblyFormat = [{
1384
+ $dstMem `,`
1385
+ $tmaDescriptor `,`
1386
+ $mbar `,`
1387
+ `box` `[`$coordinates `]`
1388
+ (`,` `predicate` `=` $predicate^)?
1389
+ attr-dict `:` type(operands)
1390
+ }];
1391
+
1368
1392
let extraClassDefinition = [{
1369
1393
std::string $cppClass::getPtx() {
1370
1394
int dim = getCoordinates().size();
@@ -1382,11 +1406,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
1382
1406
let hasVerifier = 1;
1383
1407
}
1384
1408
1385
- def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1409
+ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
1410
+ NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
1411
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1412
+ AttrSizedOperandSegments]>,
1386
1413
Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor,
1387
1414
LLVM_i64ptr_shared:$srcMem,
1388
- Variadic<I32>:$coordinates)> {
1389
- let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1415
+ Variadic<I32>:$coordinates,
1416
+ PtxPredicate:$predicate)> {
1417
+ let assemblyFormat = [{
1418
+ $tmaDescriptor `,`
1419
+ $srcMem `,`
1420
+ `box` `[`$coordinates `]`
1421
+ (`,` `predicate` `=` $predicate^)?
1422
+ attr-dict `:` type(operands)
1423
+ }];
1390
1424
let extraClassDefinition = [{
1391
1425
std::string $cppClass::getPtx() {
1392
1426
int dim = getCoordinates().size();
@@ -1408,8 +1442,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
1408
1442
// NVVM Wgmma Ops
1409
1443
//===----------------------------------------------------------------------===//
1410
1444
1411
- def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
1412
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
1445
+ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
1413
1446
let arguments = (ins);
1414
1447
let description = [{
1415
1448
Enforce an ordering of register accesses between warpgroup level matrix
@@ -1423,8 +1456,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
1423
1456
}];
1424
1457
}
1425
1458
1426
- def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
1427
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1459
+ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
1428
1460
Arguments<(ins )> {
1429
1461
let assemblyFormat = "attr-dict";
1430
1462
let description = [{
@@ -1437,8 +1469,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
1437
1469
}];
1438
1470
}
1439
1471
1440
- def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
1441
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
1472
+ def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
1442
1473
let arguments = (ins I32Attr:$group);
1443
1474
let assemblyFormat = "attr-dict $group";
1444
1475
let description = [{
0 commit comments