Skip to content

Commit 4b38278

Browse files
[Feat] Introduce an offset option in threadblock swizzle (#668)
* [Feat] Introduce an offset option in threadblock swizzle * rebase ci --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 94bd9ce commit 4b38278

File tree

3 files changed

+81
-39
lines changed

3 files changed

+81
-39
lines changed

.github/workflows/ci.yml

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,118 @@
11
name: CI
2-
32
on: [pull_request]
43

4+
env:
5+
PYTHON_VERSION: '3.12'
6+
VENV_DIR: tilelang_ci
7+
58
jobs:
69
format-check:
710
runs-on: self-hosted
811

12+
permissions:
13+
contents: write
14+
915
steps:
1016
- name: Checkout repository
11-
uses: actions/checkout@v2
17+
uses: actions/checkout@v4
1218
with:
1319
fetch-depth: 0
1420

1521
- name: Set up Python
1622
uses: actions/setup-python@v2
1723
with:
18-
python-version: '3.9'
24+
python-version: ${{ env.PYTHON_VERSION }}
1925

20-
- name: Create virtual environment
21-
run: python -m venv tilelang_ci
22-
23-
- name: Activate virtual environment and install dependencies
26+
- name: Ensure venv (local & persistent)
2427
run: |
25-
source tilelang_ci/bin/activate
26-
python -m pip install --upgrade pip
27-
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi
28-
29-
- name: Update submodules recursively
30-
run: git submodule update --init --recursive
28+
set -e
29+
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
30+
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
31+
32+
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
33+
echo "venv exists and hash matches – reuse it"
34+
else
35+
echo "venv stale or missing – recreating"
36+
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
37+
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
38+
# shellcheck source=/dev/null
39+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
40+
python -m pip install --upgrade pip --no-user
41+
[[ -f requirements-test.txt ]] && \
42+
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
43+
touch "$MARKER"
44+
fi
3145
3246
- name: Run format check
3347
run: |
34-
source tilelang_ci/bin/activate
35-
./format.sh
48+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
49+
if ! output=$(./format.sh 2>&1); then
50+
echo "------------------------------------"
51+
echo "message:"
52+
echo "$output"
53+
printf '%s\n' "$output" | grep "Please review and stage the changes."
54+
echo "------------------------------------"
55+
exit 1
56+
fi
57+
58+
- name: Commit and Push Changes
59+
uses: stefanzweifel/git-auto-commit-action@v5
60+
with:
61+
commit_message: "lint"
3662

3763
build-test:
3864
runs-on: self-hosted
3965
needs: format-check
40-
66+
permissions:
67+
contents: read
4168
steps:
4269
- name: Checkout repository
43-
uses: actions/checkout@v2
70+
uses: actions/checkout@v4
4471
with:
4572
fetch-depth: 0
73+
repository: ${{ github.event.pull_request.head.repo.full_name }}
74+
ref: ${{ github.event.pull_request.head.ref }}
4675

4776
- name: Set up Python
4877
uses: actions/setup-python@v2
4978
with:
50-
python-version: '3.9'
79+
python-version: ${{ env.PYTHON_VERSION }}
5180

52-
- name: Create virtual environment
53-
run: python -m venv tilelang_ci
54-
55-
- name: Activate virtual environment and install dependencies
81+
- name: Ensure venv (local & persistent)
5682
run: |
57-
source tilelang_ci/bin/activate
58-
python -m pip install --upgrade pip
59-
if [ -f requirements-test.txt ]; then PIP_NO_BUILD_ISOLATION=1 python -m pip install -r requirements-test.txt; fi
60-
61-
- name: Install project in wheel mode
83+
set -e
84+
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
85+
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
86+
87+
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
88+
echo "venv exists and hash matches – reuse it"
89+
else
90+
echo "venv stale or missing – recreating"
91+
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
92+
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
93+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
94+
python -m pip install --upgrade pip --no-user
95+
[[ -f requirements-test.txt ]] && \
96+
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
97+
pip install . --no-user
98+
touch "$MARKER"
99+
fi
100+
101+
- name: Install project (wheel form)
62102
run: |
63-
source tilelang_ci/bin/activate
64-
python -m pip install .
103+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
104+
pip install . --no-user
65105
66106
- name: Run examples
67107
run: |
68-
source tilelang_ci/bin/activate
108+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
69109
cd examples
70-
python -m pytest **/test*.py
110+
unset PYTHONPATH
111+
python -m pytest -n 8 **/test*.py
71112
72113
- name: Run tests
73114
run: |
74-
source tilelang_ci/bin/activate
115+
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
75116
cd testing/python
76-
python -m pytest
117+
unset PYTHONPATH
118+
python -m pytest -n 8

src/tl_templates/cuda/threadblock_swizzle.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace tl {
66

7-
template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
7+
template <int panel_width, int offset> TL_DEVICE dim3 rasterization2DRow() {
88
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
99
const unsigned int grid_size = gridDim.x * gridDim.y;
1010
const unsigned int panel_size = panel_width * gridDim.x;
@@ -18,11 +18,11 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
1818
const unsigned int col_idx = (panel_idx & 1)
1919
? gridDim.x - 1 - panel_offset / stride
2020
: panel_offset / stride;
21-
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
21+
const unsigned int row_idx = (panel_offset % stride + panel_idx * panel_width + offset) % gridDim.y;
2222
return {col_idx, row_idx, blockIdx.z};
2323
}
2424

25-
template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
25+
template <int panel_width, int offset> TL_DEVICE dim3 rasterization2DColumn() {
2626
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
2727
const unsigned int grid_size = gridDim.x * gridDim.y;
2828
const unsigned int panel_size = panel_width * gridDim.y;
@@ -36,7 +36,7 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
3636
const unsigned int row_idx = (panel_idx & 1)
3737
? gridDim.y - 1 - panel_offset / stride
3838
: panel_offset / stride;
39-
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
39+
const unsigned int col_idx = (panel_offset % stride + panel_idx * panel_width + offset) % gridDim.x;
4040
return {col_idx, row_idx, blockIdx.z};
4141
}
4242

tilelang/language/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ def symbolic(name: str, dtype: str = "int32"):
7676
return tir.Var(name, dtype)
7777

7878

79-
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
79+
def use_swizzle(panel_size: int, order: str = "row", offset: int = 0, enable: bool = True):
8080
# If order is row, use rasterization2DRow, otherwise use rasterization2DColumn
8181
# The panel size is the number of threads in a warp
8282
# Use to improve the L2 Cache Locality
8383
device_func = ("rasterization2DRow" if order == "row" else "rasterization2DColumn")
8484
return attr(None, "threadblock_swizzle_pattern",
85-
f"tl::{device_func}<{panel_size}>") if enable else None
85+
f"tl::{device_func}<{panel_size}, {offset}>") if enable else None
8686

8787

8888
def annotate_layout(layout_map: Dict):

0 commit comments

Comments
 (0)