Skip to content

Commit 653b6e0

Browse files
electricliliespfk-beta
authored andcommitted
[RELAY] [VIRTUALDEVICE] Change syntax for device planning and store parameter virtual devices in virtual_device_ field (apache#10352)
* parent 33082e0 author electriclilies <lilyorthsmith@gmail.com> 1643141097 -0800 committer Lily Orth-Smith <lilyorthsmith@gmail.com> 1645560059 -0800 Store function param virtual devices in virtual_device_ field Fix test_annotation.py and change result_virtual_device to virtual_device * Change plan devices tests to use the new syntax for function parameters * Fix free var problem * Fix attribute parsing if there is virtual device; most device planning tests passgit status * fixed lambda lifting * Debugging high order functions -- right now FunctionOnDevice and Bind are mutually recursive. This needs to not be the case. * tests pass wootgit status * Remove FunctionOnDevice from device planner * Don't use MaybeFunctionOnDevice in VM compiler * Remove MaybeFunctionOnDevice from lambda lifter * Delete FunctionOnDevice and MaybeFunctionOnDevice! * Reomve GetFunctionResultVirtualDevice * Remove GetFunctionParamVirtualDevice * lint * lint * Python formatting * Remove FunctionOnDevice python test * Fix bug in binds & debug output * Fix text printer * lint * Remove function on device from fold constant tests * Mark nits * Revert behavior of bind * clean up debug * Make ExprBinder public interface and use instead of Bind * Fix lambda lift * This is broken but not sure how to fix * passes all device planning tests yay! * Add substitution helper and use in device planner * Remove unnecessary check * Respond to comments * Update comment
1 parent c12b149 commit 653b6e0

File tree

17 files changed

+245
-281
lines changed

17 files changed

+245
-281
lines changed

include/tvm/ir/function.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,6 @@ constexpr const char* kTarget = "target";
190190
*/
191191
constexpr const char* kGlobalSymbol = "global_symbol";
192192

193-
/*!
194-
* \brief The \p VirtualDevice which will hold each of the functions parameters.
195-
*
196-
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
197-
* may be included as an annotation on user programs.
198-
*
199-
* Type: Array<VirtualDevice>
200-
*/
201-
constexpr const char* kParamVirtualDevice = "param_virtual_devices";
202-
203193
} // namespace attr
204194
} // namespace tvm
205195
#endif // TVM_IR_FUNCTION_H_

include/tvm/relay/transform.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,10 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
499499
/*!
500500
* \brief Bind the free variables to a Relay expression. This is a helper
501501
* function usually called by other pass functions to help optimizations.
502+
* If any free variables are introduced into a function, those are added
503+
* to the functoin parameters.
504+
* Additionally this may change the order of parameters if you map a variable
505+
* to a variable.
502506
*
503507
* \param expr The input expression.
504508
* \param binds The variable to expression map that will be used to help the
@@ -508,6 +512,19 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
508512
*/
509513
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
510514

515+
/*!
516+
* \brief Substitute variables with new variables (including function parameters) in a function.
517+
* This is a helper function usually called by other pass functions to help optimizations.
518+
* Expects all values in the bind map to be Vars.
519+
*
520+
* \param func The input function.
521+
* \param binds The variable to expression map that will be used to help the
522+
* binding.
523+
*
524+
* \return The updated expression.
525+
*/
526+
TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
527+
511528
/*!
512529
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
513530
* function is used as a helper function to rewrtie an expression in a pass.

include/tvm/target/virtual_device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ class VirtualDeviceCache {
367367
*
368368
* Type: VirtualDevice
369369
*/
370-
constexpr const char* kVirtualDevice = "result_virtual_device";
370+
constexpr const char* kVirtualDevice = "virtual_device";
371371

372372
} // namespace tvm
373373

src/parser/parser.cc

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,13 @@ class Parser {
456456
*
457457
* "x" -> Var("x"), these are needed to map from the raw string names
458458
* to unique variable nodes.
459+
* If a virtual device is specified, sets the virtual device of the variable.
459460
*/
460-
Var BindVar(const std::string& name, const relay::Type& type_annotation) {
461+
Var BindVar(const std::string& name, const relay::Type& type_annotation,
462+
Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) {
461463
auto var = Var(name, type_annotation);
464+
var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained());
465+
VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var);
462466
this->expr_scopes.Add(name, var);
463467
return var;
464468
}
@@ -1113,11 +1117,26 @@ class Parser {
11131117
[&]() {
11141118
auto token = Match(TokenType::kLocal);
11151119
auto string = token.ToString();
1120+
1121+
// The fake attributes where the virtual device is specified.
1122+
VirtualDevice virtual_device;
1123+
if (WhenMatch(TokenType::kLCurly)) {
1124+
Map<String, ObjectRef> fake_attrs = ParseAttrs();
1125+
VLOG(9) << "Fake attributes for function parameter: " << fake_attrs;
1126+
Match(TokenType::kRCurly);
1127+
if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) {
1128+
ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>())
1129+
<< "Expected the " << kVirtualDevice
1130+
<< " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey();
1131+
virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]);
1132+
}
1133+
}
1134+
11161135
Type type;
11171136
if (WhenMatch(TokenType::kColon)) {
11181137
type = ParseType();
11191138
}
1120-
return BindVar(string, type);
1139+
return BindVar(string, type, virtual_device);
11211140
},
11221141
[&] {
11231142
auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
@@ -1150,8 +1169,15 @@ class Parser {
11501169
ICHECK(vid.as<VirtualDeviceNode>())
11511170
<< "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
11521171
<< vid->GetTypeKey();
1153-
raw_attrs.erase(kVirtualDevice);
1154-
Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
1172+
1173+
DictAttrs attrs;
1174+
// Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the
1175+
// attributes
1176+
if (raw_attrs.size() > 1) {
1177+
raw_attrs.erase(kVirtualDevice);
1178+
attrs = DictAttrs(raw_attrs);
1179+
}
1180+
Function func = relay::Function(params, body, ret_type, generics, attrs);
11551181
func->virtual_device_ = vid;
11561182
return func;
11571183
} else {

src/printer/relay_text_printer.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
220220
}
221221
Doc val = GetUniqueName("%" + name);
222222
memo_[var] = val;
223+
if (!var->virtual_device()->IsFullyUnconstrained()) {
224+
val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}";
225+
}
223226
if (var->type_annotation.defined()) {
224227
val << ": " << Print(var->type_annotation);
225228
}
229+
226230
val << PrintOptionalInfo(var);
227231
return val;
228232
}
@@ -445,7 +449,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
445449
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
446450
params.push_back(d);
447451
}
448-
if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) {
452+
if (!fn->virtual_device()->IsFullyUnconstrained()) {
449453
Doc vid_doc;
450454
vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device());
451455
params.push_back(vid_doc);
@@ -454,7 +458,6 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
454458
if (fn->ret_type.defined()) {
455459
doc << "-> " << Print(fn->ret_type) << " ";
456460
}
457-
458461
doc << PrintBody(fn->body);
459462
return doc;
460463
}

src/relay/backend/vm/compiler.cc

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,16 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
252252
// Do that flattening on-the-fly here.
253253
Function inner_func = Downcast<Function>(func->body);
254254
std::vector<Var> params;
255-
std::vector<VirtualDevice> param_virtual_devices;
256255
params.reserve(func->params.size() + inner_func->params.size());
257-
param_virtual_devices.reserve(func->params.size() + inner_func->params.size());
258256
param_device_indexes.reserve(func->params.size() + inner_func->params.size());
259257
for (size_t i = 0; i < func->params.size(); ++i) {
260258
params.emplace_back(func->params[i]);
261-
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i);
262-
param_virtual_devices.push_back(param_virtual_device);
263-
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
259+
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
264260
}
265261
for (size_t i = 0; i < inner_func->params.size(); ++i) {
266262
params.emplace_back(inner_func->params[i]);
267-
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(inner_func.get(), i);
268-
param_virtual_devices.push_back(param_virtual_device);
269-
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
263+
264+
param_device_indexes.push_back(GetDeviceIndex(inner_func->params[i]->virtual_device()));
270265
}
271266
std::vector<TypeVar> type_params;
272267
type_params.reserve(func->type_params.size() + inner_func->type_params.size());
@@ -278,13 +273,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
278273
}
279274
Function flattened_func = Function(params, inner_func->body, inner_func->ret_type,
280275
type_params, func->attrs, func->span);
281-
VisitExpr(MaybeFunctionOnDevice(flattened_func, param_virtual_devices,
282-
GetFunctionResultVirtualDevice(inner_func.get())));
276+
flattened_func->virtual_device_ = inner_func->virtual_device();
277+
VisitExpr(flattened_func);
283278
} else {
284279
param_device_indexes.reserve(func->params.size());
285280
for (size_t i = 0; i < func->params.size(); ++i) {
286-
param_device_indexes.push_back(
287-
GetDeviceIndex(GetFunctionParamVirtualDevice(func.get(), i)));
281+
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
288282
}
289283
VisitExpr(func);
290284
}

src/relay/backend/vm/lambda_lift.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,21 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
111111
auto free_type_vars = FreeTypeVars(func, module_);
112112

113113
Array<Var> captured_vars;
114-
std::vector<VirtualDevice> captured_var_virtual_devices;
115114
bool recursive = false;
116115
for (const auto& var : free_vars) {
117116
if (!letrec_.empty() && var == letrec_.back()) {
118117
recursive = true;
119118
continue;
120119
}
121120
captured_vars.push_back(var);
122-
captured_var_virtual_devices.push_back(GetVirtualDevice(var));
123121
}
124122

125123
// Freshen all the captured vars.
126124
Array<Var> typed_captured_vars;
127125
Map<Var, Expr> rebinding_map;
128126
for (auto free_var : captured_vars) {
129127
auto var = Var(free_var->name_hint(), free_var->checked_type());
128+
var->virtual_device_ = GetVirtualDevice(free_var);
130129
typed_captured_vars.push_back(var);
131130
rebinding_map.Set(free_var, var);
132131
}
@@ -173,6 +172,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
173172
if (captured_vars.empty() && free_type_vars.empty()) {
174173
lifted_func = Function(body->params, body->body, body->ret_type, body->type_params,
175174
body->attrs, body->span);
175+
// We also need to copy the virtual device
176+
lifted_func->virtual_device_ = body->virtual_device();
176177
} else {
177178
// When a closure is locally bound in a program, we have its full type information
178179
// avalible to us.
@@ -187,14 +188,14 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
187188
// construct the "closure" function with fully annotated arguments, no longer relying
188189
// on type inference.
189190
size_t before_arity = body->params.size();
191+
VLOG(9) << "Binding " << rebinding_map << " into\n" << PrettyPrint(body->body);
190192
auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map));
191193
size_t after_arity = rebound_body->params.size();
192194
CHECK_EQ(before_arity, after_arity);
193195
lifted_func =
194196
Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
195197
free_type_vars, /*attrs=*/{}, func->span);
196-
lifted_func =
197-
MaybeFunctionOnDevice(lifted_func, captured_var_virtual_devices, result_virtual_device);
198+
lifted_func->virtual_device_ = result_virtual_device;
198199
lifted_func = MarkClosure(lifted_func);
199200
}
200201

src/relay/ir/expr_functor.cc

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,45 +472,42 @@ class ExprBinder : public MixedModeMutator, PatternMutator {
472472
const tvm::Map<Var, Expr>& args_map_;
473473
};
474474

475+
// This function should be called SubstAndBind, since it assumes any variables introduced
476+
// in the substitution right hand side should be implicitly bound in the function.
475477
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
476478
if (const FunctionNode* func = expr.as<FunctionNode>()) {
477479
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
478480
Array<Var> new_params;
479-
std::vector<VirtualDevice> new_param_virtual_devices;
480481
for (size_t i = 0; i < func->params.size(); ++i) {
481482
if (!args_map.count(func->params[i])) {
482483
new_params.push_back(func->params[i]);
483-
new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i));
484484
}
485485
}
486486
if (new_body.same_as(func->body) && new_params.size() == func->params.size()) {
487487
return expr;
488488
}
489+
489490
auto ret =
490491
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
491-
ret =
492-
MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
492+
ret->virtual_device_ = func->virtual_device();
493+
493494
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;
494495
for (const auto& v : FreeVars(expr)) {
495496
set.insert(v);
496497
}
497498
for (const auto& v : FreeVars(ret)) {
498499
if (set.count(v) == 0) {
499500
new_params.push_back(v);
500-
if (!GetFunctionResultVirtualDevice(func)->IsFullyUnconstrained()) {
501-
// TODO(mbs): The function has been annotated with a device, which means we are supposed
502-
// to be preserving device annotations on every transformation. However there's no
503-
// such context for the free vars in args_map.
504-
LOG(WARNING) << "introduced free var '" << PrettyPrint(v)
505-
<< "' into function body but no device is known for it";
506-
}
507-
new_param_virtual_devices.push_back(VirtualDevice::FullyUnconstrained());
508501
}
509502
}
503+
510504
ret =
511505
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
512-
ret =
513-
MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
506+
ret->virtual_device_ = func->virtual_device();
507+
508+
VLOG(4) << "Expr:\n" << expr;
509+
VLOG(4) << "Ret:\n" << ret;
510+
514511
ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
515512
return std::move(ret);
516513
} else {
@@ -528,6 +525,27 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret)
528525
}
529526
});
530527

528+
Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& args_map) {
529+
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
530+
Array<Var> new_params;
531+
for (size_t i = 0; i < func->params.size(); i++) {
532+
if (!args_map.count(func->params[i])) {
533+
new_params.push_back(func->params[i]);
534+
} else {
535+
if (const VarNode* var = args_map[func->params[i]].as<VarNode>()) {
536+
new_params.push_back(GetRef<Var>(var));
537+
} else {
538+
ICHECK(false) << "Expected all values in args_map to be vars, but found "
539+
<< args_map[func->params[i]]->GetTypeKey();
540+
}
541+
}
542+
}
543+
auto ret =
544+
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
545+
ret->virtual_device_ = func->virtual_device();
546+
return ret;
547+
}
548+
531549
void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
532550
std::function<void(const LetNode*)> post_visit) {
533551
std::stack<const LetNode*> stack;

src/relay/op/memory/on_device.cc

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/relay/expr.h>
3030
#include <tvm/relay/op.h>
3131
#include <tvm/relay/op_attr_types.h>
32+
#include <tvm/relay/transform.h>
3233

3334
#include "../../transforms/infer_layout_utils.h"
3435
#include "../type_relations.h"
@@ -142,48 +143,5 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {
142143
return {};
143144
}
144145

145-
Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
146-
VirtualDevice result_virtual_device) {
147-
auto func = WithAttr(
148-
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
149-
tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices));
150-
VLOG(1) << "Annotated func: " << PrettyPrint(func);
151-
return func;
152-
}
153-
154-
TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
155-
156-
Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
157-
VirtualDevice result_virtual_device) {
158-
if (std::all_of(param_virtual_devices.begin(), param_virtual_devices.end(),
159-
[](const VirtualDevice& virtual_device) {
160-
return virtual_device->IsFullyUnconstrained();
161-
}) &&
162-
result_virtual_device->IsFullyUnconstrained()) {
163-
// Nothing to annotate.
164-
return function;
165-
}
166-
return FunctionOnDevice(function, std::move(param_virtual_devices),
167-
std::move(result_virtual_device));
168-
}
169-
170-
VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
171-
return function_node->virtual_device();
172-
}
173-
174-
VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
175-
ICHECK_LT(i, function_node->params.size())
176-
<< "param index " << i << " out of range for function of arity "
177-
<< function_node->params.size();
178-
auto opt_array = function_node->GetAttr<Array<VirtualDevice>>(tvm::attr::kParamVirtualDevice);
179-
if (!opt_array) {
180-
// No annotation.
181-
return VirtualDevice::FullyUnconstrained();
182-
}
183-
ICHECK_EQ(opt_array.value().size(), function_node->params.size())
184-
<< "annotation parameters do not match function arity";
185-
return opt_array.value()[i];
186-
}
187-
188146
} // namespace relay
189147
} // namespace tvm

0 commit comments

Comments
 (0)