Skip to content

Commit 23ef354

Browse files
committed
Merge branch 'main' of https://github.com/tile-ai/tilelang into v2_1106
2 parents 3f211ae + 47039f0 commit 23ef354

File tree

7 files changed

+533
-18
lines changed

7 files changed

+533
-18
lines changed

CMakeLists.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,8 @@ target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES})
230230

231231
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
232232
add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>)
233-
target_link_libraries(tilelang PUBLIC tvm_runtime)
233+
target_link_libraries(tilelang PUBLIC tvm_runtime tvm)
234234
target_link_libraries(tilelang_module PUBLIC tvm)
235-
if(APPLE)
236-
# FIXME: libtilelang should only link against tvm runtime
237-
target_link_libraries(tilelang PUBLIC tvm)
238-
endif()
239235
# Build cython extension
240236
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT})
241237

src/layout/layout.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "layout.h"
77
#include <tvm/ffi/reflection/registry.h>
8+
#include <tvm/runtime/logging.h>
89

910
#include <tvm/arith/pattern.h>
1011
#include <tvm/tir/op.h>
@@ -255,8 +256,11 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
255256
}
256257
arith::IterMapResult res =
257258
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
258-
ICHECK(res->errors.empty())
259-
<< "Layout " << DebugOutput() << " has errors: " << res->errors;
259+
if (!res->errors.empty()) {
260+
std::ostringstream msg;
261+
msg << "Layout " << DebugOutput() << " has errors: " << res->errors;
262+
throw NormalizeIterException(msg.str());
263+
}
260264

261265
auto outputs_shape = OutputShape();
262266
Array<PrimExpr> outputs;

src/layout/utils.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
115115
return results;
116116
}
117117

118+
// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator
119+
// appears in fused forms across multiple index expressions. We first normalize
120+
// every index into IterSumExpr, collect all splits per source Var, then
121+
// consolidate them to avoid misclassifying a used split as unused.
118122
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
119123
const Array<IterVar> input_iters,
120124
Analyzer *analyzer) {
@@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
134138
}
135139

136140
for (const IterVar &iter : input_iters) {
137-
IterMark iv_mark;
141+
// Merge splits from all IterMark that share the same source Var as `iter`.
142+
std::vector<IterSplitExpr> merged_splits;
138143
for (const IterMark &mark : collector.visited_) {
139-
if (mark->source.as<Var>()->same_as(iter->var)) { // NOLINT(*)
140-
iv_mark = mark;
141-
break;
144+
auto vexpr = mark->source.as<Var>();
145+
if (vexpr && vexpr.value().same_as(iter->var)) {
146+
auto it = collector.mark2splits_.find(mark);
147+
if (it != collector.mark2splits_.end()) {
148+
const auto &vec = it->second;
149+
merged_splits.insert(merged_splits.end(), vec.begin(), vec.end());
150+
}
142151
}
143152
}
144-
if (iv_mark.defined()) {
145-
auto splits =
146-
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
147-
// Put the small axis last
153+
154+
if (!merged_splits.empty()) {
155+
// Use a unified mark (Var + full extent) to compute the missing pieces
156+
// so that fused usages are honored as "used" and not reintroduced.
157+
IterMark unified_mark(iter->var, iter->dom->extent);
158+
auto splits = get_unused_iters(unified_mark, merged_splits, analyzer);
159+
// Put the small axis last for a flattened ordering.
148160
results.insert(results.end(), splits.rbegin(), splits.rend());
149161
} else if (!is_one(iter->dom->extent)) {
150162
auto mark = IterMark(iter->var, iter->dom->extent);

src/op/parallel.cc

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,11 +620,66 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
620620
if (IsCommonAccessIndice(buffer)) {
621621
return loop_layout_;
622622
}
623+
// Prefer a simple path: if original 2D indices form a bijective map, invert
624+
// them directly and avoid introducing a synthetic replicate dimension.
625+
{
626+
auto res2d =
627+
arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1,
628+
arith::IterMapLevel::Bijective,
629+
const_cast<arith::Analyzer *>(&analyzer_));
630+
if (res2d->errors.empty()) {
631+
Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse();
632+
PrimExpr indice_rep_extent = 1;
633+
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
634+
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
635+
Array<PrimExpr> fwd2;
636+
for (size_t i = 0; i < buffer->shape.size(); i++) {
637+
fwd2.push_back(InputPlaceholder(i));
638+
}
639+
PrimExpr thd_b2 =
640+
loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt);
641+
return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent,
642+
std::nullopt)
643+
->CondenseReplicateVar();
644+
}
645+
}
646+
// Otherwise, infer an extra flattened iterator that captures truly-unused
647+
// pieces of the loop space (if any), then try inversion with it.
623648
PrimExpr rep_b = MakeFlattenedExpression(
624649
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
625650
auto bijective_indice = indice_map_[buffer];
626651
bijective_indice.push_back(rep_b);
627-
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
652+
Layout layout_before_inv = Layout(loop_vars_, bijective_indice);
653+
654+
// Pre-check cardinality to guard non-bijective combinations after adding
655+
// rep_b.
656+
PrimExpr in_prod = 1;
657+
for (const auto &iv : loop_vars_)
658+
in_prod *= iv->dom->extent;
659+
PrimExpr out_prod = 1;
660+
for (const auto &d : layout_before_inv->OutputShape())
661+
out_prod *= d;
662+
663+
if (!analyzer_.CanProveEqual(in_prod, out_prod)) {
664+
DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling "
665+
"back to no-rep inversion.";
666+
Layout ind_inv_fallback =
667+
Layout(loop_vars_, indice_map_[buffer])->Inverse();
668+
PrimExpr indice_rep_extent = 1;
669+
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
670+
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
671+
Array<PrimExpr> fwd2;
672+
for (size_t i = 0; i < buffer->shape.size(); i++) {
673+
fwd2.push_back(InputPlaceholder(i));
674+
}
675+
PrimExpr thd_b = loop_layout_->ForwardThread(
676+
ind_inv_fallback->Forward(fwd2), std::nullopt);
677+
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
678+
std::nullopt)
679+
->CondenseReplicateVar();
680+
}
681+
682+
Layout ind_inv = layout_before_inv->Inverse();
628683
PrimExpr indice_rep_extent =
629684
ind_inv->InputShape().back(); // this is the size of rep_b
630685
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pytest
2+
import torch
3+
4+
import tilelang
5+
import tilelang.testing
6+
import tilelang.language as T
7+
8+
tilelang.testing.set_random_seed()
9+
10+
VEC_SIZE = 32
11+
12+
13+
@tilelang.jit
14+
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
15+
16+
@T.prim_func
17+
def main(
18+
a: T.Buffer((B, M, N), "bfloat16"),
19+
a_out: T.Buffer((B, M, N), "float32"),
20+
):
21+
with T.Kernel(
22+
T.ceildiv(M, BLOCK_MN),
23+
T.ceildiv(N, BLOCK_K),
24+
B,
25+
threads=128,
26+
) as (pid_m, pid_n, pid_b):
27+
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
28+
offs_m = pid_m * BLOCK_MN
29+
offs_n = pid_n * BLOCK_K
30+
31+
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
32+
idx = i * BLOCK_K + j
33+
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
34+
35+
return main
36+
37+
38+
def _require_cuda_tensor(shape, dtype):
39+
if not torch.cuda.is_available():
40+
pytest.skip("CUDA not available")
41+
try:
42+
return torch.randn(*shape, device="cuda", dtype=dtype)
43+
except RuntimeError as err:
44+
pytest.skip(f"CUDA runtime unavailable: {err}")
45+
46+
47+
def test_layout_infer_compiles_and_runs():
48+
B, M, N = 1, 32, 64
49+
BLOCK_MN, BLOCK_K = 32, 64
50+
kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K)
51+
52+
a = _require_cuda_tensor((B, M, N), torch.bfloat16)
53+
a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device)
54+
55+
# Ensure kernel compiles and executes without layout inversion failure
56+
kernel(a, a_out)
57+
58+
assert a_out.shape == a.shape
59+
assert a_out.dtype == torch.float32
60+
61+
62+
if __name__ == "__main__":
63+
tilelang.testing.main()

tilelang/contrib/nvcc.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import os
88
import subprocess
99
import warnings
10-
from tilelang.env import CUDA_HOME
10+
import contextlib
11+
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH
12+
import shutil
13+
import tempfile
1114
import tvm_ffi
1215
from tilelang import tvm as tvm
1316
from tvm.target import Target
@@ -125,6 +128,154 @@ def compile_cuda(code,
125128
return data
126129

127130

131+
def default_compile_options(compile_flags: list[str] | None = None) -> list[str]:
132+
"""
133+
Build a set of default NVCC compile options for TileLang generated sources.
134+
135+
Includes C++ standard and common include paths (TileLang templates, CUTLASS,
136+
CUDA include). Merges user-provided compile flags if given.
137+
138+
Parameters
139+
----------
140+
compile_flags : Optional[List[str]]
141+
Additional flags to include. Items are split on whitespace.
142+
143+
Returns
144+
-------
145+
List[str]
146+
A list of flags suitable for NVCC's command line.
147+
"""
148+
options: list[str] = ["-std=c++17"]
149+
try:
150+
if TILELANG_TEMPLATE_PATH:
151+
options.append(f"-I{TILELANG_TEMPLATE_PATH}")
152+
except Exception:
153+
pass
154+
try:
155+
if CUTLASS_INCLUDE_DIR:
156+
options.append(f"-I{CUTLASS_INCLUDE_DIR}")
157+
except Exception:
158+
pass
159+
try:
160+
if CUDA_HOME:
161+
options.append(f"-I{os.path.join(CUDA_HOME, 'include')}")
162+
except Exception:
163+
pass
164+
165+
# Preserve user flags exactly, including repeated tokens required by NVCC
166+
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
167+
if compile_flags:
168+
import shlex
169+
for flag in compile_flags:
170+
# Split each string like a shell would, preserving quoted args
171+
tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)]
172+
options.extend(tokens)
173+
return options
174+
175+
176+
def get_ptx_from_source(code: str,
177+
compile_flags: list[str] | None = None,
178+
verbose: bool = False) -> str:
179+
"""
180+
Compile CUDA C++ source to PTX using NVCC and return as text.
181+
182+
Parameters
183+
----------
184+
code : str
185+
CUDA C++ kernel source code.
186+
compile_flags : Optional[List[str]]
187+
Additional flags merged with defaults.
188+
verbose : bool
189+
Print NVCC output when True.
190+
191+
Returns
192+
-------
193+
str
194+
PTX text.
195+
"""
196+
opts = default_compile_options(compile_flags)
197+
ptx_bytes = compile_cuda(code, target_format="ptx", options=opts, verbose=verbose)
198+
try:
199+
return ptx_bytes.decode("utf-8")
200+
except Exception:
201+
return str(ptx_bytes)
202+
203+
204+
def _find_tool(name: str) -> str | None:
205+
"""Find a CUDA binary in PATH or under CUDA_HOME/bin."""
206+
path = shutil.which(name)
207+
if path:
208+
return path
209+
if CUDA_HOME:
210+
candidate = os.path.join(CUDA_HOME, "bin", name)
211+
if os.path.exists(candidate):
212+
return candidate
213+
return None
214+
215+
216+
def get_sass_from_source(code: str,
217+
compile_flags: list[str] | None = None,
218+
verbose: bool = False) -> str:
219+
"""
220+
Compile CUDA C++ source to CUBIN and disassemble to SASS.
221+
222+
Uses nvdisasm if available; otherwise falls back to cuobjdump.
223+
224+
Parameters
225+
----------
226+
code : str
227+
CUDA C++ kernel source code.
228+
compile_flags : Optional[List[str]]
229+
Additional flags merged with defaults.
230+
verbose : bool
231+
Print tool outputs when True.
232+
233+
Returns
234+
-------
235+
str
236+
SASS text.
237+
"""
238+
opts = default_compile_options(compile_flags)
239+
cubin_bytes = compile_cuda(code, target_format="cubin", options=opts, verbose=verbose)
240+
241+
# Write to a temp .cubin file
242+
with tempfile.NamedTemporaryFile(suffix=".cubin", delete=False) as tmp:
243+
tmp.write(cubin_bytes)
244+
cubin_path = tmp.name
245+
246+
# Try disassembly tools (prefer nvdisasm, fallback cuobjdump)
247+
cand_nvdisasm = _find_tool("nvdisasm")
248+
cand_cuobjdump = _find_tool("cuobjdump")
249+
if not cand_nvdisasm and not cand_cuobjdump:
250+
raise RuntimeError(
251+
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
252+
)
253+
last_err: str | None = None
254+
try:
255+
# Attempt nvdisasm first
256+
tools_to_try = []
257+
if cand_nvdisasm:
258+
tools_to_try.append(("nvdisasm", [cand_nvdisasm, cubin_path]))
259+
if cand_cuobjdump:
260+
tools_to_try.append(("cuobjdump", [cand_cuobjdump, "--dump-sass", cubin_path]))
261+
262+
for tool_name, cmd in tools_to_try:
263+
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
264+
out, _ = proc.communicate()
265+
text = py_str(out)
266+
if verbose:
267+
print(f"[{tool_name}] output:\n{text}")
268+
if proc.returncode == 0 and text.strip():
269+
return text
270+
last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}"
271+
# If we reach here, all attempts failed
272+
raise RuntimeError(f"SASS disassembly failed. Tried tools: "
273+
f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
274+
finally:
275+
with contextlib.suppress(Exception):
276+
os.remove(cubin_path)
277+
278+
128279
def find_cuda_path():
129280
"""Utility function to find cuda path
130281

0 commit comments

Comments
 (0)