Skip to content

[mlir][nvvm] Support predicates in BasicPtxBuilder #67102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
// Basic PTX Builder Interface
//===----------------------------------------------------------------------===//

def PtxPredicate : Optional<I1>;

def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
let description = [{
This interface is used to generate inline assembly with PTX for basic
Expand Down Expand Up @@ -62,6 +64,22 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
}];
let cppNamespace = "::mlir::NVVM";
let methods = [
InterfaceMethod<
/*desc=*/[{
Optional function for setting a predicate, which
always returns a `PtxPredicate` value of type i1. If no predicate is
provided, the instruction is unguarded; otherwise, it's guarded by the
predicate value. The `PtxPredicate` value must always be the last argument.
The provided PTX code by `getPtx` should not include the predicate usage.
The interface automatically handles predicate usage in the generated
PTX code when necessary.
}],
/*retType=*/"std::optional<::mlir::Value>",
/*methodName=*/"getPredicate",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
InterfaceMethod<
/*desc=*/[{ Returns PTX assembly with operand number. }],
/*retType=*/"std::string",
Expand Down
93 changes: 62 additions & 31 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
}

/// Base class that defines BasicPtxBuilderOpInterface.
class NVVM_PTXBuilder_Op<string mnemonic,
list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
}

//===----------------------------------------------------------------------===//
// NVVM attribute definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -206,21 +212,31 @@ def NVVM_ReduxOp :
//===----------------------------------------------------------------------===//

/// mbarrier.init instruction with generic pointer type
def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
}];
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDeclaration = [{
bool hasIntrinsic() { if(getPredicate()) return false; return true; }
}];
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
}];
}

/// mbarrier.init instruction with shared pointer type
def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
}];
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
}];
}

def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
Expand Down Expand Up @@ -275,26 +291,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
}

def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
}];
}

def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
}];
}

def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
Expand All @@ -313,8 +326,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
}];
}

def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
Expand Down Expand Up @@ -488,7 +500,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",

def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;

def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
Arguments<(ins LLVM_i8Ptr_shared:$dst,
LLVM_i8Ptr_global:$src,
I32Attr:$size,
Expand Down Expand Up @@ -1359,12 +1371,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//

def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
AttrSizedOperandSegments]>,
Arguments<(ins LLVM_i64ptr_shared:$dstMem,
LLVM_i64ptr_any:$tmaDescriptor,
LLVM_i64ptr_shared:$mbar,
Variadic<I32>:$coordinates)> {
let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
Variadic<I32>:$coordinates,
PtxPredicate:$predicate)> {
let assemblyFormat = [{
$dstMem `,`
$tmaDescriptor `,`
$mbar `,`
`box` `[`$coordinates `]`
(`,` `predicate` `=` $predicate^)?
attr-dict `:` type(operands)
}];

let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int dim = getCoordinates().size();
Expand All @@ -1382,11 +1406,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
let hasVerifier = 1;
}

def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
AttrSizedOperandSegments]>,
Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor,
LLVM_i64ptr_shared:$srcMem,
Variadic<I32>:$coordinates)> {
let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
Variadic<I32>:$coordinates,
PtxPredicate:$predicate)> {
let assemblyFormat = [{
$tmaDescriptor `,`
$srcMem `,`
`box` `[`$coordinates `]`
(`,` `predicate` `=` $predicate^)?
attr-dict `:` type(operands)
}];
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int dim = getCoordinates().size();
Expand All @@ -1408,8 +1442,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//

def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
Expand All @@ -1423,8 +1456,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
}];
}

def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let description = [{
Expand All @@ -1437,8 +1469,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
}];
}

def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
let arguments = (ins I32Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>

#define DEBUG_TYPE "nvgpu-to-nvvm"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
Expand Down Expand Up @@ -812,9 +813,10 @@ struct NVGPUMBarrierInitLowering
Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(mbarrierType)) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
count);
count, Value());
} else {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count);
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
Value());
}
return success();
}
Expand Down Expand Up @@ -909,12 +911,12 @@ struct NVGPUMBarrierArriveExpectTxLowering

if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
op, barrier, txcount);
op, barrier, txcount, Value());
return success();
}

rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
txcount);
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
op, barrier, txcount, Value());
return success();
}
};
Expand Down Expand Up @@ -965,7 +967,7 @@ struct NVGPUTmaAsyncLoadOpLowering
}

rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
return success();
}
};
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using namespace mlir;
using namespace NVVM;

namespace {

struct PtxLowering
: public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
using OpInterfaceRewritePattern<
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ LLVM::InlineAsmOp PtxBuilder::build() {

std::string ptxInstruction = interfaceOp.getPtx();

// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
interfaceOp.getPredicate().value()) {
std::string predicateStr = "@%";
predicateStr += std::to_string((ptxOperands.size() - 1));
ptxInstruction = predicateStr + " " + ptxInstruction;
}

// Tablegen doesn't accept $, so we use %, but inline assembly uses $.
// Replace all % with $
std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
Expand Down
Loading