Skip to content

Commit 8e1b88f

Browse files
Alex4210987xinxyxiaoLeiWang1999
authored
[CI][AMD] Add AMD GPU CI and fix some related bugs (#694)
* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py * Update AMD FlashAttention example and TVM submodule - Added a new example script `example_amd_flash_attn_fwd_k_block.py` for FlashAttention with K-blocking support. - Enhanced `example_amd_flash_attn_fwd.py` by expanding configuration options for block sizes and threads. - Updated the TVM submodule to the latest commit for improved functionality. - Introduced a new test script `test.sh` to facilitate running the new example with specified parameters. * Add CI workflow for automated format checking and testing - Introduced a new GitHub Actions workflow in `amd_ci.yml` to automate format checks and testing for pull requests. - The workflow includes steps for setting up a Python environment, running format checks, and executing tests. - Removed obsolete example script `example_amd_flash_attn_fwd_k_block.py` and test script `test.sh` to streamline the examples directory. * Rename CI workflow from "CI" to "AMD CI" for clarity and specificity. * Update AMD CI workflow to include copying PyTorch, TorchVision, and Torchaudio packages to the virtual environment for improved dependency management. * Update AMD CI workflow to install pytest directly instead of using requirements-test.txt * Update AMD CI workflow to remove 'flash-attn' from requirements and install dependencies from requirements-test.txt * Refactor AMD CI workflow to enhance clarity in removing 'flash-attn' from requirements-test.txt before installation * Remove Torchaudio package copying from AMD CI workflow to streamline dependency management. * Refactor AMD CI workflow to remove the format-check job and streamline the build-test process by directly copying PyTorch and TorchVision packages to the virtual environment. * Add installation of ROCm in AMD CI workflow - Included a step to execute the `install_rocm.sh` script for improved setup. - Removed unnecessary blank line for better readability in the workflow script. * Remove installation step for ROCm in AMD CI workflow to simplify the setup process. * Update AMD CI workflow to run specific test file with verbose output instead of all tests. * Add new tilelang built-in operations for AMD architecture - Introduced `tvm_mfma`, `tvm_mfma_store`, `tvm_rdna_wmma`, and `tvm_rdna_wmma_store` built-in operations to enhance support for matrix multiplication and storage in tilelang. - Each operation is configured with the appropriate number of inputs and marked as opaque in terms of call effects. * Enhance autotuner configurations and GEMM operations in AMD example - Updated block sizes and num_split_q parameters in `get_configs` for improved autotuning. - Modified `T.gemm` calls in `fast_flashattn` to utilize `GemmWarpPolicy.FullRow`, optimizing performance for matrix multiplications. * Update autotuner configurations in AMD example for enhanced performance - Refined block sizes, thread counts, and added new parameters in `get_configs` to optimize autotuning. - Adjusted `fast_flashattn` function to incorporate new parameters for panel size and coalesced widths, improving memory access patterns. * Enhance autotuner configurations and memory handling in AMD example - Expanded block sizes and thread counts in `get_configs` for improved autotuning capabilities. - Updated `fast_flashattn` to utilize a new shared memory allocation strategy, optimizing memory access patterns during GEMM operations. * Refine autotuner configurations and memory usage in AMD example - Reduced block sizes and adjusted thread counts in `get_configs` for optimized autotuning. - Updated `fast_flashattn` to utilize register fragments for accumulation, minimizing LDS usage and enhancing performance during GEMM operations. * Update autotuner configurations in AMD example for enhanced performance - Expanded block sizes and thread counts in `get_configs` to improve autotuning capabilities. - Adjusted `num_split_q` and `v_coalesced_width` parameters for better optimization during GEMM operations. * Enhance autotuner configurations and GEMM operations in AMD example - Expanded thread counts in `get_configs` to include higher values for improved autotuning. - Updated `fast_flashattn` to adjust accumulation logic and ensure proper handling of causal conditions, optimizing performance during matrix multiplications. * Update AMD CI workflow and remove obsolete test script - Modified the CI workflow to run on multiple environments: self-hosted, amd, and gpu. - Deleted the outdated `test.sh` script from the examples directory, streamlining the project structure. * Remove TVM subproject from 3rdparty directory * Refactor configuration generation and accumulation logic in AMD example - Reformatted the `get_configs` function for improved readability by aligning parameters. - Adjusted the `fast_flashattn` function to enhance clarity in the conditional logic for accumulation, ensuring better handling of causal conditions. * Enhance AMD CI workflow with additional logging and setup steps - Added echo statements to provide feedback during the CI process, indicating when the environment is running on an AMD GPU, copying necessary packages, and installing requirements. - Improved clarity in the workflow by explicitly stating when the project is being installed and when tests are being executed. * Comment out package copying in AMD CI workflow to prevent potential issues during environment setup * Update AMD CI workflow to install nightly versions of PyTorch and remove obsolete package copying steps * Enhance BuildTileLangHIP function by adding whitespace for improved readability * Refactor kTVMGridConstant definition for clarity and remove unnecessary comment * Update TVM subproject to latest commit a64a5926a6e59f5417ef2501f9d88b467337cf6a * lint fix * Update AMD CI workflow to use requirements-rocm.txt for dependency installation * fix ci * Remove dependency on format-check from AMD CI workflow * fix ci * fix ci * fix ci * Remove format-check job from AMD CI workflow * Add torch to requirements-rocm.txt and remove explicit pip install commands from AMD CI workflow * Add dependency on format-check job in AMD CI workflow * Add format-check job to AMD CI workflow * Update format-check job in AMD CI workflow to run on self-hosted environment * Enhance format-check job in AMD CI workflow with improved Python environment setup and automatic commit of lint changes * Update amd_ci.yml --------- Co-authored-by: xinxyxiao <xinyxiao@amd.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent d074286 commit 8e1b88f

File tree

7 files changed

+246
-40
lines changed

7 files changed

+246
-40
lines changed

.github/workflows/amd_ci.yml

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
name: CI Test on AMD
2+
on: [pull_request]
3+
4+
env:
5+
PYTHON_VERSION: '3.12'
6+
VENV_DIR: tilelang_ci
7+
PYTORCH_INDEX_URL: https://download.pytorch.org/whl/nightly/rocm6.3/
8+
9+
jobs:
10+
format-check:
11+
runs-on: [self-hosted, amd, gpu]
12+
13+
permissions:
14+
contents: write
15+
16+
steps:
17+
- name: Checkout repository
18+
uses: actions/checkout@v4
19+
with:
20+
fetch-depth: 0
21+
22+
- name: Set up Python
23+
uses: actions/setup-python@v2
24+
with:
25+
python-version: ${{ env.PYTHON_VERSION }}
26+
27+
- name: Ensure venv (local & persistent)
28+
run: |
29+
set -e
30+
REQS_HASH=$(sha256sum requirements-test.txt 2>/dev/null | awk '{print $1}' || echo "no_requirements")
31+
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
32+
33+
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
34+
echo "venv exists and hash matches – reuse it"
35+
else
36+
echo "venv stale or missing – recreating"
37+
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
38+
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
39+
# shellcheck source=/dev/null
40+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
41+
python -m pip install --upgrade pip --no-user
42+
[[ -f requirements-test.txt ]] && \
43+
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
44+
pip install flash_attn==2.5.8 --no-user --no-build-isolation
45+
touch "$MARKER"
46+
fi
47+
48+
- name: Run format check
49+
run: |
50+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
51+
if ! output=$(./format.sh 2>&1); then
52+
echo "------------------------------------"
53+
echo "message:"
54+
echo "$output"
55+
printf '%s\n' "$output" | grep "Please review and stage the changes."
56+
echo "------------------------------------"
57+
exit 1
58+
fi
59+
60+
- name: Commit and Push Changes
61+
uses: stefanzweifel/git-auto-commit-action@v5
62+
with:
63+
commit_message: "lint"
64+
65+
build-test-amd:
66+
runs-on: [self-hosted, amd, gpu]
67+
needs: format-check
68+
permissions:
69+
contents: read
70+
steps:
71+
- name: Checkout repository
72+
uses: actions/checkout@v4
73+
with:
74+
fetch-depth: 0
75+
repository: ${{ github.event.pull_request.head.repo.full_name }}
76+
ref: ${{ github.event.pull_request.head.ref }}
77+
78+
- name: Set up Python
79+
uses: actions/setup-python@v2
80+
with:
81+
python-version: ${{ env.PYTHON_VERSION }}
82+
83+
- name: Ensure venv (local & persistent)
84+
run: |
85+
echo "Running on AMD GPU"
86+
set -e
87+
REQS_HASH=$(sha256sum requirements-rocm.txt | cut -d ' ' -f 1)
88+
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
89+
90+
echo "Installing requirements"
91+
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
92+
echo "venv exists and hash matches – reuse it"
93+
else
94+
echo "venv stale or missing – recreating"
95+
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
96+
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
97+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
98+
python -m pip install --upgrade pip --no-user
99+
if [[ -f requirements-rocm.txt ]]; then
100+
pip install --pre torch torchvision torchaudio --index-url ${{ env.PYTORCH_INDEX_URL }}
101+
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-rocm.txt
102+
fi
103+
104+
USE_ROCM=True pip install . --no-user
105+
touch "$MARKER"
106+
fi
107+
108+
- name: Install project (wheel form)
109+
run: |
110+
echo "Installing project (wheel form)"
111+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
112+
USE_ROCM=True pip install . --no-user
113+
114+
- name: Run tests
115+
run: |
116+
echo "Running tests"
117+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
118+
cd testing/python/amd
119+
unset PYTHONPATH
120+
python -m pytest -v test_tilelang_test_amd.py

3rdparty/tvm

Submodule tvm updated from 5a433cc to a64a592

examples/amd/example_amd_flash_attn_fwd.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn.functional as F
33
import tilelang
44
import tilelang.language as T
5+
from tilelang.primitives.gemm.base import GemmWarpPolicy
56
import itertools
67
import argparse
78
from functools import partial
@@ -29,26 +30,35 @@ def ref_program(Q, K, V, is_causal, groups=1):
2930

3031
def get_configs():
3132
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
32-
block_M = [64, 128, 256]
33-
block_N = [32, 64, 128]
34-
threads = [128, 256, 512]
35-
num_split_q = [32, 64, 128]
36-
num_stages = [0, 1, 2]
37-
enable_rasterization = [True, False]
38-
k_pack = [1, 2]
33+
block_M = [32, 64, 128, 256]
34+
block_N = [32, 64, 128, 256]
35+
threads = [64, 128, 192, 256, 512, 1024]
36+
num_split_q = [32, 64, 128, 256, 256]
37+
num_stages = [0]
38+
enable_rasterization = [True]
39+
k_pack = [2]
40+
panel_size = [7, 8, 9, 10]
41+
qk_coalesced_width = [8]
42+
v_coalesced_width = [4]
3943

4044
valid_configs = []
4145

42-
for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads,
43-
num_stages, enable_rasterization, k_pack):
46+
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
47+
threads, num_stages,
48+
enable_rasterization, k_pack,
49+
panel_size, qk_coalesced_width,
50+
v_coalesced_width):
4451
valid_configs.append({
4552
"block_M": m,
4653
"block_N": n,
4754
"num_split_q": s,
4855
"threads": t,
4956
"num_stages": stages,
5057
"enable_rasterization": r,
51-
"k_pack": k
58+
"k_pack": k,
59+
"panel_size": p,
60+
"qk_coalesced_width": qkw,
61+
"v_coalesced_width": vw,
5262
})
5363
valid_configs.append({
5464
'block_M': 64,
@@ -57,7 +67,10 @@ def get_configs():
5767
'threads': 256,
5868
'num_stages': 1,
5969
'enable_rasterization': True,
60-
'k_pack': 2
70+
'k_pack': 2,
71+
'panel_size': 64,
72+
'qk_coalesced_width': 8,
73+
'v_coalesced_width': 8,
6174
})
6275
return valid_configs
6376

@@ -78,6 +91,9 @@ def fast_flashattn(
7891
num_stages: int,
7992
enable_rasterization: bool,
8093
k_pack: int,
94+
panel_size: int,
95+
qk_coalesced_width: int,
96+
v_coalesced_width: int,
8197
):
8298
scale = (1.0 / dim)**0.5 * 1.44269504
8399
head_kv = heads // groups
@@ -86,8 +102,8 @@ def fast_flashattn(
86102
dtype = "float16"
87103
accum_dtype = "float"
88104

89-
v_vec_size = 4
90-
vec_size = 4 * k_pack
105+
vec_size = qk_coalesced_width
106+
v_vec_size = v_coalesced_width
91107

92108
@T.prim_func
93109
def main(
@@ -97,31 +113,32 @@ def main(
97113
Output: T.Tensor(q_shape, dtype),
98114
):
99115
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
100-
T.use_swizzle(10, enable=enable_rasterization)
116+
T.use_swizzle(panel_size, enable=enable_rasterization)
101117

102118
bz = byz_combined // heads
103119
by = byz_combined % heads
104120

105121
num_q_blocks = T.ceildiv(seq_len, block_M)
106122

107123
bx = T.alloc_var("int32")
108-
bx[0] = b_split
124+
bx = b_split
109125

110-
with T.While(bx[0] < num_q_blocks):
126+
with T.While(bx < num_q_blocks):
111127
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
112128
m_i = T.alloc_fragment([block_M], accum_dtype)
113129
l_i = T.alloc_fragment([block_M], accum_dtype)
114130
T.fill(acc_o, 0)
115131
T.fill(m_i, -T.infinity(accum_dtype))
116132
T.fill(l_i, 0)
117133

118-
current_bx = bx[0]
134+
current_bx = bx
119135
q_block_offset = current_bx * block_M
120136

121137
Q_shared = T.alloc_shared([block_M, dim], dtype)
122138
K_shared = T.alloc_shared([block_N, dim], dtype)
123139
V_shared = T.alloc_shared([block_N, dim], dtype)
124-
P_shared = T.alloc_shared([block_M, block_N], dtype)
140+
# Use register fragment for P instead of shared memory to reduce LDS usage
141+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
125142

126143
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
127144
m_prev = T.alloc_fragment([block_M], accum_dtype)
@@ -135,6 +152,8 @@ def main(
135152
loop_end_k = T.ceildiv(q_block_offset + block_M,
136153
block_N) if is_causal else T.ceildiv(seq_len, block_N)
137154

155+
row_sum = T.alloc_fragment([block_M], accum_dtype)
156+
138157
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
139158
kv_idx = k * block_N
140159

@@ -147,13 +166,20 @@ def main(
147166
V_shared,
148167
coalesced_width=v_vec_size)
149168

150-
T.clear(acc_s)
151-
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack)
152-
153169
if is_causal:
154170
for i, j in T.Parallel(block_M, block_N):
155-
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j,
156-
acc_s[i, j], -T.infinity(acc_s.dtype))
171+
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
172+
-T.infinity(acc_s.dtype))
173+
else:
174+
T.clear(acc_s)
175+
T.gemm(
176+
Q_shared,
177+
K_shared,
178+
acc_s,
179+
transpose_B=True,
180+
k_pack=k_pack,
181+
policy=GemmWarpPolicy.FullRow,
182+
)
157183

158184
T.copy(m_i, m_prev)
159185
T.reduce_max(acc_s, m_i, dim=1, clear=False)
@@ -169,15 +195,14 @@ def main(
169195
for i, j in T.Parallel(block_M, block_N):
170196
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
171197

172-
row_sum = T.alloc_fragment([block_M], accum_dtype)
173198
T.reduce_sum(acc_s, row_sum, dim=1)
174199
for i in T.Parallel(block_M):
175200
l_i[i] += row_sum[i]
176201

177-
T.copy(acc_s, P_shared)
178-
T.sync_threads()
202+
# Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V
203+
T.copy(acc_s, acc_s_cast)
179204

180-
T.gemm(P_shared, V_shared, acc_o)
205+
T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
181206

182207
l_inv = T.alloc_fragment([block_M], accum_dtype)
183208
for i in T.Parallel(block_M):
@@ -187,7 +212,7 @@ def main(
187212
for i, j in T.Parallel(block_M, dim):
188213
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
189214

190-
bx[0] = current_bx + num_split_q
215+
bx = current_bx + num_split_q
191216

192217
return main
193218

requirements-rocm.txt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# lint requirements
2+
-r requirements-lint.txt
3+
# build requirements
4+
Cython
5+
cmake>=3.26
6+
# runtime requirements
7+
cffi
8+
cpplint
9+
Cython
10+
docutils
11+
dtlib
12+
numpy>=1.23.5
13+
pytest>=6.2.4
14+
pytest_xdist>=2.2.1
15+
packaging>=21.0
16+
PyYAML
17+
tqdm>=4.62.3
18+
typing_extensions>=4.10.0
19+
requests
20+
cloudpickle
21+
ml_dtypes
22+
psutil
23+
torch
24+
tabulate
25+
wheel
26+
setuptools
27+
einops
28+
scipy
29+
tornado

src/op/builtin.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,24 @@ TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
141141
.set_attr<TCallEffectKind>("TCallEffectKind",
142142
Integer(CallEffectKind::kOpaque));
143143

144+
TIR_DEFINE_TL_BUILTIN(tvm_mfma).set_num_inputs(12).set_attr<TCallEffectKind>(
145+
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
146+
147+
TIR_DEFINE_TL_BUILTIN(tvm_mfma_store)
148+
.set_num_inputs(6)
149+
.set_attr<TCallEffectKind>("TCallEffectKind",
150+
Integer(CallEffectKind::kOpaque));
151+
152+
TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma)
153+
.set_num_inputs(12)
154+
.set_attr<TCallEffectKind>("TCallEffectKind",
155+
Integer(CallEffectKind::kOpaque));
156+
157+
TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma_store)
158+
.set_num_inputs(6)
159+
.set_attr<TCallEffectKind>("TCallEffectKind",
160+
Integer(CallEffectKind::kOpaque));
161+
144162
TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
145163
.set_num_inputs(1)
146164
.set_attr<TCallEffectKind>("TCallEffectKind",

src/target/codegen_hip.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "codegen_hip.h"
66
#include <tvm/arith/analyzer.h>
7-
#include <tvm/runtime/registry.h>
7+
#include <tvm/ffi/function.h>
88
#include <tvm/tir/index_map.h>
99
#include <tvm/tir/op.h>
1010

@@ -882,7 +882,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
882882
this->PrintExpr(op->args[i * 2 + 1], os);
883883
os << "]" << ((i < 3) ? ", " : ")");
884884
}
885-
} else if (op->op.same_as(builtin::tvm_mfma())) {
885+
} else if (op->op.same_as(tl::tvm_mfma())) {
886886
// arg 0: prefix: {otype}_16x16x16{itype}
887887
// arg 1: A layout: row/col
888888
// arg 2: B layout: row/col

0 commit comments

Comments
 (0)