Skip to content

Commit

Permalink
[ARITH] normalize iter affine map expr to PrimExpr (apache#7759)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored and trevor-m committed May 11, 2021
1 parent b86a277 commit cbf84c3
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import detect_iter_map
from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr
16 changes: 16 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,19 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals
Empty array if no match can be found.
"""
return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective)


def normalize_iter_map_to_expr(expr):
"""Given an IterMapExpr, transform it to normal PrimExpr
Parameters
----------
expr : IterMapExpr
the input IterMapExpr
Returns
-------
result : PrimExpr
the corresponding normal PrimExpr
"""
return _ffi_api.NormalizeIterMapToExpr(expr)
58 changes: 58 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1028,5 +1028,63 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
}
}

/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
class IterMapToExprNormalizer {
public:
explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}

PrimExpr Convert(const IterMapExpr& expr) {
if (const auto* op = expr.as<IterSplitExprNode>()) {
return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
} else if (const auto* op = expr.as<IterSumExprNode>()) {
return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
} else {
ICHECK(expr.defined());
LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
return 0;
}
}

PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
PrimExpr res = 0;
for (const IterSplitExpr& arg : expr->args) {
res += ConvertIterSplitExpr(arg);
}
res += expr->base;
return res;
}

PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
PrimExpr source;
if (const auto* op = expr->source->source.as<VarNode>()) {
source = GetRef<Var>(op);
} else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
} else {
LOG(FATAL) << "Unexpected source of IterSplitExpr";
}
if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
return source * expr->scale;
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale;
}
}

private:
Analyzer* analyzer_;
};

PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr) {
arith::Analyzer analyzer;
IterMapToExprNormalizer normalizer(&analyzer);
return normalizer.Convert(expr);
}

TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const IterMapExpr& expr) {
return NormalizeIterMapToExpr(expr);
});

} // namespace arith
} // namespace tvm
File renamed without changes.
21 changes: 21 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,30 @@ def test_predicate():
assert len(res) == 0


def test_normalize_iter_map_to_expr():
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

x = tvm.tir.Var("x", "int32"), 10
y = tvm.tir.Var("y", "int32"), 9

xo, xi = isplit(x, 5)
yo, yi = isplit(y, 3)
z = ifuse([yo, xo, yi])

res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y]))

tvm.ir.assert_structural_equal(
tvm.arith.normalize_iter_map_to_expr(res[0]),
fld(y[0], 3) * 6 + fld(x[0], 5) * 3 + flm(y[0], 3),
)
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5))


if __name__ == "__main__":
test_split()
test_trivial()
test_fuse()
test_compound()
test_predicate()
test_normalize_iter_map_to_expr()

0 comments on commit cbf84c3

Please sign in to comment.