Skip to content

Commit a9c8724

Browse files
micmelessekeryell
andauthored
[ROCM] Enable ROCM Backend #1: Empty Kernel (triton-lang#1312)
This PR is a first in a series of PRs to import the changes that we have made to enable ROCM on [our fork](https://github.com/ROCmSoftwarePlatform/triton) of triton. The PR contains the major changes to the python frontend and enough changes to the c++ backend to allow compilation and running of the empty kernel. We use the ROCM ci added a few weeks ago to verify things. --------- Co-authored-by: Ronan Keryell <ronan@keryell.fr>
1 parent 89d8fe6 commit a9c8724

33 files changed

+1602
-131
lines changed

.github/workflows/integration-tests.yml

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
id: set-matrix
2323
run: |
2424
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
25-
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], "macos-10.15"]'
25+
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"], "macos-10.15"]'
2626
else
2727
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
2828
fi
@@ -40,6 +40,16 @@ jobs:
4040
- name: Checkout
4141
uses: actions/checkout@v2
4242

43+
- name: Set CUDA ENV
44+
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
45+
run: |
46+
echo "BACKEND=CUDA" >> $GITHUB_ENV
47+
48+
- name: Set ROCM ENV
49+
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'gfx908')}}
50+
run: |
51+
echo "BACKEND=ROCM" >> $GITHUB_ENV
52+
4353
- name: Clear cache
4454
run: |
4555
rm -rf ~/.triton/
@@ -74,12 +84,22 @@ jobs:
7484
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
7585
7686
- name: Install Triton
87+
if: ${{ env.BACKEND != 'ROCM'}}
7788
run: |
7889
cd python
7990
pip3 install cmake==3.24
8091
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
8192
93+
- name: Install Triton on ROCM
94+
if: ${{ env.BACKEND == 'ROCM'}}
95+
run: |
96+
cd python
97+
pip3 uninstall --yes torch torchvision torchaudio
98+
pip3 install --no-cache-dir --force-reinstall torch==1.13.1 --extra-index-url https://download.pytorch.org/whl/rocm5.2
99+
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
100+
82101
- name: Run lit tests
102+
if: ${{ env.BACKEND != 'ROCM'}}
83103
run: |
84104
pip3 install lit
85105
cd python
@@ -89,13 +109,20 @@ jobs:
89109
fi
90110
lit -v "$LIT_TEST_DIR"
91111
92-
- name: Run python tests
93-
if: ${{matrix.runner[0] == 'self-hosted'}}
112+
- name: Run python tests on CUDA
113+
if: ${{ env.BACKEND == 'CUDA'}}
94114
run: |
95115
cd python/test/unit/
96116
pytest
117+
118+
- name: Run python tests on ROCM
119+
if: ${{ env.BACKEND == 'ROCM'}}
120+
run: |
121+
cd python/test/unit/language/
122+
pytest --capture=tee-sys -rfs --verbose "test_core.py::test_empty_kernel"
97123
98124
- name: Run CXX unittests
125+
if: ${{ env.BACKEND != 'ROCM'}}
99126
run: |
100127
cd python/
101128
cd "build/$(ls build)"

.gitignore

100644100755
File mode changed.

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
216216
TritonGPUTransforms
217217
TritonLLVMIR
218218
TritonPTX
219+
TritonHSACO
219220
${dialect_libs}
220221
${conversion_libs}
221222

@@ -228,6 +229,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
228229
MLIRExecutionEngine
229230
MLIRMathToLLVM
230231
MLIRNVVMToLLVMIRTranslation
232+
MLIRROCDLToLLVMIRTranslation
231233
MLIRIR
232234
)
233235

bin/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ llvm_update_compile_flags(triton-translate)
3030
TritonGPUTransforms
3131
TritonLLVMIR
3232
TritonPTX
33+
TritonHSACO
3334
${dialect_libs}
3435
${conversion_libs}
3536
# tests
@@ -53,5 +54,6 @@ llvm_update_compile_flags(triton-translate)
5354
MLIRTransformUtils
5455
MLIRLLVMToLLVMIRTranslation
5556
MLIRNVVMToLLVMIRTranslation
57+
MLIRROCDLToLLVMIRTranslation
5658
)
5759
mlir_check_all_link_libraries(triton-translate)

bin/triton-opt.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include "triton/Dialect/Triton/Transforms/Passes.h"
55
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
66

7-
#include "triton/Conversion/Passes.h"
7+
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
8+
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
89

910
#include "mlir/IR/Dialect.h"
1011
#include "mlir/InitAllPasses.h"

bin/triton-translate.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
1515
#include "triton/Dialect/Triton/IR/Dialect.h"
1616
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
17+
#include "triton/Target/HSACO/HSACOTranslation.h"
1718
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
1819
#include "triton/Target/PTX/PTXTranslation.h"
1920
#include "llvm/IR/LLVMContext.h"
2021
#include "llvm/Support/CommandLine.h"
2122
#include "llvm/Support/InitLLVM.h"
2223
#include "llvm/Support/SourceMgr.h"
2324
#include "llvm/Support/ToolOutputFile.h"
24-
#include <iostream>
2525

2626
namespace mlir {
2727
namespace triton {
@@ -79,7 +79,8 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
7979
llvm::cl::init("-"));
8080

8181
static llvm::cl::opt<std::string> targetKind(
82-
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
82+
"target",
83+
llvm::cl::desc("<translation target, options: llvmir/ptx/hsaco>"),
8384
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
8485

8586
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
@@ -88,6 +89,18 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
8889
static llvm::cl::opt<int> ptxVersion(
8990
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
9091

92+
static llvm::cl::opt<std::string> GCNArch(
93+
"gfx", llvm::cl::desc("AMDGCN target. e.g. '90a'"),
94+
llvm::cl::value_desc("architecture"), llvm::cl::init("90a"));
95+
96+
static llvm::cl::opt<std::string> GCNTriple(
97+
"amdgcn", llvm::cl::desc("AMDGCN triple. e.g. '-amd-amdhsa'"),
98+
llvm::cl::value_desc("target triple"), llvm::cl::init("-amd-amdhsa"));
99+
100+
static llvm::cl::opt<std::string> GCNFeatures(
101+
"", llvm::cl::desc("AMDGCN features. e.g. '+sramecc,-xnack'"),
102+
llvm::cl::value_desc("features"), llvm::cl::init("+sramecc,-xnack"));
103+
91104
llvm::InitLLVM y(argc, argv);
92105

93106
registerAsmPrinterCLOptions();
@@ -119,6 +132,15 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
119132
else if (targetKind == "ptx")
120133
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
121134
ptxVersion.getValue());
135+
else if (targetKind == "hsaco") {
136+
auto [module, hsaco] = ::triton::translateLLVMIRToHSACO(
137+
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
138+
GCNFeatures.getValue());
139+
llvm::outs() << hsaco;
140+
} else {
141+
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
142+
return failure();
143+
}
122144

123145
return success();
124146
}
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
set(LLVM_TARGET_DEFINITIONS Passes.td)
2-
mlir_tablegen(Passes.h.inc -gen-pass-decls)
3-
add_public_tablegen_target(TritonConversionPassIncGen)
1+
add_subdirectory(TritonToTritonGPU)
2+
add_subdirectory(TritonGPUToLLVM)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
2+
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
3+
4+
#include "mlir/IR/Value.h"
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
6+
#include "llvm/ADT/SmallVector.h"
7+
#include "llvm/ADT/StringRef.h"
8+
#include <memory>
9+
#include <string>
10+
11+
namespace mlir {
12+
class ConversionPatternRewriter;
13+
class Location;
14+
15+
namespace triton {
16+
using llvm::StringRef;
17+
18+
inline std::string strJoin(llvm::ArrayRef<std::string> strs,
19+
llvm::StringRef delimiter) {
20+
std::string osStr;
21+
llvm::raw_string_ostream os(osStr);
22+
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; ++i)
23+
os << strs[i] << delimiter;
24+
if (!strs.empty())
25+
os << strs.back();
26+
os.flush();
27+
return osStr;
28+
}
29+
30+
} // namespace triton
31+
} // namespace mlir
32+
33+
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM)
3+
add_public_tablegen_target(TritonGPUConversionPassIncGen)

0 commit comments

Comments
 (0)