Skip to content

Commit 7880876

Browse files
committed
Merge branch 'main' of https://github.com/tile-ai/tilelang into kurisu-patch-1
2 parents 61318ce + 6b12502 commit 7880876

22 files changed

+370
-953
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,18 @@ jobs:
104104
- name: Install project (wheel form)
105105
run: |
106106
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
107-
pip install . --no-user
107+
pip install . --no-user -v
108108
109109
- name: Run examples
110110
run: |
111111
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
112112
cd examples
113113
unset PYTHONPATH
114-
python -m pytest -n 4 **/test*.py
114+
python -m pytest -n 4 **/test*.py -v -r fE
115115
116116
- name: Run tests
117117
run: |
118118
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
119119
cd testing/python
120120
unset PYTHONPATH
121-
python -m pytest -n 4
121+
python -m pytest -n 4 -v -r fE

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
[build-system]
22
requires = [
3+
"build",
34
"cmake>=3.26",
4-
"cython",
55
"packaging",
66
"setuptools>=61",
7+
"torch",
78
"wheel",
9+
"tox",
10+
"auditwheel",
11+
"patchelf",
12+
"ninja",
13+
"Cython",
814
]
915
build-backend = "setuptools.build_meta"
1016

setup.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def get_nvcc_cuda_version():
112112
113113
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
114114
"""
115-
nvcc_output = subprocess.check_output(["nvcc", "-V"], universal_newlines=True)
115+
nvcc_path = os.path.join(CUDA_HOME, "bin", "nvcc")
116+
nvcc_output = subprocess.check_output([nvcc_path, "-V"], universal_newlines=True)
116117
output = nvcc_output.split()
117118
release_idx = output.index("release") + 1
118119
nvcc_cuda_version = Version(output[release_idx].split(",")[0])
@@ -788,26 +789,46 @@ def build_cmake(self, ext):
788789
build_temp = os.path.abspath(self.build_temp)
789790
os.makedirs(build_temp, exist_ok=True)
790791

791-
# Copy the default 'config.cmake' from the source tree into our build directory.
792-
src_config_cmake = os.path.join(ext.sourcedir, "3rdparty", "tvm", "cmake", "config.cmake")
793-
dst_config_cmake = os.path.join(build_temp, "config.cmake")
794-
shutil.copy(src_config_cmake, dst_config_cmake)
795-
796-
# Append some configuration variables to 'config.cmake'
797-
with open(dst_config_cmake, "a") as config_file:
798-
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
799-
if USE_ROCM:
800-
config_file.write(f"set(USE_ROCM {ROCM_HOME})\n")
801-
config_file.write("set(USE_CUDA OFF)\n")
802-
else:
803-
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
804-
config_file.write("set(USE_ROCM OFF)\n")
792+
# Paths to the source and destination config.cmake files
793+
src_config = Path(ext.sourcedir) / "3rdparty" / "tvm" / "cmake" / "config.cmake"
794+
dst_config = Path(build_temp) / "config.cmake"
795+
796+
# Read the default config template
797+
content_lines = src_config.read_text().splitlines()
798+
799+
# Add common LLVM configuration
800+
content_lines.append(f"set(USE_LLVM {llvm_config_path})")
801+
802+
# Append GPU backend configuration based on environment
803+
if USE_ROCM:
804+
content_lines += [
805+
f"set(USE_ROCM {ROCM_HOME})",
806+
"set(USE_CUDA OFF)",
807+
]
808+
else:
809+
content_lines += [
810+
f"set(USE_CUDA {CUDA_HOME})",
811+
"set(USE_ROCM OFF)",
812+
]
813+
814+
# Create the final file content
815+
new_content = "\n".join(content_lines) + "\n"
816+
817+
# Write the file only if it does not exist or has changed
818+
if not dst_config.exists() or dst_config.read_text() != new_content:
819+
dst_config.write_text(new_content)
820+
print(f"[Config] Updated: {dst_config}")
821+
else:
822+
print(f"[Config] No changes: {dst_config}")
805823

806824
# Run CMake to configure the project with the given arguments.
807-
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
825+
if not os.path.exists(build_temp + "/build.ninja"):
826+
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
808827

809828
# Build the project in "Release" mode with all available CPU cores ("-j").
810-
subprocess.check_call(["cmake", "--build", ".", "--config", "Release", "-j"],
829+
num_jobs = max(1, int(multiprocessing.cpu_count() * 0.75))
830+
subprocess.check_call(["cmake", "--build", ".", "--config", "Release", "-j",
831+
str(num_jobs)],
811832
cwd=build_temp)
812833

813834

src/op/builtin.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,6 @@ TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
9090
.set_attr<TCallEffectKind>("TCallEffectKind",
9191
Integer(CallEffectKind::kOpaque));
9292

93-
TIR_DEFINE_TL_BUILTIN(sync_thread_partial)
94-
.set_num_inputs(2)
95-
.set_attr<TCallEffectKind>("TCallEffectKind",
96-
Integer(CallEffectKind::kOpaque));
97-
9893
TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
9994
.set_num_inputs(0)
10095
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,6 @@ TVM_DLL const Op &ptx_stmatrix();
169169
*/
170170
TVM_DLL const Op &pack_b16();
171171

172-
/*!
173-
* \brief Similar to __syncthreads(), but can be used to sync partial threads
174-
*
175-
* sync_thread_partial(num_partial_threads or mbarrier)
176-
*
177-
*/
178-
TVM_DLL const Op &sync_thread_partial();
179-
180172
/*!
181173
* \brief Issue a shared memory fence for async operations
182174
*

src/target/codegen_cuda.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,8 +1050,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
10501050
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
10511051
auto phase = this->PrintExpr(op->args[1]);
10521052
this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
1053-
} else if (op->op.same_as(tl::sync_thread_partial())) {
1054-
print_extern_call_stmt("cutlass::arch::NamedBarrier::sync");
10551053
} else if (op->op.same_as(tl::no_set_max_nreg())) {
10561054
return;
10571055
} else if (op->op.same_as(tl::tma_load())) {

src/target/codegen_hip.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,28 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
784784
int n = Downcast<IntImm>(op->args[0])->value;
785785
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
786786
print_extern_call_stmt(func_name, 1);
787-
} else if (op->op.same_as(tl::sync_thread_partial())) {
788-
print_extern_call_stmt("tl::syncthreads_partial");
787+
} else if (op->op.same_as(builtin::create_barriers())) {
788+
this->PrintIndent();
789+
int barrier_count = Downcast<IntImm>(op->args[0])->value;
790+
std::string barrier_name = "_mbarrier";
791+
this->stream << "__shared__ uint64_t " << barrier_name << "["
792+
<< barrier_count << "];\n";
793+
} else if (op->op.same_as(tl::get_mbarrier())) {
794+
std::string barrier_name = "_mbarrier";
795+
std::string barrier_id = this->PrintExpr(op->args[0]);
796+
os << barrier_name + "[" + barrier_id + "]";
797+
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
798+
print_extern_call_stmt("tl::mbarrier_arrive");
799+
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
800+
print_extern_call_stmt("tl::mbarrier_init");
801+
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
802+
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
803+
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
804+
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
805+
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
806+
print_extern_call_stmt("tl::mbarrier_expect_tx");
807+
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
808+
print_extern_call_stmt("tl::mbarrier_wait");
789809
} else if (op->op.same_as(tl::ptx_stmatrix())) {
790810
int trans = Downcast<IntImm>(op->args[0])->value;
791811
int num = Downcast<IntImm>(op->args[1])->value;

src/tl_templates/cuda/common.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,43 @@ TL_DEVICE void __sync_thread_partial() {
241241
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
242242
}
243243

244+
// Template parameter:
245+
// thread_extent: the logical size (in number of threads) of each "group"
246+
// within which we want to elect exactly ONE representative
247+
// thread.
244248
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
249+
250+
// Special case: thread_extent == 0 means "elect exactly one thread
251+
// in the entire thread block", i.e., the leader of the first warp of the
252+
// block.
245253
if constexpr (thread_extent == 0) {
254+
// cutlass::canonical_warp_idx_sync():
255+
// Returns the warp ID within the thread block in a "canonical" way
256+
// (0 for the first warp, 1 for the second, ...).
257+
// cute::elect_one_sync():
258+
// Elect exactly one lane in the warp to return true (typically lane 0),
259+
// other lanes return false.
260+
// The condition ensures that:
261+
// (1) We are in warp 0 of the block.
262+
// (2) We are the elected lane in this warp.
246263
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
247264
}
248-
return __shfl_sync(0xffffffff, (threadIdx.x / 32) % (thread_extent / 32),
249-
0) == 0 &&
265+
266+
// General case: thread_extent != 0
267+
// (threadIdx.x / 32) is the warp index in the block.
268+
// (thread_extent / 32) is the number of warps in one group of size
269+
// thread_extent. We take warp_id % num_warps_in_group to get the warp's index
270+
// within the group.
271+
// __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all
272+
// lanes in the warp. Here it broadcasts the group-local warp index from lane
273+
// 0. Comparing to 0 selects only the group's warp 0.
274+
return __shfl_sync(0xffffffff, // full warp mask
275+
(threadIdx.x / 32) %
276+
(thread_extent / 32), // warp index within group
277+
0 // take the value from lane 0
278+
) == 0 &&
279+
// Within that group leader warp, elect exactly one lane (typically
280+
// lane 0) to be the single representative for the group.
250281
cute::elect_one_sync();
251282
}
252283

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*!
2+
* \file thread_sync_types.h
3+
*/
4+
#ifndef TVM_TL_THREAD_BOUND_KEY_H_
5+
#define TVM_TL_THREAD_BOUND_KEY_H_
6+
7+
#include <cstdint>
8+
#include <functional>
9+
10+
namespace tvm {
11+
namespace tl {
12+
13+
struct ThreadBoundKey {
14+
int64_t tx_min, tx_max, ty_min, ty_max, tz_min, tz_max;
15+
bool operator==(const ThreadBoundKey &other) const {
16+
return tx_min == other.tx_min && tx_max == other.tx_max &&
17+
ty_min == other.ty_min && ty_max == other.ty_max &&
18+
tz_min == other.tz_min && tz_max == other.tz_max;
19+
}
20+
};
21+
22+
// There are 16 Named Barriers provided by Hardware starting in Hopper
23+
// Their IDs are in the range 0-15
24+
// Number of threads syncing using the barrier must be a multiple of warp-size
25+
// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads)
26+
// may use it and conflict with other uses.
27+
enum class ReservedNamedBarriers {
28+
kSyncThreads = 0,
29+
kReduce_0 = 1,
30+
kReduce_1 = 2,
31+
kFirstUsedBarrier = kReduce_1 + 1
32+
};
33+
34+
} // namespace tl
35+
} // namespace tvm
36+
37+
namespace std {
38+
template <> struct hash<tvm::tl::ThreadBoundKey> {
39+
size_t operator()(const tvm::tl::ThreadBoundKey &k) const {
40+
size_t h = std::hash<int64_t>()(k.tx_min);
41+
h = h * 31 + std::hash<int64_t>()(k.tx_max);
42+
h = h * 31 + std::hash<int64_t>()(k.ty_min);
43+
h = h * 31 + std::hash<int64_t>()(k.ty_max);
44+
h = h * 31 + std::hash<int64_t>()(k.tz_min);
45+
h = h * 31 + std::hash<int64_t>()(k.tz_max);
46+
return h;
47+
}
48+
};
49+
} // namespace std
50+
51+
#endif // TVM_TL_THREAD_BOUND_KEY_H_

src/transform/storage_access.cc

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ using namespace tir;
3838

3939
void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
4040
Var buf = op->buffer->data;
41+
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
4142
StorageScope scope = GetScope(buf);
4243
if (Enabled(buf.get(), scope)) {
4344
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string();
@@ -64,6 +65,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
6465
curr_stmt_.stmt = op;
6566

6667
Var buf = op->buffer->data;
68+
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
6769
StorageScope scope = GetScope(buf);
6870
if (Enabled(buf.get(), scope)) {
6971
AccessEntry e;
@@ -115,6 +117,15 @@ void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) {
115117
this->VisitStmt(op->body);
116118
}
117119

120+
void TileLangStorageAccessVisitor::VisitStmt_(const BlockNode *op) {
121+
auto block = Downcast<Block>(op);
122+
for (const auto &buffer : block->alloc_buffers) {
123+
ICHECK(buffer->IsInstance<BufferNode>());
124+
buffer_data_to_buffer_.Set(buffer->data, buffer);
125+
}
126+
IRVisitorWithAnalyzer::VisitStmt_(op);
127+
}
128+
118129
void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
119130
if (op->attr_key == tvm::tir::attr::double_buffer_write) {
120131
ICHECK(double_buffer_write_ == nullptr);
@@ -271,18 +282,27 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
271282
Buffer buffer = load->buffer;
272283
DataType dtype = buffer->dtype;
273284
const VarNode *buffer_var = buffer->data.as<VarNode>();
285+
buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer);
274286
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
287+
Array<Range> buffer_ranges;
288+
// from indices to buffer indices
289+
ICHECK(buffer->shape.size() == load->indices.size());
290+
for (size_t i = 0; i < buffer->shape.size(); ++i) {
291+
buffer_ranges.push_back(
292+
Range::FromMinExtent(load->indices[i], buffer->shape[i]));
293+
}
275294
if (Enabled(buffer_var, scope)) {
276295
ICHECK(allow_append_);
277296
AccessEntry e;
278297
e.threads = env_threads();
279298
e.thread_range = this->ComputeThreadRange(e.threads);
280299
e.dtype = dtype;
281300
e.buffer = Downcast<Var>(buffer->data);
282-
e.buffer_indices = load->indices;
301+
e.buffer_ranges = buffer_ranges;
283302
for (const auto &index : load->indices) {
284303
e.touched.push_back(arith::IntSet::Vector(index));
285304
}
305+
e.is_pointer_access = true;
286306
e.type = kRead;
287307
e.scope = scope;
288308
curr_stmt_.access.emplace_back(e);
@@ -294,20 +314,54 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
294314
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
295315
ICHECK_EQ(op->args.size(), 5U);
296316
DataType dtype = op->args[0].dtype();
297-
const VarNode *buffer = op->args[1].as<VarNode>();
317+
const VarNode *buffer_var = op->args[1].as<VarNode>();
298318
PrimExpr offset = op->args[2];
299319
PrimExpr extent = op->args[3];
300320
const IntImmNode *flag = op->args[4].as<IntImmNode>();
301-
StorageScope scope = GetScope(GetRef<Var>(buffer));
321+
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
302322
// The buffer scope.
303-
if (Enabled(buffer, scope)) {
323+
if (Enabled(buffer_var, scope)) {
304324
ICHECK(allow_append_);
325+
Array<Range> buffer_ranges;
326+
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) ==
327+
buffer_data_to_buffer_.end()) {
328+
// cannot find buffer map, use the default buffer
329+
buffer_ranges = {Range::FromMinExtent(offset, extent)};
330+
} else {
331+
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
332+
auto buffer_shape = buffer->shape;
333+
// convert 1d offset to multi-dimensional index
334+
auto linear_to_indices = [this](PrimExpr offset,
335+
const Array<PrimExpr> &shape) {
336+
Array<PrimExpr> indices;
337+
PrimExpr remaining = offset;
338+
for (size_t i = 0; i < shape.size(); ++i) {
339+
PrimExpr stride = make_const(DataType::Int(32), 1);
340+
for (size_t j = i + 1; j < shape.size(); ++j) {
341+
stride = stride * shape[j];
342+
}
343+
PrimExpr idx = FloorDiv(remaining, stride);
344+
remaining = FloorMod(remaining, stride);
345+
indices.push_back(analyzer_.Simplify(idx));
346+
}
347+
return indices;
348+
};
349+
Array<PrimExpr> start_indices = linear_to_indices(offset, buffer_shape);
350+
Array<PrimExpr> end_indices =
351+
linear_to_indices(offset + extent, buffer_shape);
352+
for (size_t i = 0; i < buffer_shape.size(); ++i) {
353+
buffer_ranges.push_back(Range::FromMinExtent(
354+
start_indices[i],
355+
analyzer_.Simplify(end_indices[i] - start_indices[i])));
356+
}
357+
}
305358
AccessEntry e;
306359
e.threads = env_threads();
307360
e.thread_range = this->ComputeThreadRange(e.threads);
308361
e.dtype = dtype;
309-
e.buffer = Downcast<Var>(op->args[1]);
310-
e.buffer_indices = {offset, extent};
362+
e.buffer = GetRef<Var>(buffer_var);
363+
e.buffer_ranges = buffer_ranges;
364+
e.is_pointer_access = true;
311365
e.touched = {
312366
arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};
313367
e.scope = scope;

0 commit comments

Comments
 (0)