Skip to content

Commit 1b308ba

Browse files
[Language] Introduce StridedTensor to support non contigious torch inputs (#722)
* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Support strided tensors * Refactor target attribute helper functions for improved clarity * No code changes made in proxy.py and setup.py * lint fix * lint fix via gemini * lint fix * test fix * test fix * lint fix * Update wrapper.py * test fix * Enhance test for InjectSoftwarePipeline by adding LowerOpaqueBlock transformation and updating expected function signature to use match_buffer for better clarity. * lint fix --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
1 parent c369d69 commit 1b308ba

File tree

17 files changed

+430
-158
lines changed

17 files changed

+430
-158
lines changed

examples/fusedmoe/example_fusedmoe_tilelang.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from tilelang.autotuner import *
88
from example_fusedmoe_torch import *
99

10-
# tilelang.disable_cache()
11-
1210

1311
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
1412
def moe_forward_tilelang_shared(d_hidden,

examples/warp_specialize/example_warp_specialize_flashmla.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,10 @@ def flash_attn(
145145
clear_accum=True,
146146
wg_wait=-1)
147147
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
148-
T.gemm(
149-
Q_shared_r,
150-
KV_shared_0_r,
151-
acc_s_0,
152-
transpose_B=True,
153-
wg_wait=-1)
148+
T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
154149

155150
T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
156-
T.gemm(
157-
Q_pe_local_0,
158-
K_pe_shared_0,
159-
acc_s_0,
160-
transpose_B=True,
161-
wg_wait=-1)
151+
T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1)
162152

163153
T.wait_wgmma(0)
164154

@@ -261,20 +251,10 @@ def flash_attn(
261251
wg_wait=-1)
262252

263253
T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
264-
T.gemm(
265-
Q_shared_r,
266-
KV_shared_1_r,
267-
acc_s_1,
268-
transpose_B=True,
269-
wg_wait=-1)
254+
T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
270255

271256
T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
272-
T.gemm(
273-
Q_pe_local_1,
274-
K_pe_shared_1,
275-
acc_s_1,
276-
transpose_B=True,
277-
wg_wait=-1)
257+
T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1)
278258

279259
T.wait_wgmma(0)
280260

@@ -308,11 +288,7 @@ def flash_attn(
308288

309289
# Step 10. compute O1 with KV_shared_1_rd
310290
T.copy(acc_s_1, acc_s_1_cast)
311-
T.gemm(
312-
acc_s_1_cast,
313-
KV_shared_1_r,
314-
acc_o_r,
315-
wg_wait=-1)
291+
T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1)
316292
T.copy(acc_s_1_cast, SP1_shared)
317293
T.barrier_arrive(s_shared_ready_barrier)
318294

setup.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import fcntl
2+
import functools
3+
import hashlib
14
import io
25
import subprocess
36
import shutil
@@ -12,17 +15,14 @@
1215
import os
1316
import sys
1417
import site
15-
import hashlib
1618
import sysconfig
17-
import functools
1819
import urllib.request
1920
from packaging.version import Version
2021
import platform
2122
import multiprocessing
2223
from setuptools.command.build_ext import build_ext
2324
import importlib
2425
import logging
25-
import fcntl
2626

2727
# Configure logging with basic settings
2828
logging.basicConfig(
@@ -692,15 +692,15 @@ def build_cython(self, ext):
692692
with open(md5_path, "r") as f:
693693
cached_hash = f.read().strip()
694694
if cached_hash == code_hash:
695-
logger.info("Cython jit adapter is up to date, no need to compile...")
695+
logger.info("Cython JIT adapter is up to date, no need to compile...")
696696
need_compile = False
697697
else:
698-
logger.info("Cython jit adapter is out of date, need to recompile...")
698+
logger.info("Cython JIT adapter is out of date, need to recompile...")
699699
else:
700-
logger.info("No cached version found for cython jit adapter, need to compile...")
700+
logger.info("No cached version found for Cython JIT adapter, need to compile...")
701701

702702
if need_compile:
703-
logger.info("Waiting for lock to compile cython jit adapter...")
703+
logger.info("Waiting for lock to compile Cython JIT adapter...")
704704
with open(lock_file, 'w') as lock:
705705
fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
706706
try:
@@ -715,7 +715,7 @@ def build_cython(self, ext):
715715
need_compile = False
716716

717717
if need_compile:
718-
logger.info("Compiling cython jit adapter...")
718+
logger.info("Compiling Cython JIT adapter...")
719719
temp_path = cache_dir / f"temp_{code_hash}.so"
720720

721721
with open(md5_path, "w") as f:
@@ -736,7 +736,7 @@ def build_cython(self, ext):
736736
except Exception as e:
737737
if 'temp_path' in locals() and temp_path.exists():
738738
temp_path.unlink()
739-
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
739+
raise Exception(f"Failed to compile Cython JIT adapter: {e}") from e
740740
finally:
741741
if lock_file.exists():
742742
lock_file.unlink()

src/target/codegen_cuda.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
17021702
os << "))";
17031703
}
17041704

1705+
void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
1706+
std::ostream &os) { // NOLINT(*)
1707+
ICHECK_EQ(op->indices.size(), 1)
1708+
<< "Load from non-flat memory not supported.";
1709+
ICHECK(!op->predicate.defined())
1710+
<< "Predicated buffer load is not supported.";
1711+
1712+
DataType value_dtype = op->dtype;
1713+
PrimExpr index = op->indices[0];
1714+
Var buffer_var = op->buffer->data;
1715+
DataType element_dtype = op->buffer->dtype;
1716+
1717+
int lanes = op->dtype.lanes();
1718+
// delcare type.
1719+
if (value_dtype.lanes() == element_dtype.lanes()) {
1720+
std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
1721+
HandleVolatileLoads(ref, op, os);
1722+
} else {
1723+
bool can_vector_load = false;
1724+
arith::PVar<PrimExpr> base;
1725+
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
1726+
const RampNode *ramp = index.as<RampNode>();
1727+
ICHECK(ramp);
1728+
can_vector_load = true;
1729+
// arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
1730+
// The condition: {k * coeff + base} divisible by the alignment for any k
1731+
// if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes()
1732+
// == 0) {
1733+
// can_vector_load = true;
1734+
// }
1735+
}
1736+
1737+
if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
1738+
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
1739+
// So we cannot vector load it.
1740+
can_vector_load = false;
1741+
}
1742+
if (can_vector_load) {
1743+
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
1744+
HandleVolatileLoads(ref, op, os);
1745+
} else {
1746+
std::ostringstream svalue_expr;
1747+
std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
1748+
std::string vid = GetVarID(buffer_var.get());
1749+
DataType elem_type = op->dtype.element_of();
1750+
for (int i = 0; i < lanes; ++i) {
1751+
std::ostringstream value_temp;
1752+
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
1753+
value_temp << "((";
1754+
if (buffer_var.get()->dtype.is_handle()) {
1755+
auto it = alloc_storage_scope_.find(buffer_var.get());
1756+
if (it != alloc_storage_scope_.end()) {
1757+
PrintStorageScope(it->second, value_temp);
1758+
}
1759+
}
1760+
PrintType(elem_type, value_temp);
1761+
value_temp << "*)" << vid << ')';
1762+
} else {
1763+
value_temp << vid;
1764+
}
1765+
value_temp << '[';
1766+
PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
1767+
value_temp << ']';
1768+
PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
1769+
}
1770+
os << svalue_expr.str();
1771+
}
1772+
}
1773+
}
1774+
17051775
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
17061776
std::ostream &os) { // NOLINT(*)
17071777
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);

src/target/codegen_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class CodeGenTileLangCUDA final : public CodeGenC {
5050
void VisitStmt_(const EvaluateNode *op) final;
5151
void VisitStmt_(const AllocateNode *op) final;
5252
void VisitStmt_(const AttrStmtNode *op) final;
53+
void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final;
5354

5455
// Override this as a work around for __grid_constant__ parameter
5556
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);

src/tl_templates/hip/reduce.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ struct MinOp {
2222
}
2323
};
2424

25-
template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
25+
template <class Reducer, int threads, int scale, int thread_offset = 0>
26+
struct AllReduce {
2627
static_assert(threads == 1024 || threads == 512 || threads == 256 ||
2728
threads == 128 || threads == 64 || threads == 32 ||
2829
threads == 16 || threads == 8 || threads == 4 || threads == 2);

src/transform/loop_vectorize.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,23 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
136136
max_vector_size = gcd_base;
137137
}
138138
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
139+
140+
// Generate strides if not existed
141+
auto strides = buffer->strides;
142+
if (buffer->strides.size() == 0) {
143+
PrimExpr stride = 1;
144+
for (int i = indices.size() - 1; i >= 0; --i) {
145+
strides.push_back(stride);
146+
stride = stride * buffer->shape[i];
147+
}
148+
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
149+
}
150+
151+
// Generate and check element offset expression
152+
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
139153
PrimExpr elem_offset = 0;
140-
PrimExpr stride = 1;
141-
for (int i = indices.size() - 1; i >= 0; --i) {
142-
elem_offset = elem_offset + indices[i] * stride;
143-
stride = stride * buffer->shape[i];
154+
for (int i = 0; i < indices.size(); ++i) {
155+
elem_offset += indices[i] * strides[i];
144156
}
145157
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
146158
inner_for_->extent, vector_size_,
@@ -229,10 +241,19 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
229241
ICHECK(target_vectorized_size >= 1);
230242
if (target_vectorized_size == 1)
231243
return true;
232-
// bind thread range
244+
245+
// Extent must be divisible
233246
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
234247
0))
235248
return false;
249+
250+
// The base offset must be divisible
251+
if (!analyzer->CanProveEqual(
252+
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
253+
return false;
254+
}
255+
256+
// Bind thread range
236257
Var v0("v0"), v1("v1");
237258
analyzer->Bind(v0, Range(0, target_vectorized_size));
238259
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
@@ -241,7 +262,8 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
241262
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
242263
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
243264
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
244-
// This simplify is necessary for thread region specifiled
265+
266+
// This simplify is necessary for thread region specified
245267
// optimizations.
246268
expr_vectorized = analyzer->Simplify(expr_vectorized);
247269
auto ramp_node = expr_vectorized.as<RampNode>();

testing/python/language/test_tilelang_language_copy.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
2828
out_idx=[1],
2929
target="cuda",
3030
pass_configs={
31-
"tl.disable_warp_specialized": True,
32-
"tl.disable_tma_lower": True
31+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
32+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
3333
})
3434
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
3535
b = kernel(a)
@@ -42,5 +42,49 @@ def test_tilelang_copy():
4242
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float")
4343

4444

45+
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
46+
47+
@T.prim_func
48+
def main(
49+
A: T.StridedTensor((M, N), (NN, 1), dtype),
50+
B: T.Tensor((M, N), dtype),
51+
):
52+
# Initialize Kernel Context
53+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
54+
for i, j in T.Parallel(block_M, block_N):
55+
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j]
56+
57+
return main
58+
59+
60+
def run_tilelang_copy_with_stride(M=1024,
61+
N=1024,
62+
NN=2048,
63+
block_M=128,
64+
block_N=128,
65+
dtype="float16"):
66+
if isinstance(NN, int):
67+
assert NN > N, "NN must be greater than N"
68+
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
69+
kernel = tilelang.compile(
70+
program,
71+
out_idx=[1],
72+
target="cuda",
73+
pass_configs={
74+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
75+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
76+
})
77+
if isinstance(NN, T.Var):
78+
NN = N * 2
79+
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
80+
b = kernel(a[:, :N])
81+
torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)
82+
83+
84+
def test_tilelang_copy_with_stride():
85+
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
86+
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
87+
88+
4589
if __name__ == "__main__":
4690
tilelang.testing.main()

testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def _check(original, transformed):
99
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
1010
mod = tl.transform.InjectSoftwarePipeline()(mod)
1111
mod = tl.transform.Simplify()(mod)
12+
mod = tl.transform.LowerOpaqueBlock()(mod)
1213
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
1314
True)
1415

@@ -39,32 +40,16 @@ def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
3940
C[tx, i] = B[tx, 0] + T.float32(1)
4041

4142
@T.prim_func
42-
def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")):
43-
for tx in T.thread_binding(16, thread="threadIdx.x"):
44-
with T.block():
45-
T.reads(A[tx, 0])
46-
T.writes(C[tx, 0])
47-
B = T.alloc_buffer((2, 16, 1), scope="shared")
48-
with T.block():
49-
T.reads(A[tx, 0])
50-
T.writes(B[0, tx, 0])
51-
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
52-
with T.block():
53-
T.reads(A[tx, 1:1], B[0:2, tx, 0])
54-
T.writes(B[1:1, tx, 0], C[tx, 0:0])
55-
for i in range(0):
56-
with T.block():
57-
T.reads(A[tx, i + 1])
58-
T.writes(B[i + 1, tx, 0])
59-
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
60-
with T.block():
61-
T.reads(B[i, tx, 0])
62-
T.writes(C[tx, i])
63-
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
64-
with T.block():
65-
T.reads(B[0, tx, 0])
66-
T.writes(C[tx, 0])
67-
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
43+
def expected(A_handle: T.handle, C_handle: T.handle):
44+
A = T.match_buffer(A_handle, (16, 1), strides=(1, 1))
45+
C = T.match_buffer(C_handle, (16, 1), strides=(1, 1))
46+
tx = T.launch_thread("threadIdx.x", 16)
47+
B = T.decl_buffer((2, 16, 1), scope="shared")
48+
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
49+
for i in range(0):
50+
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
51+
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
52+
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
6853

6954
_check(before, expected)
7055

0 commit comments

Comments
 (0)