Skip to content

Commit 51207bc

Browse files
committed
[NVPTX] Add TMA bulk tensor reduction intrinsics
This patch adds NVVM intrinsics and NVPTX codegen for: * cp.async.bulk.tensor.reduce.1D -> 5D variants, supporting both Tile and Im2Col modes. * These intrinsics optionally support cache_hints as indicated by the boolean flag argument. * Lit tests are added for all combinations of these intrinsics in cp-async-bulk-tensor-reduce.ll. * The generated PTX is verified with a 12.3 ptxas executable. * Added docs for these intrinsics in NVPTXUsage.rst file. PTX Spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent def22f4 commit 51207bc

File tree

9 files changed

+693
-16
lines changed

9 files changed

+693
-16
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,84 @@ the same functionality as described in the ``tile`` mode intrinsics above.
663663
For more information, refer PTX ISA
664664
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor>`_.
665665

666+
'``llvm.nvvm.cp.async.bulk.tensor.reduce.tile.[1-5]d``'
667+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
668+
669+
Syntax:
670+
"""""""
671+
672+
.. code-block:: llvm
673+
674+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.tile.1d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i64 %ch, i1 %flag_ch, i8 %flag_red_op)
675+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.tile.2d(..., i32 %d0, i32 %d1, ...)
676+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.tile.3d(..., i32 %d0, i32 %d1, i32 %d2, ...)
677+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.tile.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
678+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.tile.5d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, ...)
679+
680+
Overview:
681+
"""""""""
682+
683+
The '``@llvm.nvvm.cp.async.bulk.tensor.reduce.tile.[1-5]d``' intrinsics
684+
correspond to the ``cp.reduce.async.bulk.tensor.[1-5]d.*`` set of PTX instructions.
685+
These instructions initiate an asynchronous reduction operation of tensor data
686+
in global memory with tensor data in shared::cta memory, using ``tile`` mode.
687+
The dimension of the tensor data ranges from 1d to 5d with the coordinates
688+
specified by the ``i32 %d0 ... i32 %d4`` arguments.
689+
690+
* The last two arguments to these intrinsics are flags.
691+
These flag arguments must be compile-time constants. The backend
692+
looks through these flags and lowers the intrinsics appropriately.
693+
694+
* The Nth argument (denoted by ``i8 flag_red_op``) indicates the
695+
kind of reduction operation performed. The argument must be in
696+
the range [0, 7], representing the following reduction operations:
697+
698+
========== =============
699+
Enum Value Reduction Op
700+
========== =============
701+
``0`` ADD
702+
``1`` MIN
703+
``2`` MAX
704+
``3`` INC
705+
``4`` DEC
706+
``5`` AND
707+
``6`` OR
708+
``7`` XOR
709+
========== =============
710+
711+
* The [N-1]th argument (denoted by ``i1 flag_ch``) when set, indicates
712+
the presence of a valid cache_hint (``i64 %ch``) and generates the
713+
``.L2::cache_hint`` variant of the PTX instruction.
714+
715+
For more information, refer PTX ISA
716+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor>`_.
717+
718+
'``llvm.nvvm.cp.async.bulk.tensor.reduce.im2col.[1-5]d``'
719+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
720+
721+
Syntax:
722+
"""""""
723+
724+
.. code-block:: llvm
725+
726+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.im2col.3d(ptr addrspace(3) %src, ptr %tensor_map, i32 %d0, i32 %d1, i32 %d2, i64 %ch, i1 %flag_ch, i8 %flag_red_op)
727+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.im2col.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
728+
declare void @llvm.nvvm.cp.async.bulk.tensor.reduce.im2col.5d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, ...)
729+
730+
Overview:
731+
"""""""""
732+
733+
The '``@llvm.nvvm.cp.async.bulk.tensor.reduce.im2col.[1-5]d``' intrinsics
734+
correspond to the ``cp.reduce.async.bulk.tensor.[1-5]d.*`` set of PTX instructions.
735+
These instructions initiate an asynchronous reduction operation of tensor data
736+
in global memory with tensor data in shared::cta memory, using ``im2col`` mode.
737+
In this mode, the tensor has to be at least three-dimensional.
738+
The last two arguments of these intrinsics are compile-time flags,
739+
with the same functionality as described in the ``tile`` mode intrinsics above.
740+
741+
For more information, refer PTX ISA
742+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor>`_.
743+
666744
Other Intrinsics
667745
----------------
668746

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,26 @@ class CP_ASYNC_BULK_TENSOR_PREFETCH_INTR<int dim, string mode> {
635635
ImmArg<ArgIndex<FlagsStartIdx>>];
636636
}
637637

638+
class CP_ASYNC_BULK_TENSOR_REDUCE_INTR<int dim, string mode> {
639+
string Name = "int_nvvm_cp_async_bulk_tensor_reduce_" # mode # "_" # dim # "d";
640+
641+
list<LLVMType> TensorDimsTy = !listsplat(llvm_i32_ty, dim);
642+
list<LLVMType> ArgsTy = !listconcat(
643+
[llvm_shared_ptr_ty, // src_smem_ptr
644+
llvm_ptr_ty], // tensormap_ptr
645+
TensorDimsTy, // actual tensor dims
646+
[llvm_i64_ty, // cache_hint
647+
llvm_i1_ty, // Flag for cache_hint
648+
llvm_i8_ty] // Flag for Reduction Op
649+
);
650+
int FlagsStartIdx = !add(dim, 3);
651+
list<IntrinsicProperty> IntrProp = [IntrConvergent,
652+
ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
653+
NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>,
654+
ImmArg<ArgIndex<FlagsStartIdx>>,
655+
ImmArg<ArgIndex<!add(FlagsStartIdx, 1)>>];
656+
}
657+
638658
let TargetPrefix = "nvvm" in {
639659
def int_nvvm_prmt : ClangBuiltin<"__nvvm_prmt">,
640660
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
@@ -4926,6 +4946,8 @@ foreach dim = [1, 2, 3, 4, 5] in {
49264946
def s2g.Name : DefaultAttrsIntrinsic<[], s2g.ArgsTy, s2g.IntrProp>;
49274947
foreach prefetch = [CP_ASYNC_BULK_TENSOR_PREFETCH_INTR<dim, mode>] in
49284948
def prefetch.Name : DefaultAttrsIntrinsic<[], prefetch.ArgsTy, prefetch.IntrProp>;
4949+
foreach reduce = [CP_ASYNC_BULK_TENSOR_REDUCE_INTR<dim, mode>] in
4950+
def reduce.Name : DefaultAttrsIntrinsic<[], reduce.ArgsTy, reduce.IntrProp>;
49294951
}
49304952
}
49314953

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// This file contains the definitions of the enumerations and flags
11+
/// associated with NVVM Intrinsics.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
16+
#define LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
17+
18+
namespace llvm {
19+
namespace nvvm {
20+
21+
// Reduction Ops supported with TMA Copy from Shared
22+
// to Global Memory for the "cp.reduce.async.bulk.tensor.*"
23+
// family of PTX instructions.
24+
enum class TMAReductionOp : uint8_t {
25+
ADD = 0,
26+
MIN = 1,
27+
MAX = 2,
28+
INC = 3,
29+
DEC = 4,
30+
AND = 5,
31+
OR = 6,
32+
XOR = 7,
33+
};
34+
35+
} // namespace nvvm
36+
} // namespace llvm
37+
#endif // LLVM_SUPPORT_NVVMINTRINSICFLAGS_H

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "NVPTX.h"
1515
#include "NVPTXUtilities.h"
1616
#include "llvm/ADT/StringRef.h"
17+
#include "llvm/IR/NVVMIntrinsicFlags.h"
1718
#include "llvm/MC/MCExpr.h"
1819
#include "llvm/MC/MCInst.h"
1920
#include "llvm/MC/MCInstrInfo.h"
@@ -416,3 +417,39 @@ void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
416417
return;
417418
}
418419
}
420+
421+
void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
422+
raw_ostream &O,
423+
const char *Modifier) {
424+
const MCOperand &MO = MI->getOperand(OpNum);
425+
426+
switch (static_cast<nvvm::TMAReductionOp>(MO.getImm())) {
427+
case nvvm::TMAReductionOp::ADD:
428+
O << ".add";
429+
return;
430+
case nvvm::TMAReductionOp::MIN:
431+
O << ".min";
432+
return;
433+
case nvvm::TMAReductionOp::MAX:
434+
O << ".max";
435+
return;
436+
case nvvm::TMAReductionOp::INC:
437+
O << ".inc";
438+
return;
439+
case nvvm::TMAReductionOp::DEC:
440+
O << ".dec";
441+
return;
442+
case nvvm::TMAReductionOp::AND:
443+
O << ".and";
444+
return;
445+
case nvvm::TMAReductionOp::OR:
446+
O << ".or";
447+
return;
448+
case nvvm::TMAReductionOp::XOR:
449+
O << ".xor";
450+
return;
451+
default:
452+
llvm_unreachable(
453+
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
454+
}
455+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
5454
raw_ostream &O, const char *Modifier = nullptr);
5555
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
5656
const char *Modifier = nullptr);
57+
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
58+
const char *Modifier = nullptr);
5759
};
5860

5961
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4157,9 +4157,9 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
41574157
? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix \
41584158
: NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)
41594159

4160-
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode) \
4161-
(IsCacheHint ? (CP_ASYNC_BULK_TENSOR_OPCODE(S2G, dim, mode, _CH)) \
4162-
: (CP_ASYNC_BULK_TENSOR_OPCODE(S2G, dim, mode, )))
4160+
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(op, dim, mode) \
4161+
(IsCacheHint ? (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, _CH)) \
4162+
: (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, )))
41634163

41644164
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode) \
41654165
[&]() -> auto { \
@@ -4177,31 +4177,40 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
41774177
: NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode)
41784178

41794179
static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
4180-
bool IsCacheHint, bool IsIm2Col) {
4180+
bool IsCacheHint, bool IsIm2Col,
4181+
bool IsReduce = false) {
41814182
if (IsIm2Col) {
41824183
switch (Dim) {
41834184
case 3:
4184-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL);
4185+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 3D, IM2COL)
4186+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 3D, IM2COL);
41854187
case 4:
4186-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL);
4188+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 4D, IM2COL)
4189+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 4D, IM2COL);
41874190
case 5:
4188-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL);
4191+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 5D, IM2COL)
4192+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 5D, IM2COL);
41894193
default:
41904194
llvm_unreachable("Invalid Dimension in im2col mode for "
41914195
"GetCpAsyncBulkTensorS2GOpcode.");
41924196
}
41934197
} else {
41944198
switch (Dim) {
41954199
case 1:
4196-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE);
4200+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 1D, TILE)
4201+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 1D, TILE);
41974202
case 2:
4198-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE);
4203+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 2D, TILE)
4204+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 2D, TILE);
41994205
case 3:
4200-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE);
4206+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 3D, TILE)
4207+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 3D, TILE);
42014208
case 4:
4202-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE);
4209+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 4D, TILE)
4210+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 4D, TILE);
42034211
case 5:
4204-
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE);
4212+
return IsReduce ? GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(RED, 5D, TILE)
4213+
: GET_CP_ASYNC_BULK_TENSOR_OPCODE_CH(S2G, 5D, TILE);
42054214
default:
42064215
llvm_unreachable(
42074216
"Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
@@ -4377,6 +4386,30 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N,
43774386
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
43784387
}
43794388

4389+
void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,
4390+
bool IsIm2Col) {
4391+
// We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
4392+
// src, dst, dims{d0...dN}, cache_hint, cache_hint_flag, reduction_kind_flag
4393+
// NumOperands = {Chain, IID} + {Actual intrinsic args}
4394+
// = {2} + {5 + dims}
4395+
size_t NumOps = N->getNumOperands();
4396+
size_t NumDims = NumOps - 7;
4397+
unsigned ReductionKind = N->getConstantOperandVal(NumOps - 1);
4398+
bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1;
4399+
size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint
4400+
4401+
SDLoc DL(N);
4402+
SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs));
4403+
Ops.push_back(getI32Imm(ReductionKind, DL)); // Reduction Op
4404+
Ops.push_back(N->getOperand(0)); // Chain operand
4405+
4406+
bool IsShared32 =
4407+
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
4408+
unsigned Opcode = GetCpAsyncBulkTensorS2GOpcode(
4409+
NumDims, IsShared32, IsCacheHint, IsIm2Col, /*IsReduce=*/true);
4410+
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
4411+
}
4412+
43804413
bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
43814414
unsigned IID = N->getConstantOperandVal(1);
43824415
switch (IID) {
@@ -4418,5 +4451,17 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
44184451
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d:
44194452
SelectCpAsyncBulkTensorPrefetchCommon(N, /*IsIm2Col=*/true);
44204453
return true;
4454+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_tile_1d:
4455+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_tile_2d:
4456+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_tile_3d:
4457+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_tile_4d:
4458+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_tile_5d:
4459+
SelectCpAsyncBulkTensorReduceCommon(N);
4460+
return true;
4461+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_im2col_3d:
4462+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_im2col_4d:
4463+
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_im2col_5d:
4464+
SelectCpAsyncBulkTensorReduceCommon(N, /*IsIm2Col=*/true);
4465+
return true;
44214466
}
44224467
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
9595
void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false);
9696
void SelectCpAsyncBulkTensorS2GCommon(SDNode *N, bool IsIm2Col = false);
9797
void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
98+
void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, bool IsIm2Col = false);
99+
98100
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
99101
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
100102
}

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,17 +564,19 @@ foreach dim = [1, 2, 3, 4, 5] in {
564564
}
565565

566566
// From Shared to Global memory (S2G)
567-
class S2G_STRINGS<int dim, string mode, bit ch, bit is_shared32 = 0> {
568-
string prefix = "cp.async.bulk.tensor";
567+
class S2G_STRINGS<int dim, string mode, bit ch,
568+
bit is_shared32 = 0, bit is_reduce = 0> {
569569
string dir = "global.shared::cta";
570570
string completion = "bulk_group";
571-
string inst_name = prefix
571+
string inst_name = !if(is_reduce, "cp.reduce", "cp")
572+
# ".async.bulk.tensor"
572573
# "." # dim # "d"
573574
# "." # dir
574575
# "." # mode
575576
# "." # completion
576577
# !if(ch, ".L2::cache_hint", "");
577-
string intr_name = "CP_ASYNC_BULK_TENSOR_S2G_"
578+
string intr_name = "CP_ASYNC_BULK_TENSOR_"
579+
# !if(is_reduce, "RED_", "S2G_")
578580
# dim # "D"
579581
# !if(is_shared32, "_SHARED32", "")
580582
# !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
@@ -596,11 +598,37 @@ multiclass CP_ASYNC_BULK_TENSOR_S2G_INTR<int dim, bit shared32, string mode> {
596598
Requires<[hasPTX<80>, hasSM<90>]>;
597599
}
598600

601+
def TMAReductionFlags : Operand<i32> {
602+
let PrintMethod = "printTmaReductionMode";
603+
}
604+
605+
// TMA Copy from Shared to Global memory with Reduction
606+
multiclass CP_ASYNC_BULK_TENSOR_REDUCE_INTR<int dim, bit shared32, string mode> {
607+
defvar dims_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "d" # i));
608+
defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
609+
defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]";
610+
defvar rc = !if(shared32, Int32Regs, Int64Regs);
611+
612+
defvar prefix = "cp.reduce.async.bulk.tensor" # "." # dim # "d" # ".global.shared::cta";
613+
defvar suffix = "." # mode # ".bulk_group";
614+
615+
def "": NVPTXInst<(outs),
616+
!con((ins rc:$src, Int64Regs:$tmap), dims_dag, (ins TMAReductionFlags:$red_op)),
617+
!strconcat(prefix, "${red_op}", suffix, asm_str, ";"), []>,
618+
Requires<[hasPTX<80>, hasSM<90>]>;
619+
def _CH: NVPTXInst<(outs),
620+
!con((ins rc:$src, Int64Regs:$tmap), dims_dag, (ins Int64Regs:$ch, TMAReductionFlags:$red_op)),
621+
!strconcat(prefix, "${red_op}", suffix, ".L2::cache_hint", asm_str, ", $ch;"), []>,
622+
Requires<[hasPTX<80>, hasSM<90>]>;
623+
}
624+
599625
foreach dim = [1, 2, 3, 4, 5] in {
600626
foreach shared32 = [true, false] in {
601627
foreach mode = !if(!ge(dim, 3), ["tile", "im2col_no_offs"], ["tile"]) in {
602628
defm S2G_STRINGS<dim, mode, 0, shared32>.intr_name :
603629
CP_ASYNC_BULK_TENSOR_S2G_INTR<dim, shared32, mode>;
630+
defm S2G_STRINGS<dim, mode, 0, shared32, 1>.intr_name :
631+
CP_ASYNC_BULK_TENSOR_REDUCE_INTR<dim, shared32, mode>;
604632
}
605633
}
606634
}

0 commit comments

Comments
 (0)