Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ class IterMapRewriter : public ExprMutator {
bool requires_padding_{false};

// The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
// an extra offset)
// an extra offset). The normal form always has minimum value of zero.
// Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=9),
Expand All @@ -497,6 +497,7 @@ class IterMapRewriter : public ExprMutator {
// IterSplit(k, scale=1)),
// extent=9)
// scale=1))
// offset = 0
// Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: 1 <= j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=8),
Expand All @@ -507,9 +508,15 @@ class IterMapRewriter : public ExprMutator {
// IterSplit(k, scale=1), base=-1),
// extent=9-1)
// scale=1),
// base=1)
// base=0)
// offset = 1
std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_;
// The map for sum that maps normal form to flattened form
// For sum_fuse_map_ and flattened_map_ the following invariants hold:
// for any IterSumExpr e in the flattened_form, we have
// iter_mark, mark_offset = sum_fuse_map_[e]
// flattened_map_[normal_form] = e where normal_form = iter_mark->args[0] and
// iter_mark->args.size() = 1
std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_;
// The flattened forms of constrained iters
std::vector<IterSumExpr> constrained_iters_flattened_;
Expand Down Expand Up @@ -685,7 +692,10 @@ class IterMapRewriter : public ExprMutator {
PrimExpr mark_offset = it_mark->second.offset;
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
// the delta of iter_min when it is updated when the lower bound predicate is present
PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0);
if (predicate_induced_min.defined()) {
iter_min_delta = predicate_induced_min.value() - iter_min;
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
Expand All @@ -704,10 +714,12 @@ class IterMapRewriter : public ExprMutator {
iter_max = min(predicate_induced_max.value(), iter_max);
}
}
if (!is_zero(iter_min)) {
// When iter_min_delta is present, we need to normalize the structured form to have minimum of
// 0, and add the delta to the mark_offset
if (!is_zero(iter_min_delta)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
structured_form.CopyOnWrite()->base -= iter_min_delta;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
}
Expand All @@ -716,8 +728,9 @@ class IterMapRewriter : public ExprMutator {
// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
IterSumExprNode* normalized_expr = expr.CopyOnWrite();
normalized_expr->args = Array<IterSplitExpr>({split});
normalized_expr->base = base;
return expr;
}
ErrorLogger(this) << "Could not normalize iterators using the constraints given.";
Expand Down Expand Up @@ -1089,8 +1102,8 @@ class IterMapRewriter : public ExprMutator {
std::vector<IterSplitExpr> flattened_iters, grouped_iters;

// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_extra_base = make_const(expr.dtype(), 0);
PrimExpr tail_extent = make_const(expr.dtype(), 0);
PrimExpr expected_scale = base_scale;
int first_possible_unit_extent_pos = FindFirstPossibleUnitExtentIndex(expr);

Expand Down Expand Up @@ -1143,8 +1156,9 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale)) {
break;
}
}
}
if (k == expr->args.size()) {
Expand Down Expand Up @@ -1201,7 +1215,7 @@ class IterMapRewriter : public ExprMutator {
} else {
// new iter, form a new mark
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale) + tail_extent);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, expected_extra_base);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale)}, expr->base + expected_extra_base);
}
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_compound_floormod_two_regression():
def test_predicate():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")

# available contraints
# upper bound only
Expand Down Expand Up @@ -269,6 +270,12 @@ def test_predicate():
predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127),
)

assert_iter_sum_pattern(
{x * 64 + y * 4 + z: (16, 16)},
var_dom([(x, 16), (y, 16), (z, 4)]),
predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, 4 <= x * 16 + y),
)

# constraint on one fused iter
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
Expand Down