Skip to content

Commit 47039f0

Browse files
authored
[Language] Refactor reduce and support shared memory as its in/out (#1219)
* [Refactor] Update ReduceOpNode to use absolute values in Max computation and remove unused shared memory reduction logic * Changed Max computation for AbsMax type to use absolute values of lhs and rhs. * Removed unused shared memory reduction logic and related checks for buffer dimensions and thread extents, simplifying the Lower method. * Added a fatal log for unsupported buffer scope reductions. * reduce fix * [Fix] Update type check for eval value in Builder class * Changed the type check for eval values to raise a TypeError for unsupported types, specifically excluding instances of tvm.tir.Buffer. This improves error handling and clarity in the Builder class.
1 parent 2957afc commit 47039f0

File tree

3 files changed

+69
-78
lines changed

3 files changed

+69
-78
lines changed

src/op/reduce.cc

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
104104
} else if (type->isMin()) {
105105
return Min(lhs, rhs);
106106
} else if (type->isAbsMax()) {
107-
return Max(Max(lhs, rhs), -Min(lhs, rhs));
107+
return Max(tvm::abs(lhs), tvm::abs(rhs));
108108
} else if (type->isBitAnd()) {
109109
return lhs & rhs;
110110
} else if (type->isBitOr()) {
@@ -360,70 +360,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
360360
return body;
361361
}
362362

363-
auto is_shared_scope = [](const std::string &scope) {
364-
return scope == "shared" || scope == "shared.dyn";
365-
};
366-
367-
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
368-
Buffer src_buffer = get_buffer(this->src);
369-
Buffer dst_buffer = get_buffer(this->dst);
370-
371-
size_t src_dim = src_buffer->shape.size();
372-
size_t dst_dim = dst_buffer->shape.size();
373-
bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1);
374-
if (!is_1d_reduce) {
375-
ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
376-
} else {
377-
ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce.";
378-
}
379-
380-
auto thread_extent = as_const_int(T.thread_bounds->extent);
381-
ICHECK(thread_extent)
382-
<< "Shared-memory reduce requires static thread extent.";
383-
int threads = *thread_extent;
384-
385-
if (TargetIsCuda(T.target)) {
386-
ICHECK_EQ(threads % 32, 0)
387-
<< "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA.";
388-
} else if (TargetIsRocm(T.target)) {
389-
ICHECK_EQ(threads % 64, 0)
390-
<< "Shared reduce expects blockDim.x to be a multiple of 64 on HIP.";
391-
}
392-
393-
bool use_abs = this->type->isAbsSum() || this->type->isAbsMax();
394-
bool need_accumulate =
395-
(!this->clear) && (this->type->isSum() || this->type->isAbsSum() ||
396-
this->type->isBitAnd() || this->type->isBitOr() ||
397-
this->type->isBitXor());
398-
399-
PrimExpr reduce_extent = src_buffer->shape[this->dim];
400-
PrimExpr tail_extent = make_const(DataType::Int(32), 1);
401-
for (size_t i = this->dim + 1; i < src_dim; ++i) {
402-
tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]);
403-
}
404-
405-
PrimExpr total_dest = make_const(DataType::Int(32), 1);
406-
for (size_t i = 0; i < dst_dim; ++i) {
407-
total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]);
408-
}
409-
410-
std::stringstream ss;
411-
std::string reducer = this->MakeCodegenReducer();
412-
ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", "
413-
<< (use_abs ? "true" : "false") << ", "
414-
<< (need_accumulate ? "true" : "false") << ">::run";
415-
416-
Array<PrimExpr> call_args = {StringImm(ss.str()),
417-
src_buffer.access_ptr(1),
418-
dst_buffer.access_ptr(3),
419-
cast(DataType::Int(32), total_dest),
420-
cast(DataType::Int(32), reduce_extent),
421-
cast(DataType::Int(32), tail_extent),
422-
this->MakeInitValue()};
423-
424-
return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));
425-
}
426-
427363
LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
428364
<< dst_scope << ") is not implemented.";
429365
return Stmt();

tilelang/language/reduce.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from __future__ import annotations
33

44
from tvm import tir
5-
from tilelang.language import copy, macro, alloc_shared
5+
from tilelang.language import copy, macro, alloc_shared, alloc_fragment
6+
from tilelang.utils.language import is_shared, is_fragment
7+
from tvm.script.ir_builder import IRBuilder
68

79

810
def _legalize_dim(buffer: tir.Buffer, dim: int):
@@ -34,17 +36,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
3436
raise ValueError(
3537
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
3638
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
37-
buffer = buffer.access_ptr("r")
38-
out = out.access_ptr("w")
39-
return tir.call_intrin(
40-
"handle",
41-
tir.op.Op.get("tl.reduce"),
42-
buffer,
43-
out,
44-
reduce_type,
45-
dim,
46-
clear,
47-
)
39+
40+
@macro
41+
def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
42+
if is_shared(buffer) and is_shared(out):
43+
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
44+
red_frag_out = alloc_fragment(out.shape, out.dtype)
45+
46+
# rename buffers
47+
IRBuilder.name(buffer.name + "_frag", red_frag_in)
48+
IRBuilder.name(out.name + "_frag", red_frag_out)
49+
50+
copy(buffer, red_frag_in)
51+
tir.call_intrin(
52+
"handle",
53+
tir.op.Op.get("tl.reduce"),
54+
red_frag_in.access_ptr("r"),
55+
red_frag_out.access_ptr("w"),
56+
reduce_type,
57+
dim,
58+
clear,
59+
)
60+
copy(red_frag_out, out)
61+
elif is_shared(buffer) and is_fragment(out):
62+
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
63+
IRBuilder.name(buffer.name + "_frag", red_frag_in)
64+
65+
copy(buffer, red_frag_in)
66+
tir.call_intrin(
67+
"handle",
68+
tir.op.Op.get("tl.reduce"),
69+
red_frag_in.access_ptr("r"),
70+
out.access_ptr("w"),
71+
reduce_type,
72+
dim,
73+
clear,
74+
)
75+
elif is_fragment(buffer) and is_shared(out):
76+
red_frag_out = alloc_fragment(out.shape, out.dtype)
77+
IRBuilder.name(out.name + "_frag", red_frag_out)
78+
79+
tir.call_intrin(
80+
"handle",
81+
tir.op.Op.get("tl.reduce"),
82+
buffer.access_ptr("r"),
83+
red_frag_out.access_ptr("w"),
84+
reduce_type,
85+
dim,
86+
clear,
87+
)
88+
copy(red_frag_out, out)
89+
elif is_fragment(buffer) and is_fragment(out):
90+
tir.call_intrin(
91+
"handle",
92+
tir.op.Op.get("tl.reduce"),
93+
buffer.access_ptr("r"),
94+
out.access_ptr("w"),
95+
reduce_type,
96+
dim,
97+
clear,
98+
)
99+
else:
100+
raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}")
101+
102+
return reduce_macro(buffer, out, reduce_type, dim, clear)
48103

49104

50105
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):

tilelang/language/v2/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def eval(self, val: Any):
245245
pass
246246
elif isinstance(val, tvm.tir.stmt.BufferStore):
247247
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
248-
else:
248+
elif not isinstance(val, tvm.tir.Buffer):
249249
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}")
250250

251251
def ctx_for(self, it):

0 commit comments

Comments
 (0)