Skip to content

Commit ed2ef64

Browse files
committed
BUG: Look through on_device annotations when looking for shape constants
#8788 introduced a perf regression since a `shape.as<ConstantNode>` in `alloc_tensor` was always failing due to the extra `on_device` annotation on the constant. Fixed that, and introduced some helpers to make this situation easier to deal with. (This is CORE-102 in OctoML JIRA). (Second try -- test_crp.py failure seems unrelated)
1 parent e62075d commit ed2ef64

File tree

7 files changed

+52
-19
lines changed

7 files changed

+52
-19
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
182182
* \return The corresponding token.
183183
*/
184184
StorageInfo GetStorage(const Expr& expr) {
185-
auto props = GetOnDeviceProps(expr);
186185
// See through "on_device" calls.
187-
Expr true_expr = props.body.defined() ? props.body : expr;
186+
Expr true_expr = IgnoreOnDevice(expr);
188187
VisitExpr(true_expr);
189188
auto it = storage_device_map_.find(true_expr);
190189
ICHECK(it != storage_device_map_.end());

src/relay/backend/graph_plan_memory.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
146146
* \return The corresponding token.
147147
*/
148148
const std::vector<StorageToken*>& GetToken(const Expr& expr) {
149-
this->VisitExpr(expr);
150149
// See through on_device calls.
151-
auto props = GetOnDeviceProps(expr);
152-
Expr real_expr = props.body.defined() ? props.body : expr;
150+
Expr real_expr = IgnoreOnDevice(expr);
151+
this->VisitExpr(real_expr);
153152
auto it = token_map_.find(real_expr.get());
154153
ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl
155154
<< PrettyPrint(real_expr);

src/relay/backend/vm/compiler.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
594594
auto offset_register = last_register_;
595595

596596
// If the shape is constant then we will emit a static tensor allocation
597-
// instruction.
598-
auto const_shape = args[2].as<ConstantNode>();
597+
// instruction. It may be wrapped by an on_device, but it will be on the host
598+
// which is assumed by the alloc_tensor instruction anyway.
599+
auto const_shape = AsIgnoringOnDevice<ConstantNode>(args[2]);
599600

600601
if (const_shape) {
601602
NDArray shape = const_shape->data;
@@ -619,7 +620,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
619620
this->VisitExpr(args[0]);
620621
auto size_register = last_register_;
621622

622-
ICHECK(args[1].as<ConstantNode>());
623+
ICHECK(args[1].as<ConstantNode>()); // Always a literal.
623624
NDArray alignment_arr = args[1].as<ConstantNode>()->data;
624625
ICHECK_EQ(alignment_arr->dtype.code, 0U)
625626
<< "The dtype of constant shape must be int32 or int64, but got "

src/relay/op/annotation/annotation.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node);
8585
*/
8686
OnDeviceProps GetOnDeviceProps(const Expr& expr);
8787

88+
/*!
89+
* \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns
90+
* \p expr directly.
91+
*/
92+
inline Expr IgnoreOnDevice(const Expr& expr) {
93+
OnDeviceProps props = GetOnDeviceProps(expr);
94+
return props.body.defined() ? props.body : expr;
95+
}
96+
97+
/*!
98+
* \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through
99+
* any "on_device" annotations.
100+
*/
101+
template <typename NodeType>
102+
const NodeType* AsIgnoringOnDevice(const Expr& expr) {
103+
const auto* node = expr.as<NodeType>();
104+
if (node != nullptr) {
105+
return node;
106+
}
107+
OnDeviceProps props = GetOnDeviceProps(expr);
108+
if (!props.body.defined()) {
109+
return nullptr;
110+
}
111+
return props.body.as<NodeType>();
112+
}
113+
88114
/*!
89115
* \brief Returns \p function annotated with "param_device_types" and "result_device_type"
90116
* attributes capturing parameter and result devices types respectively.

src/relay/op/memory/memory.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,9 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype,
101101
attrs->assert_shape = assert_shape;
102102
} else {
103103
// Look through any on_device for the shape argument expression.
104-
Expr literal_shape = shape;
105-
auto props = GetOnDeviceProps(literal_shape);
106-
if (props.body.defined()) {
107-
// See through on_device calls.
108-
literal_shape = props.body;
109-
}
110-
attrs->const_shape = Downcast<Constant>(literal_shape);
104+
const auto* constant_node = AsIgnoringOnDevice<ConstantNode>(shape);
105+
ICHECK(constant_node);
106+
attrs->const_shape = GetRef<Constant>(constant_node);
111107
}
112108
static const Op& op = Op::Get("memory.alloc_tensor");
113109
return Call(op, {storage, offset, shape}, Attrs(attrs), {});

src/relay/transforms/pass_utils.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,8 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr&
118118
* is it atomic?
119119
* if so, the compute cost of the expression is bounded so it can be copy without graph mode.
120120
*/
121-
inline bool IsAtomic(const Expr& e) {
122-
auto props = GetOnDeviceProps(e);
123-
Expr true_expr = props.body.defined() ? props.body : e;
121+
inline bool IsAtomic(const Expr& expr) {
122+
Expr true_expr = IgnoreOnDevice(expr);
124123
return true_expr.as<VarNode>() || true_expr.as<OpNode>() || true_expr.as<ConstructorNode>() ||
125124
true_expr.as<GlobalVarNode>() ||
126125
true_expr.as<ConstantNode>(); // Constant is always by reference.

tests/python/relay/test_vm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,19 @@ def test_vm_reshape_tensor(target, dev):
766766
check_result(target, dev, [x_np, y_np], x_np.reshape([8, 2, 8]), mod)
767767

768768

769+
def test_vm_reshape_and_copy(target, dev):
770+
"""Make sure the compiler notices the reshape result shape is a literal and can use
771+
the immediate-mode alloc_tensor instruction instead of alloc_tensor_reg."""
772+
x_np = np.random.uniform(size=(1, 1)).astype("float32")
773+
x = relay.var("x", shape=(1, 1), dtype="float32")
774+
mod = tvm.IRModule.from_expr(relay.Function([x], relay.copy(relay.reshape(x, [0, 1]))))
775+
with tvm.transform.PassContext(opt_level=3):
776+
exec = relay.vm.compile(mod, "llvm")
777+
assert "alloc_tensor" in exec.bytecode
778+
assert not "alloc_tensor_reg" in exec.bytecode
779+
check_result(target, dev, [x_np], x_np.reshape([1, 1]), mod)
780+
781+
769782
def test_vm_reshape_tuple(target, dev, x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
770783
tup = relay.var(
771784
"tup",
@@ -963,4 +976,4 @@ def test_benchmark_end_to_end_rpc():
963976
if __name__ == "__main__":
964977
import sys
965978

966-
sys.exit(pytest.main(sys.argv))
979+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)