Skip to content

[NVPTX] Add TMA bulk tensor reduction intrinsics #116854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ specified by the ``i32 %d0 ... i32 %d4`` arguments.
For more information, refer PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor>`_.

'``llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.[1-5]d``'
'``llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.[3-5]d``'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
Expand All @@ -648,7 +648,7 @@ Syntax:
Overview:
"""""""""

The '``@llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.[1-5]d``' intrinsics
The '``@llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.[3-5]d``' intrinsics
correspond to the ``cp.async.bulk.prefetch.tensor.[1-5]d.L2.global*`` set
of PTX instructions. These instructions initiate an asynchronous prefetch
of tensor data from global memory to the L2 cache. In im2col mode, some
Expand All @@ -663,6 +663,76 @@ the same functionality as described in the ``tile`` mode intrinsics above.
For more information, refer PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor>`_.

'``llvm.nvvm.cp.async.bulk.tensor.reduce.[red_op].tile.[1-5]d``'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch)

declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.tile.2d(..., i32 %d0, i32 %d1, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.tile.3d(..., i32 %d0, i32 %d1, i32 %d2, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.tile.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.tile.5d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, ...)

Overview:
"""""""""

The '``@llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.tile.[1-5]d``' intrinsics
correspond to the ``cp.reduce.async.bulk.tensor.[1-5]d.*`` set of PTX instructions.
These instructions initiate an asynchronous reduction operation of tensor data
in global memory with the tensor data in shared{::cta} memory, using ``tile`` mode.
The dimension of the tensor data ranges from 1d to 5d with the coordinates
specified by the ``i32 %d0 ... i32 %d4`` arguments. The supported reduction
operations are {add, min, max, inc, dec, and, or, xor} as described in the
``tile.1d`` intrinsics.

* The last argument to these intrinsics is a boolean flag
indicating support for cache_hint. This flag argument must
be a compile-time constant. When set, it indicates a valid
cache_hint (``i64 %ch``) and generates the ``.L2::cache_hint``
variant of the PTX instruction.

For more information, refer PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor>`_.

'``llvm.nvvm.cp.async.bulk.tensor.reduce.[red_op].im2col.[3-5]d``'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.im2col.3d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i32 %d1, i32 %d2, i64 %ch, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.im2col.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.im2col.5d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, ...)

Overview:
"""""""""

The '``@llvm.nvvm.cp.async.bulk.tensor.reduce.<red_op>.im2col.[3-5]d``' intrinsics
correspond to the ``cp.reduce.async.bulk.tensor.[3-5]d.*`` set of PTX instructions.
These instructions initiate an asynchronous reduction operation of tensor data
in global memory with the tensor data in shared{::cta} memory, using ``im2col`` mode.
In this mode, the tensor has to be at least three-dimensional. The supported reduction
operations supported are the same as the ones in the tile mode. The last argument to
these intrinsics is a boolean flag, with the same functionality as described in the
``tile`` mode intrinsics above.

For more information, refer PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor>`_.

Other Intrinsics
----------------

Expand Down
29 changes: 29 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,25 @@ class CP_ASYNC_BULK_TENSOR_PREFETCH_INTR<int dim, string mode> {
ImmArg<ArgIndex<FlagsStartIdx>>];
}

class CP_ASYNC_BULK_TENSOR_REDUCE_INTR<int dim, string mode, string op> {
string Suffix = op # "_" # mode # "_" # dim # "d";
string Name = "int_nvvm_cp_async_bulk_tensor_reduce_" # Suffix;

list<LLVMType> TensorDimsTy = !listsplat(llvm_i32_ty, dim);
list<LLVMType> ArgsTy = !listconcat(
[llvm_shared_ptr_ty, // src_smem_ptr
llvm_ptr_ty], // tensormap_ptr
TensorDimsTy, // actual tensor dims
[llvm_i64_ty, // cache_hint
llvm_i1_ty] // Flag for cache_hint
);
int FlagsStartIdx = !add(dim, 3);
list<IntrinsicProperty> IntrProp = [IntrConvergent,
ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>,
ImmArg<ArgIndex<FlagsStartIdx>>];
}

let TargetPrefix = "nvvm" in {
def int_nvvm_prmt : ClangBuiltin<"__nvvm_prmt">,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
Expand Down Expand Up @@ -4929,4 +4948,14 @@ foreach dim = [1, 2, 3, 4, 5] in {
}
}

// Intrinsics for TMA Copy with reduction
foreach dim = [1, 2, 3, 4, 5] in {
foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in {
foreach reduce = [CP_ASYNC_BULK_TENSOR_REDUCE_INTR<dim, mode, red_op>] in
def reduce.Name : DefaultAttrsIntrinsic<[], reduce.ArgsTy, reduce.IntrProp>;
}
}
}

} // let TargetPrefix = "nvvm"
37 changes: 37 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicFlags.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// This file contains the definitions of the enumerations and flags
/// associated with NVVM Intrinsics.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_IR_NVVMINTRINSICFLAGS_H
#define LLVM_IR_NVVMINTRINSICFLAGS_H

namespace llvm {
namespace nvvm {

// Reduction Ops supported with TMA Copy from Shared
// to Global Memory for the "cp.reduce.async.bulk.tensor.*"
// family of PTX instructions.
enum class TMAReductionOp : uint8_t {
ADD = 0,
MIN = 1,
MAX = 2,
INC = 3,
DEC = 4,
AND = 5,
OR = 6,
XOR = 7,
};

} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICFLAGS_H
38 changes: 38 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/NVVMIntrinsicFlags.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstrInfo.h"
Expand Down Expand Up @@ -416,3 +417,40 @@ void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
return;
}
}

void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
raw_ostream &O,
const char *Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
using RedTy = llvm::nvvm::TMAReductionOp;

switch (static_cast<RedTy>(MO.getImm())) {
case RedTy::ADD:
O << ".add";
return;
case RedTy::MIN:
O << ".min";
return;
case RedTy::MAX:
O << ".max";
return;
case RedTy::INC:
O << ".inc";
return;
case RedTy::DEC:
O << ".dec";
return;
case RedTy::AND:
O << ".and";
return;
case RedTy::OR:
O << ".or";
return;
case RedTy::XOR:
O << ".xor";
return;
default:
llvm_unreachable(
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
}
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
raw_ostream &O, const char *Modifier = nullptr);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
};

}
Expand Down
Loading