Skip to content

Commit 93d610a

Browse files
altanhjroesch
authored andcommitted
[Relay][Training] Add checkpoint annotation for checkpointing memory optimization (#4146)
* add checkpoint annotation for checkpointing memory optimization * add alpha-equivalence checkpoint test and fix gradient type issue * fix build issues * ignore checkpoint annotation when checking missing gradients * refactor, fix checkpoint compute for tuple and add tests
1 parent 7732873 commit 93d610a

File tree

6 files changed

+309
-36
lines changed

6 files changed

+309
-36
lines changed

python/tvm/relay/op/annotation/annotation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
"""Annotation operations."""
1818
from __future__ import absolute_import as _abs
1919
from . import _make
20+
from ..op import register_schedule, schedule_injective
2021
from .... import nd as _nd
2122
from .... import TVMContext as _TVMContext
2223

23-
2424
def on_device(data, device):
2525
"""Annotate an expression with a certain device type.
2626
@@ -61,3 +61,20 @@ def stop_fusion(data):
6161
The annotated expression.
6262
"""
6363
return _make.stop_fusion(data)
64+
65+
def checkpoint(data):
66+
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
67+
68+
Parameters
69+
----------
70+
data : tvm.relay.Expr
71+
The expression to be annotated.
72+
73+
Returns
74+
-------
75+
result : tvm.relay.Expr
76+
The annotated expression.
77+
"""
78+
return _make.checkpoint(data)
79+
80+
register_schedule("annotation.checkpoint", schedule_injective)

src/relay/op/annotation/annotation.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,32 @@ Mark the end of bitpacking.
144144
return {topi::identity(inputs[0])};
145145
});
146146

147+
TVM_REGISTER_API("relay.op.annotation._make.checkpoint")
148+
.set_body_typed<Expr(Expr)>([](Expr data) {
149+
static const Op& op = Op::Get("annotation.checkpoint");
150+
return CallNode::make(op, {data}, Attrs{}, {});
151+
});
152+
153+
RELAY_REGISTER_OP("annotation.checkpoint")
154+
.describe(R"code(
155+
Mark a checkpoint for checkpointing memory optimization.
156+
)code" TVM_ADD_FILELINE)
157+
.set_num_inputs(1)
158+
.set_support_level(10)
159+
.add_type_rel("Identity", IdentityRel)
160+
.set_attr<TOpPattern>("TOpPattern", kOpaque)
161+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
162+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
163+
ElemwiseArbitraryLayout)
164+
.set_attr<FTVMCompute>("FTVMCompute",
165+
[](const Attrs& attrs, const Array<Tensor>& inputs,
166+
const Type& out_dtype, const Target& target) -> Array<Tensor> {
167+
Array<Tensor> outputs;
168+
for (size_t i = 0; i < inputs.size(); ++i) {
169+
outputs.push_back(topi::identity(inputs[i]));
170+
}
171+
return outputs;
172+
});
173+
147174
} // namespace relay
148175
} // namespace tvm

src/relay/pass/de_duplicate.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ Expr DeDup(const Expr& e) {
5252
}
5353

5454
Expr VisitExpr(const Expr& e) final {
55-
return ExprMutator::VisitExpr(e);
55+
auto ret = ExprMutator::VisitExpr(e);
56+
ret->checked_type_ = e->checked_type_;
57+
return ret;
5658
}
5759

5860
Expr VisitExpr_(const VarNode* op) final {

src/relay/pass/gradient.cc

Lines changed: 128 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -273,50 +273,93 @@ Type ReverseType(const Type& t) {
273273
* by doing a structure preserving map.
274274
*/
275275
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
276-
const Type& t,
276+
const std::function<Type(const Type&)>& tf,
277+
const Type& forward_type,
277278
const Expr& e,
278279
LetList* ll) {
279280
CHECK(IsAtomic(e)) << e;
280-
if (t.as<TensorTypeNode>()) {
281+
if (forward_type.as<TensorTypeNode>()) {
281282
auto ret = f(e);
282-
ret->checked_type_ = t;
283+
ret->checked_type_ = tf(forward_type);
283284
return ret;
284-
} else if (auto* tt = t.as<TupleTypeNode>()) {
285+
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
285286
tvm::Array<Expr> fields;
287+
tvm::Array<Type> types;
286288
for (size_t i = 0; i < tt->fields.size(); ++i) {
287-
fields.push_back(LiftTensor(f,
288-
tt->fields[i],
289-
ll->Push(GetField(e, i)),
290-
ll));
289+
auto field = LiftTensor(f,
290+
tf,
291+
tt->fields[i],
292+
ll->Push(GetField(e, i)),
293+
ll);
294+
fields.push_back(field);
295+
types.push_back(field->checked_type_);
291296
}
292297
auto ret = TupleNode::make(fields);
293-
ret->checked_type_ = t;
298+
ret->checked_type_ = TupleTypeNode::make(types);
294299
return std::move(ret);
295300
} else {
296301
LOG(FATAL) << "unsupported input/output type: " << tt;
297302
throw;
298303
}
299304
}
300305

306+
/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
307+
* by stitching the references in the AD values.
308+
*/
309+
void TransferGrads(const Type& forward_type,
310+
const Expr& from,
311+
const Expr& to,
312+
LetList* ll) {
313+
CHECK(IsAtomic(from)) << from;
314+
CHECK(IsAtomic(to)) << to;
315+
if (forward_type.as<TensorTypeNode>()) {
316+
auto from_ref = TupleGetItemNode::make(from, 1);
317+
auto to_ref = TupleGetItemNode::make(to, 1);
318+
ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
319+
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
320+
for (size_t i = 0; i < tt->fields.size(); ++i) {
321+
TransferGrads(tt->fields[i],
322+
ll->Push(TupleGetItemNode::make(from, i)),
323+
ll->Push(TupleGetItemNode::make(to, i)),
324+
ll);
325+
}
326+
} else {
327+
LOG(FATAL) << "Unsupported input/output type: " << forward_type;
328+
throw;
329+
}
330+
}
331+
301332
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
302-
Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
333+
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
303334
auto rev = [&](const Expr& e) {
304335
return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
305336
};
306-
return LiftTensor(rev, t, e, ll);
337+
auto rev_type = [&](const Type& forward_type) {
338+
return ReverseType(forward_type);
339+
};
340+
return LiftTensor(rev, rev_type, forward_type, e, ll);
307341
}
308342

309343
/*! \brief ReverseType(t) -> t. Get the original value. */
310-
Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
311-
return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
344+
Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
345+
auto val = [&](const Expr& e) {
346+
return GetField(e, 0);
347+
};
348+
auto val_type = [&](const Type& forward_type) {
349+
return forward_type;
350+
};
351+
return LiftTensor(val, val_type, forward_type, e, ll);
312352
}
313353

314354
/*! \brief ReverseType(t) -> t. Get the gradient. */
315-
Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
355+
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
316356
auto grad = [&](const Expr& e) {
317357
return ll->Push(RefReadNode::make(GetField(e, 1)));
318358
};
319-
return LiftTensor(grad, t, e, ll);
359+
auto grad_type = [&](const Type& forward_type) {
360+
return forward_type;
361+
};
362+
return LiftTensor(grad, grad_type, forward_type, e, ll);
320363
}
321364

322365
void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
@@ -337,42 +380,87 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
337380
}
338381
}
339382

383+
Expr BPEmpty() {
384+
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
385+
return RefCreateNode::make(unitF);
386+
}
387+
340388
struct ReverseAD : ExprMutator {
389+
using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;
390+
341391
Var bp;
392+
std::shared_ptr<ADVarMap> ad_vars;
342393
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
343394

344-
explicit ReverseAD(const Var& bp) : bp(bp) { }
395+
explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
396+
: bp(bp), ad_vars(ad_vars) { }
345397

346398
Expr VisitExpr_(const OpNode* op) final {
347399
LOG(FATAL) << "op should only be inside call";
348400
throw;
349401
}
350402

351-
Expr VisitExpr_(const CallNode* op) final {
352-
if (const OpNode* op_node = op->op.as<OpNode>()) {
403+
Expr VisitCheckpoint(const CallNode *call) {
404+
const OpNode* op_node = call->op.as<OpNode>();
405+
CHECK(op_node) << "expected op in call";
406+
Op op_ref = GetRef<Op>(op_node);
407+
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
408+
auto x = call->args[0];
409+
return LetList::With([&](LetList* ll) {
410+
auto x_var = ll->Push(x);
411+
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
412+
auto bpv = ll->Push(RefReadNode::make(bp));
413+
Expr nbp = FunctionNode::make(
414+
{},
415+
LetList::With([&](LetList* ll) {
416+
// we need a new ReverseAD visitor to avoid clobbering the bp local var
417+
auto dup_bp = ll->Push(BPEmpty());
418+
ReverseAD dup_diff(dup_bp, ad_vars);
419+
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
420+
421+
TransferGrads(call->checked_type(), ret, dup_ad, ll);
422+
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
423+
return CallNode::make(bpv, {});
424+
}),
425+
TupleTypeNode::make({}),
426+
{});
427+
ll->Push(RefWriteNode::make(bp, nbp));
428+
return ret;
429+
});
430+
}
431+
432+
Expr VisitExpr_(const CallNode* call) final {
433+
if (const OpNode* op_node = call->op.as<OpNode>()) {
353434
Op op_ref = GetRef<Op>(op_node);
435+
436+
if (op_ref->name == "annotation.checkpoint") {
437+
return VisitCheckpoint(call);
438+
}
439+
440+
CHECK(rev_map.count(op_ref))
441+
<< op_node->name << " does not have reverse mode defined";
354442
return LetList::With([&](LetList* ll) {
355443
std::vector<Var> args;
356-
for (const auto& arg : op->args) {
444+
for (const auto& arg : call->args) {
357445
args.push_back(ll->Push(VisitExpr(arg)));
358446
}
359447
std::vector<Expr> orig_args;
360448
for (size_t i = 0; i < args.size(); i++) {
361-
orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
449+
orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
362450
}
363-
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
364-
orig->checked_type_ = op->checked_type();
451+
Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
452+
orig->checked_type_ = call->checked_type();
365453
Var orig_var = ll->Push(orig);
366-
orig_var->checked_type_ = op->checked_type();
367-
auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
454+
orig_var->checked_type_ = call->checked_type();
455+
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
368456
auto bpv = ll->Push(RefReadNode::make(bp));
369457
Expr nbp = FunctionNode::make(
370458
{},
371459
LetList::With([&](LetList* ll) {
372-
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
460+
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
373461
CHECK(args.size() == rev.size());
374462
for (size_t i = 0; i < args.size(); ++i) {
375-
UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
463+
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
376464
}
377465
return CallNode::make(bpv, {});
378466
}),
@@ -382,7 +470,7 @@ struct ReverseAD : ExprMutator {
382470
return ret;
383471
});
384472
}
385-
return ExprMutator::VisitExpr_(op);
473+
return ExprMutator::VisitExpr_(call);
386474
}
387475

388476
Expr VisitExpr_(const ConstantNode* op) final {
@@ -396,24 +484,30 @@ struct ReverseAD : ExprMutator {
396484
VisitExpr(op->false_branch));
397485
}
398486

487+
Expr VisitExpr_(const VarNode* var) final {
488+
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
489+
auto var_ref = GetRef<Var>(var);
490+
if (!ad_vars->count(var_ref)) {
491+
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
492+
(*ad_vars)[var_ref] = res;
493+
}
494+
495+
return ad_vars->at(var_ref);
496+
}
497+
399498
Type VisitType(const Type& t) final {
400499
return t.defined() ? ReverseType(t) : t;
401500
}
402501
};
403502

404-
Expr BPEmpty() {
405-
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
406-
return RefCreateNode::make(unitF);
407-
}
408-
409503
bool MissingGrad(const Expr& e) {
410504
struct MGVisitor : ExprVisitor {
411505
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
412506
std::unordered_set<std::string> op_names;
413507

414508
void VisitExpr_(const OpNode* op) final {
415509
Op op_ref = GetRef<Op>(op);
416-
if (!rev_map.count(op_ref)) {
510+
if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
417511
op_names.insert(op_ref->name);
418512
}
419513
ExprVisitor::VisitExpr_(op);
@@ -445,7 +539,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
445539
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
446540
Expr body = LetList::With([&](LetList* ll) {
447541
Var bp = ll->Push(BPEmpty());
448-
Expr rev = ReverseAD(bp)(e);
542+
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
449543
std::vector<Expr> args;
450544
for (const auto& p : f->params) {
451545
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));

tests/python/relay/test_op_grad_level10.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ def test_cross_entropy_with_logits_grad():
3030
x = relay.var("x", shape=(2, 5))
3131
y = relay.var("y", shape=(2, 5))
3232
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
33+
34+
def test_checkpoint():
35+
inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
36+
output = relay.multiply(relay.add(inputs[0], inputs[1]),
37+
relay.add(inputs[2], inputs[3]))
38+
check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))
39+
40+
out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
41+
relay.multiply(inputs[2], inputs[3])])
42+
out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
43+
relay.TupleGetItem(out_tuple, 1))
44+
check_grad(relay.Function(inputs, out_single))
3345

3446

3547
if __name__ == "__main__":

0 commit comments

Comments
 (0)