Skip to content

Commit 27701c3

Browse files
authored
[Parallel] Support T.Parallel with dynamic extents (#990)
* Allow dynamic extents in loop partition; warn when layout inversion falls back to NoCheck * add test and introduce predicate * test fix * fix * enhance * inverse with level * test fix * bug fix
1 parent d66b83c commit 27701c3

File tree

5 files changed

+196
-29
lines changed

5 files changed

+196
-29
lines changed

src/layout/layout.cc

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,34 @@ Fragment FragmentNode::BindThreadRange(Range thread_range) const {
229229
return Fragment(n);
230230
}
231231

232-
Layout LayoutNode::Inverse() const {
232+
std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
233233
arith::Analyzer analyzer;
234+
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
235+
Array<PrimExpr> symbolic_dims;
236+
for (const auto &dim : shape) {
237+
if (!as_const_int(dim)) {
238+
symbolic_dims.push_back(dim);
239+
}
240+
}
241+
return symbolic_dims;
242+
};
243+
Array<PrimExpr> symbolic_dims = collect_symbolic(input_size_);
244+
Array<PrimExpr> output_shape = OutputShape();
245+
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
246+
output_shape.end());
247+
symbolic_dims = collect_symbolic(symbolic_dims);
248+
bool is_static_shape = symbolic_dims.empty();
249+
auto level = is_static_shape ? arith::IterMapLevel::Bijective
250+
: arith::IterMapLevel::NoCheck;
251+
if (!is_static_shape) {
252+
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and
253+
// warn.
254+
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
255+
"NoCheck; symbolic dims: "
256+
<< symbolic_dims;
257+
}
234258
arith::IterMapResult res =
235-
arith::DetectIterMap(forward_index_, getVarMap(), 1,
236-
arith::IterMapLevel::Bijective, &analyzer);
259+
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
237260
ICHECK(res->errors.empty())
238261
<< "Layout " << DebugOutput() << " has errors: " << res->errors;
239262

@@ -254,9 +277,13 @@ Layout LayoutNode::Inverse() const {
254277
}
255278
}
256279

257-
return Layout(outputs_shape, backward_index);
280+
return {Layout(outputs_shape, backward_index), level};
258281
}
259282

283+
Layout LayoutNode::Inverse() const {
284+
auto inverse_result = InverseWithLevel();
285+
return std::move(inverse_result.first);
286+
}
260287
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
261288
const PrimExpr &forward_thread,
262289
arith::Analyzer *analyzer) {
@@ -366,15 +393,19 @@ PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
366393
}
367394

368395
Layout FragmentNode::Inverse() const {
396+
auto result = InverseWithLevel();
397+
return std::move(result.first);
398+
}
399+
400+
std::pair<Layout, arith::IterMapLevel> FragmentNode::InverseWithLevel() const {
369401
auto input_size_copy = input_size_;
370402
input_size_copy.push_back(ReplicateExtent());
371403
auto forward_index_copy = forward_index_;
372404
forward_index_copy.push_back(
373405
Substitute(forward_thread_,
374406
{{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
375407
auto fwd = Layout(input_size_copy, forward_index_copy);
376-
auto bwd = fwd->Inverse();
377-
return bwd;
408+
return fwd->InverseWithLevel();
378409
}
379410

380411
Fragment FragmentNode::CondenseReplicateVar() const {

src/layout/layout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#define TVM_TL_LAYOUT_LAYOUT_H_
88

99
#include <tvm/arith/analyzer.h>
10+
#include <tvm/arith/iter_affine_map.h>
11+
#include <utility>
1012

1113
namespace tvm {
1214
namespace tl {
@@ -36,6 +38,7 @@ class LayoutNode : public Object {
3638
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
3739

3840
virtual Layout Inverse() const;
41+
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
3942

4043
virtual std::string DebugOutput() const;
4144

@@ -76,6 +79,7 @@ class FragmentNode : public LayoutNode {
7679
Array<PrimExpr> GetForwardVars() const final;
7780

7881
Layout Inverse() const final;
82+
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
7983

8084
PrimExpr ThreadExtent() const;
8185

src/transform/loop_partition.cc

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,88 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
6464
ICHECK(thread_var.defined());
6565
int old_loop_depth = loop_layout->InputDim();
6666
int new_loop_depth = loop_layout->OutputDim();
67-
6867
// Create the new loop iter var
6968
Array<Var> vars;
7069
for (int i = 0; i < new_loop_depth; i++) {
7170
Var var = Var(std::string{char('i' + i)});
71+
analyzer->Bind(var, Range::FromMinExtent(make_zero(var->dtype),
72+
loop_layout->OutputShape()[i]));
7273
vars.push_back(var);
7374
}
7475
vars.push_back(thread_var);
7576
// create the substitute map, and the loop body
7677
Map<Var, PrimExpr> vmap;
7778
Stmt body = std::move(op);
78-
auto inv_loop = loop_layout->Inverse();
79+
Array<PrimExpr> loop_mins;
80+
Array<PrimExpr> loop_extents;
81+
auto inverse_info = loop_layout->InverseWithLevel();
82+
auto inv_loop = inverse_info.first;
83+
// Must check the guard if the layout can not be proved as bijective
84+
bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;
7985
auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
86+
// Normalize thread var once so we can reuse the same substitution later.
87+
Map<Var, PrimExpr> thread_offset_map;
88+
bool has_thread_offset = false;
89+
if (loop_layout->ThreadRange().defined()) {
90+
auto range = loop_layout->ThreadRange();
91+
thread_offset_map.Set(thread_var, thread_var - range->min);
92+
has_thread_offset = true;
93+
}
8094
for (int i = 0; i < old_loop_depth; i++) {
8195
const ForNode *loop = body.as<ForNode>();
8296
ICHECK(loop != nullptr);
8397
vmap.Set(loop->loop_var, indices[i]);
98+
loop_mins.push_back(loop->min);
99+
loop_extents.push_back(loop->extent);
84100
body = loop->body;
85101
}
86-
87102
// substitute and re-construct the serial loop
88103
body = Substitute(body, vmap);
104+
// Guard executes the recovered loop body only if each inverse-mapped iterator
105+
// falls back into the original For ranges. We first check every axis from the
106+
// old loop nest (old_loop_depth) and then the extra index produced by inverse
107+
// layouts that carry a replicate/thread component (`inv_output_shape`). Both
108+
// must stay within bounds to ensure correctness. Example: layout([i, j]) =
109+
// floor((i * 16 + j) / 32) may generate extra points when the new loop
110+
// enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j)
111+
// or replicate index fall outside their original extents.
112+
// Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points
113+
// when the new loop enumerates 0..31; this guard skips iterations where the
114+
// inverse i, j land outside the original extents. This protects
115+
// non-surjective loop_layout mappings that otherwise over-cover the parallel
116+
// space.
117+
PrimExpr guard = const_true();
118+
119+
if (need_guard) {
120+
for (int i = 0; i < old_loop_depth; i++) {
121+
PrimExpr index = indices[i];
122+
if (has_thread_offset) {
123+
index = Substitute(index, thread_offset_map);
124+
}
125+
PrimExpr lower_bound = analyzer->Simplify(index >= loop_mins[i]);
126+
PrimExpr upper_bound =
127+
analyzer->Simplify(index < loop_mins[i] + loop_extents[i]);
128+
guard = And(guard, And(lower_bound, upper_bound));
129+
}
130+
auto inv_output_shape = inv_loop->OutputShape();
131+
if (inv_output_shape.size() > static_cast<size_t>(old_loop_depth)) {
132+
PrimExpr replicate_index = indices[old_loop_depth];
133+
if (has_thread_offset) {
134+
replicate_index = Substitute(replicate_index, thread_offset_map);
135+
}
136+
PrimExpr replicate_extent = inv_output_shape[old_loop_depth];
137+
PrimExpr lower_bound = analyzer->Simplify(
138+
replicate_index >= make_zero(replicate_index.dtype()));
139+
PrimExpr upper_bound =
140+
analyzer->Simplify(replicate_index < replicate_extent);
141+
guard = And(guard, And(lower_bound, upper_bound));
142+
}
143+
PrimExpr simplified_guard = analyzer->Simplify(guard);
144+
if (!analyzer->CanProve(simplified_guard)) {
145+
body = IfThenElse(simplified_guard, body, Stmt());
146+
}
147+
}
148+
89149
for (int i = new_loop_depth - 1; i >= 0; i--) {
90150
body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
91151
ForKind::kSerial, body);
@@ -94,13 +154,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
94154

95155
body = BufferIndiceSimplify(analyzer)(body);
96156

97-
auto for_node = LoopPragmaUnroll(Downcast<For>(body));
98-
if (loop_layout->ThreadRange().defined()) {
99-
auto range = loop_layout->ThreadRange();
100-
auto thread_var_with_offset = thread_var - range->min;
101-
for_node.CopyOnWrite()->body =
102-
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
157+
if (has_thread_offset) {
158+
body = Substitute(body, thread_offset_map);
103159
}
160+
161+
auto for_node = LoopPragmaUnroll(Downcast<For>(body));
104162
return for_node;
105163
}
106164

@@ -111,6 +169,10 @@ class LoopPramaUnroller : public StmtExprMutator {
111169
private:
112170
Stmt VisitStmt_(const ForNode *node) final {
113171
if (node->kind == ForKind::kSerial) {
172+
auto analyzer = std::make_shared<arith::Analyzer>();
173+
if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
174+
return StmtExprMutator::VisitStmt_(node);
175+
}
114176
For new_for = GetRef<For>(node);
115177
auto for_ptr = new_for.CopyOnWrite();
116178
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
@@ -127,22 +189,20 @@ class LoopPartitioner : public StmtExprVisitor {
127189

128190
Fragment Partition(const For &op, int num_thread, int vectorize_size) {
129191
this->VisitStmt(op);
130-
int loop_size_full = 1;
131-
PrimExpr flattened = 0;
192+
ICHECK(!loop_vars_.empty());
193+
DataType dtype = loop_vars_[0]->var.dtype();
194+
PrimExpr flattened = make_const(dtype, 0);
195+
PrimExpr vector_extent = make_const(dtype, vectorize_size);
196+
PrimExpr thread_extent_const = make_const(dtype, num_thread);
132197
for (size_t i = 0; i < loop_vars_.size(); i++) {
133-
auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent);
134-
ICHECK(ext_ptr)
135-
<< "Loop partitioner only works with constant loop sizes, but got "
136-
<< loop_vars_[i]->dom->extent;
137-
int extent = *ext_ptr;
138-
loop_size_full *= extent;
198+
PrimExpr extent = loop_vars_[i]->dom->extent;
139199
flattened = flattened * extent + loop_vars_[i]->var;
140200
}
141-
ICHECK(loop_size_full % vectorize_size == 0);
142-
PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
143-
PrimExpr thd = FloorMod(access_idx, num_thread);
144-
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
145-
FloorMod(flattened, vectorize_size);
201+
PrimExpr access_idx = FloorDiv(flattened, vector_extent);
202+
PrimExpr thd = FloorMod(access_idx, thread_extent_const);
203+
PrimExpr idx = FloorDiv(access_idx, thread_extent_const) * vector_extent +
204+
FloorMod(flattened, vector_extent);
205+
146206
auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
147207
if (has_fragment_) {
148208
// for fragment buffer, we don't need to replicate the loop layout

src/transform/loop_vectorize.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
9494
private:
9595
void VisitStmt_(const ForNode *node) final {
9696
inner_for_ = node;
97-
auto extent_ptr = as_const_int(node->extent);
97+
auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
9898
// Here I disable dynamic shape completely,
9999
// In order to do it, the Planner should accept an analyzer with
100100
// arithmetic info outside to prove the dividiblity of vector size
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import torch
4+
import tilelang.testing
5+
import pytest
6+
7+
tilelang.testing.set_random_seed()
8+
9+
10+
@tilelang.jit(out_idx=[1])
11+
def parallel_elementwise_static(length=256, dtype="float32"):
12+
13+
@T.prim_func
14+
def main(
15+
A: T.Tensor((length,), dtype),
16+
B: T.Tensor((length,), dtype),
17+
):
18+
with T.Kernel(1, threads=length) as _:
19+
for i in T.Parallel(length):
20+
B[i] = A[i] + 1.0
21+
22+
return main
23+
24+
25+
@tilelang.jit(out_idx=[1])
26+
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):
27+
28+
@T.prim_func
29+
def main(
30+
A: T.Tensor((max_len,), dtype),
31+
B: T.Tensor((max_len,), dtype),
32+
valid_len: T.int32,
33+
):
34+
with T.Kernel(1, threads=threads) as _:
35+
for i in T.Parallel(max_len):
36+
B[i] = 0.0
37+
span = T.min(valid_len, max_len)
38+
for i in T.Parallel(span):
39+
B[i] = A[i] - 1.0
40+
41+
return main
42+
43+
44+
def _require_cuda_tensor(shape, dtype=torch.float32):
45+
if not torch.cuda.is_available():
46+
pytest.skip("CUDA not available")
47+
try:
48+
return torch.randn(*shape, device="cuda", dtype=dtype)
49+
except RuntimeError as err:
50+
pytest.skip(f"CUDA runtime unavailable: {err}")
51+
52+
53+
def test_parallel_static_extent():
54+
kernel = parallel_elementwise_static(length=256)
55+
data = _require_cuda_tensor((256,), torch.float32)
56+
result = kernel(data)
57+
torch.testing.assert_close(result, data + 1.0, atol=1e-5, rtol=1e-5)
58+
59+
60+
def test_parallel_dynamic_extent():
61+
kernel = parallel_elementwise_dynamic(max_len=512, threads=256)
62+
data = _require_cuda_tensor((512,), torch.float32)
63+
for valid_len in [0, 13, 200, 600]:
64+
out = kernel(data, valid_len)
65+
reference = torch.zeros_like(data)
66+
clip = min(valid_len, data.shape[0])
67+
reference[:clip] = data[:clip] - 1.0
68+
torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5)
69+
70+
71+
if __name__ == "__main__":
72+
tilelang.testing.main()

0 commit comments

Comments
 (0)