Skip to content

Commit cbfcfc8

Browse files
committed
[mlir][nvvm] Support predicates in BasicPtxBuilder
This PR enhances `BasicPtxBuilder` to support predicates in PTX code generation. The `BasicPtxBuilder` interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism. In PTX programming, instructions can be guarded by predicates as shown below:. Here `@p` is a predicate register and guard the execution of the instruction. ``` @p ptx.code op1, op2, op3 ``` This PR introduces the `getPredicate` function in the `BasicPtxBuilder` interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition. Additionally, this PR implements predicate usage for the following ops: - mbarrier.init - mbarrier.init.shared - mbarrier.arrive.expect_tx - mbarrier.arrive.expect_tx.shared - cp.async.bulk.tensor.shared.cluster.global - cp.async.bulk.tensor.global.shared.cta See for more detail in PTX programing model https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions
1 parent 282ea28 commit cbfcfc8

File tree

6 files changed

+154
-70
lines changed

6 files changed

+154
-70
lines changed

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
2222
// Basic PTX Builder Interface
2323
//===----------------------------------------------------------------------===//
2424

25+
def PtxPredicate : Optional<I1>;
26+
2527
def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
2628
let description = [{
2729
This interface is used to generate inline assembly with PTX for basic
@@ -62,6 +64,22 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
6264
}];
6365
let cppNamespace = "::mlir::NVVM";
6466
let methods = [
67+
InterfaceMethod<
68+
/*desc=*/[{
69+
Optional function for setting a predicate, which
70+
always returns a `PtxPredicate` value of type i1. If no predicate is
71+
provided, the instruction is unguarded; otherwise, it's guarded by the
72+
predicate value. The `PtxPredicate` value must always be the last argument.
73+
The provided PTX code by `getPtx` should not include the predicate usage.
74+
The interface automatically handles predicate usage in the generated
75+
PTX code when necessary.
76+
}],
77+
/*retType=*/"std::optional<::mlir::Value>",
78+
/*methodName=*/"getPredicate",
79+
/*args=*/(ins),
80+
/*methodBody=*/"",
81+
/*defaultImplementation=*/"return {};"
82+
>,
6583
InterfaceMethod<
6684
/*desc=*/[{ Returns PTX assembly with operand number. }],
6785
/*retType=*/"std::string",

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
7474
LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
7575
}
7676

77+
/// Base class that defines BasicPtxBuilderOpInterface.
78+
class NVVM_PTXBuilder_Op<string mnemonic,
79+
list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
80+
LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
81+
}
82+
7783
//===----------------------------------------------------------------------===//
7884
// NVVM attribute definitions
7985
//===----------------------------------------------------------------------===//
@@ -206,21 +212,31 @@ def NVVM_ReduxOp :
206212
//===----------------------------------------------------------------------===//
207213

208214
/// mbarrier.init instruction with generic pointer type
209-
def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
210-
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
215+
def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
216+
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate)> {
211217
string llvmBuilder = [{
212218
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
213219
}];
214-
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
220+
let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
221+
let extraClassDeclaration = [{
222+
bool hasIntrinsic() { if(getPredicate()) return false; return true; }
223+
}];
224+
let extraClassDefinition = [{
225+
std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
226+
}];
215227
}
216228

217229
/// mbarrier.init instruction with shared pointer type
218-
def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
219-
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
230+
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
231+
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate)> {
220232
string llvmBuilder = [{
221233
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
222234
}];
223-
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
235+
let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
236+
let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
237+
let extraClassDefinition = [{
238+
std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
239+
}];
224240
}
225241

226242
def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -275,26 +291,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
275291
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
276292
}
277293

278-
def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
279-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
280-
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
281-
let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
294+
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
295+
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
296+
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
282297
let extraClassDefinition = [{
283298
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
284299
}];
285300
}
286301

287-
def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared",
288-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
289-
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
290-
let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
302+
def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
303+
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
304+
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
291305
let extraClassDefinition = [{
292306
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
293307
}];
294308
}
295309

296-
def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
297-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
310+
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
298311
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {
299312
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
300313
let extraClassDefinition = [{
@@ -313,8 +326,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
313326
}];
314327
}
315328

316-
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared",
317-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
329+
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
318330
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {
319331
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
320332
let extraClassDefinition = [{
@@ -488,7 +500,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
488500

489501
def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
490502

491-
def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
503+
def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
492504
Arguments<(ins LLVM_i8Ptr_shared:$dst,
493505
LLVM_i8Ptr_global:$src,
494506
I32Attr:$size,
@@ -1359,12 +1371,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
13591371
// NVVM TMA Ops
13601372
//===----------------------------------------------------------------------===//
13611373

1362-
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1374+
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
1375+
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
1376+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1377+
AttrSizedOperandSegments]>,
13631378
Arguments<(ins LLVM_i64ptr_shared:$dstMem,
13641379
LLVM_i64ptr_any:$tmaDescriptor,
13651380
LLVM_i64ptr_shared:$mbar,
1366-
Variadic<I32>:$coordinates)> {
1367-
let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1381+
Variadic<I32>:$coordinates,
1382+
PtxPredicate:$predicate)> {
1383+
let assemblyFormat = [{
1384+
$dstMem `,`
1385+
$tmaDescriptor `,`
1386+
$mbar `,`
1387+
`box` `[`$coordinates `]`
1388+
(`,` `predicate` `=` $predicate^)?
1389+
attr-dict `:` type(operands)
1390+
}];
1391+
13681392
let extraClassDefinition = [{
13691393
std::string $cppClass::getPtx() {
13701394
int dim = getCoordinates().size();
@@ -1382,11 +1406,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
13821406
let hasVerifier = 1;
13831407
}
13841408

1385-
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1409+
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
1410+
NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
1411+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1412+
AttrSizedOperandSegments]>,
13861413
Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor,
13871414
LLVM_i64ptr_shared:$srcMem,
1388-
Variadic<I32>:$coordinates)> {
1389-
let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1415+
Variadic<I32>:$coordinates,
1416+
PtxPredicate:$predicate)> {
1417+
let assemblyFormat = [{
1418+
$tmaDescriptor `,`
1419+
$srcMem `,`
1420+
`box` `[`$coordinates `]`
1421+
(`,` `predicate` `=` $predicate^)?
1422+
attr-dict `:` type(operands)
1423+
}];
13901424
let extraClassDefinition = [{
13911425
std::string $cppClass::getPtx() {
13921426
int dim = getCoordinates().size();
@@ -1408,8 +1442,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
14081442
// NVVM Wgmma Ops
14091443
//===----------------------------------------------------------------------===//
14101444

1411-
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
1412-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
1445+
def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
14131446
let arguments = (ins);
14141447
let description = [{
14151448
Enforce an ordering of register accesses between warpgroup level matrix
@@ -1423,8 +1456,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
14231456
}];
14241457
}
14251458

1426-
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
1427-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1459+
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
14281460
Arguments<(ins )> {
14291461
let assemblyFormat = "attr-dict";
14301462
let description = [{
@@ -1437,8 +1469,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
14371469
}];
14381470
}
14391471

1440-
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
1441-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
1472+
def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
14421473
let arguments = (ins I32Attr:$group);
14431474
let assemblyFormat = "attr-dict $group";
14441475
let description = [{

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/Support/Debug.h"
2929
#include "llvm/Support/ErrorHandling.h"
3030
#include "llvm/Support/raw_ostream.h"
31+
#include <optional>
3132

3233
#define DEBUG_TYPE "nvgpu-to-nvvm"
3334
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -812,9 +813,10 @@ struct NVGPUMBarrierInitLowering
812813
Value count = truncToI32(b, adaptor.getCount());
813814
if (isMbarrierShared(mbarrierType)) {
814815
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
815-
count);
816+
count, Value());
816817
} else {
817-
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count);
818+
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
819+
Value());
818820
}
819821
return success();
820822
}
@@ -909,12 +911,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
909911

910912
if (isMbarrierShared(op.getBarriers().getType())) {
911913
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
912-
op, barrier, txcount);
914+
op, barrier, txcount, Value());
913915
return success();
914916
}
915917

916-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
917-
txcount);
918+
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
919+
op, barrier, txcount, Value());
918920
return success();
919921
}
920922
};
@@ -965,7 +967,7 @@ struct NVGPUTmaAsyncLoadOpLowering
965967
}
966968

967969
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
968-
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
970+
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
969971
return success();
970972
}
971973
};

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ using namespace mlir;
4141
using namespace NVVM;
4242

4343
namespace {
44+
4445
struct PtxLowering
4546
: public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
4647
using OpInterfaceRewritePattern<

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ LLVM::InlineAsmOp PtxBuilder::build() {
123123

124124
std::string ptxInstruction = interfaceOp.getPtx();
125125

126+
// Add the predicate to the asm string.
127+
if (interfaceOp.getPredicate().has_value() &&
128+
interfaceOp.getPredicate().value()) {
129+
std::string predicateStr = "@%";
130+
predicateStr += std::to_string((ptxOperands.size() - 1));
131+
ptxInstruction = predicateStr + " " + ptxInstruction;
132+
}
133+
126134
// Tablegen doesn't accept $, so we use %, but inline assembly uses $.
127135
// Replace all % with $
128136
std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');

0 commit comments

Comments
 (0)