Skip to content

Commit 19b68f3

Browse files
committed
refactor, fix checkpoint compute for tuple and add tests
1 parent 7bc21ca commit 19b68f3

File tree

4 files changed

+88
-23
lines changed

4 files changed

+88
-23
lines changed

src/relay/op/annotation/annotation.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ Mark a checkpoint for checkpointing memory optimization.
164164
.set_attr<FTVMCompute>("FTVMCompute",
165165
[](const Attrs& attrs, const Array<Tensor>& inputs,
166166
const Type& out_dtype, const Target& target) -> Array<Tensor> {
167-
return {topi::identity(inputs[0])};
167+
Array<Tensor> outputs;
168+
for (size_t i = 0; i < inputs.size(); ++i) {
169+
outputs.push_back(topi::identity(inputs[i]));
170+
}
171+
return outputs;
168172
});
169173

170174
} // namespace relay

src/relay/pass/gradient.cc

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -400,33 +400,41 @@ struct ReverseAD : ExprMutator {
400400
throw;
401401
}
402402

403+
Expr VisitCheckpoint(const CallNode *call) {
404+
const OpNode* op_node = call->op.as<OpNode>();
405+
CHECK(op_node) << "expected op in call";
406+
Op op_ref = GetRef<Op>(op_node);
407+
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
408+
auto x = call->args[0];
409+
return LetList::With([&](LetList* ll) {
410+
auto x_var = ll->Push(x);
411+
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
412+
auto bpv = ll->Push(RefReadNode::make(bp));
413+
Expr nbp = FunctionNode::make(
414+
{},
415+
LetList::With([&](LetList* ll) {
416+
// we need a new ReverseAD visitor to avoid clobbering the bp local var
417+
auto dup_bp = ll->Push(BPEmpty());
418+
ReverseAD dup_diff(dup_bp, ad_vars);
419+
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
420+
421+
TransferGrads(call->checked_type(), ret, dup_ad, ll);
422+
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
423+
return CallNode::make(bpv, {});
424+
}),
425+
TupleTypeNode::make({}),
426+
{});
427+
ll->Push(RefWriteNode::make(bp, nbp));
428+
return ret;
429+
});
430+
}
431+
403432
Expr VisitExpr_(const CallNode* call) final {
404433
if (const OpNode* op_node = call->op.as<OpNode>()) {
405434
Op op_ref = GetRef<Op>(op_node);
406435

407436
if (op_ref->name == "annotation.checkpoint") {
408-
auto x = call->args[0];
409-
return LetList::With([&](LetList* ll) {
410-
auto x_var = ll->Push(x);
411-
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
412-
auto bpv = ll->Push(RefReadNode::make(bp));
413-
Expr nbp = FunctionNode::make(
414-
{},
415-
LetList::With([&](LetList* ll) {
416-
// we need a new ReverseAD visitor to avoid clobbering the bp local var
417-
auto dup_bp = ll->Push(BPEmpty());
418-
ReverseAD dup_diff(dup_bp, ad_vars);
419-
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
420-
421-
TransferGrads(call->checked_type(), ret, dup_ad, ll);
422-
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
423-
return CallNode::make(bpv, {});
424-
}),
425-
TupleTypeNode::make({}),
426-
{});
427-
ll->Push(RefWriteNode::make(bp, nbp));
428-
return ret;
429-
});
437+
return VisitCheckpoint(call);
430438
}
431439

432440
CHECK(rev_map.count(op_ref))

tests/python/relay/test_op_grad_level10.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def test_checkpoint():
3030
relay.add(inputs[2], inputs[3]))
3131
check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))
3232

33+
out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
34+
relay.multiply(inputs[2], inputs[3])])
35+
out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
36+
relay.TupleGetItem(out_tuple, 1))
37+
check_grad(relay.Function(inputs, out_single))
38+
3339

3440
if __name__ == "__main__":
3541
test_cross_entropy_grad()

tests/python/relay/test_op_level10.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,53 @@ def test_checkpoint_alpha_equal():
105105

106106
relay.analysis.assert_alpha_equal(df, df_parsed)
107107

108+
def test_checkpoint_alpha_equal_tuple():
109+
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
110+
f = relay.Function(xs, relay.annotation.checkpoint(
111+
relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
112+
))
113+
df = transform.gradient(run_infer_type(f))
114+
115+
# run PE and DCE
116+
with transform.PassContext(opt_level=3):
117+
passes = [transform.PartialEvaluate(),
118+
transform.DeadCodeElimination(inline_once=True)]
119+
mod = transform.Sequential(passes)(relay.Module.from_expr(df))
120+
df = mod["main"]
121+
122+
df_parsed = relay.parser.fromtext(
123+
"""
124+
v0.0.4
125+
fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
126+
%z: Tensor[(1), float32], %w: Tensor[(1), float32])
127+
-> ((Tensor[(1), float32], Tensor[(1), float32]),
128+
(Tensor[(1), float32], Tensor[(1), float32],
129+
Tensor[(1), float32], Tensor[(1), float32])) {
130+
let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */;
131+
let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */;
132+
let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */;
133+
let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */;
134+
%0 = (%x1, %x2);
135+
%1 = zeros_like(%x) /* ty=Tensor[(1), float32] */;
136+
%2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */;
137+
%3 = add(%1, %2) /* ty=Tensor[(1), float32] */;
138+
%4 = zeros_like(%y) /* ty=Tensor[(1), float32] */;
139+
%5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */;
140+
%6 = add(%4, %5) /* ty=Tensor[(1), float32] */;
141+
%7 = zeros_like(%z) /* ty=Tensor[(1), float32] */;
142+
%8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */;
143+
%9 = add(%7, %8) /* ty=Tensor[(1), float32] */;
144+
%10 = zeros_like(%w) /* ty=Tensor[(1), float32] */;
145+
%11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */;
146+
%12 = add(%10, %11) /* ty=Tensor[(1), float32] */;
147+
%13 = (%3, %6, %9, %12);
148+
(%0, %13)
149+
}
150+
"""
151+
)
152+
153+
relay.analysis.assert_alpha_equal(df, df_parsed)
154+
108155
def test_collapse_sum_like():
109156
shape = (3, 4, 5, 6)
110157
shape_like = (4, 5, 6)

0 commit comments

Comments
 (0)