Skip to content

[NVPTX] Add NVPTX intrinsics for TMA copies #95289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,26 @@ defm int_nvvm_cp_async_ca_shared_global_8 : CP_ASYNC_SHARED_GLOBAL<"8", "ca">;
defm int_nvvm_cp_async_ca_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "ca">;
defm int_nvvm_cp_async_cg_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "cg">;

// TODO(apaszke): Multicast TMA loads
foreach dim = [1, 2, 3, 4, 5] in {
def int_nvvm_cp_async_bulk_tensor_ # dim # d_shared_cluster_global_tile_mbarrier_complete_tx_bytes :
Intrinsic<
[],
[llvm_shared_ptr_ty, llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_anyptr_ty],
[IntrArgMemOnly, IntrNoCallback,
NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<1>>, NoAlias<ArgIndex<!add(2, dim)>>,
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>],
"llvm.nvvm.cp.async.bulk.tensor." # dim # "d.shared_cluster.global.tile.mbarrier_complete_tx_bytes">;
def int_nvvm_cp_async_bulk_tensor_ # dim # d_global_shared_cta_tile_bulk_group :
Intrinsic<
[],
[llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_shared_ptr_ty],
[IntrNoCallback,
NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<!add(1, dim)>>,
ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<!add(1, dim)>>],
"llvm.nvvm.cp.async.bulk.tensor." # dim # "d.global.shared_cta.tile.bulk_group">;
}

def int_nvvm_cp_async_commit_group :
ClangBuiltin<"__nvvm_cp_async_commit_group">,
Intrinsic<[],[],[]>;
Expand Down Expand Up @@ -1595,6 +1615,10 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
"llvm.nvvm.ptr.gen.to.param">;
def int_nvvm_ptr_param_to_gen: Intrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
"llvm.nvvm.ptr.param.to.gen">;

// Move intrinsics, used in nvvm internally

Expand Down
28 changes: 28 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,33 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
int_nvvm_cp_async_cg_shared_global_16_s>;

foreach dim = [1, 2, 3, 4, 5] in {
defvar idx_ptx = !interleave(!foreach(i, !range(dim), "$idx" # i), ", ");
defvar idx_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "idx" # i));
defvar intrinsic_g2s = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_shared_cluster_global_tile_mbarrier_complete_tx_bytes");
def CP_ASYNC_BULK_TENSOR_ # dim # D_SHARED_CLUSTER_GLOBAL_TILE_MBARRIER_COMPLETE_TX_BYTES_64 :
NVPTXInst<
(outs),
!con((ins Int64Regs:$dst, Int64Regs:$desc), idx_dag, (ins Int64Regs:$mbar)),
"cp.async.bulk.tensor." # dim # "d.shared::cluster.global.tile.mbarrier::complete_tx::bytes [$dst], [$desc, {{" # idx_ptx # "}}], [$mbar];",
[!con((intrinsic_g2s Int64Regs:$dst, Int64Regs:$desc),
!setdagop(idx_dag, intrinsic_g2s),
(intrinsic_g2s Int64Regs:$mbar))]
>,
Requires<[hasPTX<80>, hasSM<90>]>;
defvar intrinsic_s2g = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_global_shared_cta_tile_bulk_group");
def CP_ASYNC_BULK_TENSOR_ # dim # D_GLOBAL_SHARED_CTA_TILE_BULK_GROUP_64 :
NVPTXInst<
(outs),
!con((ins Int64Regs:$desc), idx_dag, (ins Int64Regs:$dst)),
"cp.async.bulk.tensor." # dim # "d.global.shared::cta.tile.bulk_group [$desc, {{" # idx_ptx # "}}], [$dst];",
[!con((intrinsic_s2g Int64Regs:$desc),
!setdagop(idx_dag, intrinsic_s2g),
(intrinsic_s2g Int64Regs:$dst))]
>,
Requires<[hasPTX<80>, hasSM<90>]>;
}

def CP_ASYNC_COMMIT_GROUP :
NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>,
Requires<[hasPTX<70>, hasSM<80>]>;
Expand Down Expand Up @@ -2475,6 +2502,7 @@ defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;

defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
Expand Down
91 changes: 78 additions & 13 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,17 @@
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include <cassert>
#include <numeric>
#include <queue>

Expand Down Expand Up @@ -146,6 +151,28 @@ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
"Lower arguments (NVPTX)", false, false)

static std::optional<int> tmaDescriptorOperandIndex(Instruction *I) {
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
return 1;
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group:
case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group:
return 0;
default:
return std::nullopt;
}
}
return std::nullopt;
}

// =============================================================================
// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
// and we can't guarantee that the only accesses are loads,
Expand All @@ -166,14 +193,15 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",

// Replaces the \p OldUser instruction with the same in parameter AS.
// Only Load and GEP are supported.
static void convertToParamAS(Value *OldUser, Value *Param) {
static void convertToParamAS(Value *OldUser, Value *OldParam, Value *NewParam) {
Instruction *I = dyn_cast<Instruction>(OldUser);
assert(I && "OldUser must be an instruction");
struct IP {
Instruction *OldInstruction;
Value *OldParam;
Value *NewParam;
};
SmallVector<IP> ItemsToConvert = {{I, Param}};
SmallVector<IP> ItemsToConvert = {{I, OldParam, NewParam}};
SmallVector<Instruction *> InstructionsToDelete;

auto CloneInstInParamAS = [](const IP &I) -> Value * {
Expand All @@ -200,6 +228,28 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
// Just pass through the argument, the old ASC is no longer needed.
return I.NewParam;
}
if (auto *II = dyn_cast<IntrinsicInst>(I.OldInstruction)) {
// Assert that this is a TMA intrinsic.
assert(tmaDescriptorOperandIndex(II).has_value());
assert(I.OldInstruction->getOperand(*tmaDescriptorOperandIndex(II)) ==
I.OldParam);
// TMA descriptors can remain in param memory space, but need to be passed
// in the generic address space.
Type *ParamPtr = PointerType::get(II->getContext(), ADDRESS_SPACE_PARAM);
Type *GenericPtr =
PointerType::get(II->getContext(), ADDRESS_SPACE_GENERIC);
FunctionType *cast_func_ty =
FunctionType::get(GenericPtr, {ParamPtr}, false);
Module *M = I.OldInstruction->getModule();
FunctionCallee func =
M->getOrInsertFunction(getName(llvm::Intrinsic::nvvm_ptr_param_to_gen,
{GenericPtr, ParamPtr}, M),
cast_func_ty);
Instruction *NewInGeneric =
CallInst::Create(func, {I.NewParam}, "", II->getIterator());
II->replaceUsesOfWith(I.OldParam, NewInGeneric);
return II;
}
llvm_unreachable("Unsupported instruction");
};

Expand All @@ -212,7 +262,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
// be converted and the instruction itself to be deleted. We can't delete
// the old instruction yet, because it's still in use by a load somewhere.
for (Value *V : I.OldInstruction->users())
ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
ItemsToConvert.push_back(
{cast<Instruction>(V), I.OldInstruction, NewInst});

InstructionsToDelete.push_back(I.OldInstruction);
}
Expand Down Expand Up @@ -300,9 +351,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
Worklist.push({I, Ctx.Offset + Offset});
continue;
}
if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
assert(tmaDescriptorOperandIndex(II).has_value());
continue;
}

llvm_unreachable("All users must be one of: load, "
"bitcast, getelementptr.");
"bitcast, getelementptr, TMA intrinsic.");
}
}

Expand All @@ -321,28 +376,38 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
assert(StructType && "Missing byval type");

auto IsALoadChain = [&](Value *Start) {
SmallVector<Value *, 16> ValuesToCheck = {Start};
auto IsALoadChainInstr = [](Value *V) -> bool {
SmallVector<Use*, 16> UsesToCheck;
for (Use& u : Start->uses())
UsesToCheck.push_back(&u);
auto IsSupportedUse = [](Use *U) -> bool {
Value *V = U->get();
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
return true;
// ASC to param space are OK, too -- we'll just strip them.
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
return true;
}
// TMA descriptors passed to TMA intrinsics are OK, too.
if (auto *II = dyn_cast<IntrinsicInst>(V)) {
auto OI = tmaDescriptorOperandIndex(II);
return OI.has_value() && *OI == U->getOperandNo();
}
return false;
};

while (!ValuesToCheck.empty()) {
Value *V = ValuesToCheck.pop_back_val();
if (!IsALoadChainInstr(V)) {
LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
while (!UsesToCheck.empty()) {
Use* U = UsesToCheck.pop_back_val();
if (!IsSupportedUse(U)) {
LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << U
<< "\n");
(void)Arg;
return false;
}
if (!isa<LoadInst>(V))
llvm::append_range(ValuesToCheck, V->users());
if (!isa<LoadInst>(U)) {
for (Use& u : U->getUser()->uses())
UsesToCheck.push_back(&u);
}
}
return true;
};
Expand All @@ -355,7 +420,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
FirstInst);
for (Value *V : UsersToUpdate)
convertToParamAS(V, ArgInParamAS);
convertToParamAS(V, Arg, ArgInParamAS);
LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");

const auto *TLI =
Expand Down
Loading