Skip to content

Commit

Permalink
[Unity][Analysis] Handle PrimStructInfo in EraseToWellDefined (#16304)
Browse files Browse the repository at this point in the history
* [Unity][Analysis] Handle PrimStructInfo in EraseToWellDefined

Prior to this commit, the `EraseToWellDefined` pass would update
symbolic variable definitions in `ShapeStructInfo` and
`TensorStructInfo`, but did not in `PrimStructInfo`.  This commit
updates the `WellDefinedEraser` to include symbolic variables defined
in `PrimStructInfo`.

* Update collecting of symbolic variables in InferSymbolicVarMap

* CI bump due to flaky unit test

`tests/python/relax/test_frontend_onnx.py::test_attention` fails for
some inputs.  Failures on `unity` head occurred 3/100 test cases
  • Loading branch information
Lunderberg authored Jan 4, 2024
1 parent d88cc42 commit d509661
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 4 deletions.
32 changes: 31 additions & 1 deletion src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ class WellDefinedEraser : public StructInfoMutator,
std::function<Optional<Expr>(const Var& var)> f_var_map, arith::Analyzer* ana)
: f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {}

StructInfo VisitStructInfo_(const PrimStructInfoNode* op) final {
bool has_undefined = false;
Optional<PrimExpr> value;

if (op->value.defined()) {
std::swap(has_undefined_, has_undefined);
value = VisitPrimExpr(op->value.value());
std::swap(has_undefined_, has_undefined);
}

// erase symbolic shape if we have undefined.
if (!has_undefined) {
if (value.same_as(op->value)) {
return GetRef<StructInfo>(op);
} else {
return PrimStructInfo(value.value(), op->span);
}
} else {
return PrimStructInfo(op->dtype, op->span);
}
}

StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final {
bool has_undefined = false;
Optional<Array<PrimExpr>> values;
Expand Down Expand Up @@ -295,7 +317,15 @@ class StructInfoBaseChecker
if (other.as<ObjectStructInfoNode>()) return BaseCheckResult::kFailL1;
return BaseCheckResult::kFailL0;
}
return lhs->dtype == rhs->dtype ? BaseCheckResult::kPass : BaseCheckResult::kFailL0;

if (lhs->dtype != rhs->dtype) {
return BaseCheckResult::kFailL0;
}

if (!lhs->value.defined()) return BaseCheckResult::kPass;
if (!rhs->value.defined()) return BaseCheckResult::kFailL2;

return PrimValueMatchCheck(lhs->value.value(), rhs->value.value());
}

BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final {
Expand Down
11 changes: 10 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,22 @@ class BlockBuilderImpl : public BlockBuilderNode {

void VisitStructInfo_(const ShapeStructInfoNode* op) final {
for (const PrimExpr& s : op->values.value_or(Array<PrimExpr>())) {
// Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1))
// Only collect single var defined shape. Ignore something like `R.Shape((m + 1, n + 1))
if (const auto* var = s.as<tir::VarNode>()) {
shape_var_map_.Set(GetRef<tir::Var>(var), s);
}
}
}

void VisitStructInfo_(const PrimStructInfoNode* op) final {
// Only collect single var defined shape. Ignore something like `R.Prim(value=m + 1)`
if (op->value.defined()) {
if (auto var = op->value.as<tir::Var>()) {
shape_var_map_.Set(var.value(), op->value.value());
}
}
}

private:
Map<tir::Var, PrimExpr> shape_var_map_;
};
Expand Down
18 changes: 18 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
}
};

auto bind_from_prim_value = [&bind_from_prim_expr](const StructInfo& var,
const StructInfo& expr) {
auto var_sinfo = var.as<PrimStructInfoNode>();
if (!var_sinfo) return;

auto expr_sinfo = expr.as<PrimStructInfoNode>();
CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype)
<< "Cannot bind expression with struct type " << expr << " to variable with struct type "
<< var << ", due to conflicting PrimExpr DataType";

if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return;

bind_from_prim_expr(var_sinfo->value.value(), expr_sinfo->value.value());
};

auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) {
auto var_shape = var.as<ShapeStructInfoNode>();
if (!var_shape) return;
Expand Down Expand Up @@ -195,6 +212,7 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(

bind_from_tensor(var_sinfo, expr_sinfo);
bind_from_shape(var_sinfo, expr_sinfo);
bind_from_prim_value(var_sinfo, expr_sinfo);
}

return tir_var_remap;
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relax/test_bind_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,14 @@ def expected() -> R.Shape([16]):


def test_bind_prim_value(prim_value_dtype):
if prim_value_dtype != "int64":
pytest.xfail(reason="Currently, only support int64 as known symbolic value")

N = tir.Var("N", prim_value_dtype)
value = tir.const(16, prim_value_dtype)

@R.function
def before(A: R.Prim(value=N)):
def before(A: R.Prim(value=N)) -> R.Prim(value=N):
R.func_attr({"global_symbol": "main"})
B: R.Prim(value=N) = A
return B
Expand Down
94 changes: 93 additions & 1 deletion tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")):
return w


def test_erase_to_well_defined():
def test_erase_to_well_defined_removes_internal_vars():
@R.function
def foo(x: R.Tensor):
q = x
Expand All @@ -1172,9 +1172,101 @@ def foo(x: R.Tensor):
return w

tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2))
assert foo.ret_struct_info.shape is None
_check(foo)


def test_erase_to_well_defined_keeps_variables_exposed_by_tensor_shape():
@R.function
def foo(x: R.Tensor(["m", "n"])):
q = x
m, n = T.int64(), T.int64()
z = R.match_cast(q, R.Tensor((m, n)))
w = z
return w

assert foo.ret_struct_info.shape is not None
_check(foo)


def test_erase_to_well_defined_keeps_variants_exposed_by_shape_expr():
@R.function
def foo(x: R.Tensor, _: R.Shape(["m", "n"])):
q = x
m, n = T.int64(), T.int64()
z = R.match_cast(q, R.Tensor((m, n)))
w = z
return w

assert foo.ret_struct_info.shape is not None
_check(foo)


def test_erase_to_well_defined_keeps_variants_exposed_by_prim_value():
@R.function
def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")):
q = x
m, n = T.int64(), T.int64()
z = R.match_cast(q, R.Tensor((m, n)))
w = z
return w

assert foo.ret_struct_info.shape is not None
_check(foo)


def test_erase_to_well_defined_infers_from_shape_expr():
@I.ir_module
class Module:
# The subroutine's symbolic variables are only in-scope for the subroutine.
@R.function
def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]):
q = x
m, n = T.int64(), T.int64()
z = R.match_cast(q, R.Tensor((m, n)))
w = z
return w

# However, struct inference can make the symbolic variables in
# the main function to the symbolic variables in the
# subroutine. Therefore, the shape of the tensor returned
# from main can have a well-defined shape.
@R.function
def main(x: R.Tensor, shape: R.Shape(["m", "n"])):
output = Module.subroutine(x, shape)
return output

assert Module["main"].ret_struct_info.shape is not None
_check(Module)


def test_erase_to_well_defined_infers_from_prim_value():
@I.ir_module
class Module:
# The subroutine's symbolic variables are only in-scope for the subroutine.
@R.function
def subroutine(
x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")
) -> R.Tensor(["m", "n"]):
q = x
m, n = T.int64(), T.int64()
z = R.match_cast(q, R.Tensor((m, n)))
w = z
return w

# However, struct inference can make the symbolic variables in
# the main function to the symbolic variables in the
# subroutine. Therefore, the shape of the tensor returned
# from main can have a well-defined shape.
@R.function
def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n: R.Prim(value="n")):
output = Module.subroutine(x, relax_m, relax_n)
return output

assert Module["main"].ret_struct_info.shape is not None
_check(Module)


def test_empty_tuple():
@R.function
def foo(x: R.Tuple()):
Expand Down

0 comments on commit d509661

Please sign in to comment.