@@ -273,50 +273,93 @@ Type ReverseType(const Type& t) {
273273 * by doing a structure preserving map.
274274 */
275275Expr LiftTensor (const std::function<Expr(const Expr& t)>& f,
276- const Type& t,
276+ const std::function<Type(const Type&)>& tf,
277+ const Type& forward_type,
277278 const Expr& e,
278279 LetList* ll) {
279280 CHECK (IsAtomic (e)) << e;
280- if (t .as <TensorTypeNode>()) {
281+ if (forward_type .as <TensorTypeNode>()) {
281282 auto ret = f (e);
282- ret->checked_type_ = t ;
283+ ret->checked_type_ = tf (forward_type) ;
283284 return ret;
284- } else if (auto * tt = t .as <TupleTypeNode>()) {
285+ } else if (auto * tt = forward_type .as <TupleTypeNode>()) {
285286 tvm::Array<Expr> fields;
287+ tvm::Array<Type> types;
286288 for (size_t i = 0 ; i < tt->fields .size (); ++i) {
287- fields.push_back (LiftTensor (f,
288- tt->fields [i],
289- ll->Push (GetField (e, i)),
290- ll));
289+ auto field = LiftTensor (f,
290+ tf,
291+ tt->fields [i],
292+ ll->Push (GetField (e, i)),
293+ ll);
294+ fields.push_back (field);
295+ types.push_back (field->checked_type_ );
291296 }
292297 auto ret = TupleNode::make (fields);
293- ret->checked_type_ = t ;
298+ ret->checked_type_ = TupleTypeNode::make (types) ;
294299 return std::move (ret);
295300 } else {
296301 LOG (FATAL) << " unsupported input/output type: " << tt;
297302 throw ;
298303 }
299304}
300305
306+ /* ! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
307+ * by stitching the references in the AD values.
308+ */
309+ void TransferGrads (const Type& forward_type,
310+ const Expr& from,
311+ const Expr& to,
312+ LetList* ll) {
313+ CHECK (IsAtomic (from)) << from;
314+ CHECK (IsAtomic (to)) << to;
315+ if (forward_type.as <TensorTypeNode>()) {
316+ auto from_ref = TupleGetItemNode::make (from, 1 );
317+ auto to_ref = TupleGetItemNode::make (to, 1 );
318+ ll->Push (RefWriteNode::make (to_ref, RefReadNode::make (from_ref)));
319+ } else if (auto * tt = forward_type.as <TupleTypeNode>()) {
320+ for (size_t i = 0 ; i < tt->fields .size (); ++i) {
321+ TransferGrads (tt->fields [i],
322+ ll->Push (TupleGetItemNode::make (from, i)),
323+ ll->Push (TupleGetItemNode::make (to, i)),
324+ ll);
325+ }
326+ } else {
327+ LOG (FATAL) << " Unsupported input/output type: " << forward_type;
328+ throw ;
329+ }
330+ }
331+
301332/* ! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
302- Expr GetRev (const Type& t , const Expr& e, LetList* ll) {
333+ Expr GetRev (const Type& forward_type , const Expr& e, LetList* ll) {
303334 auto rev = [&](const Expr& e) {
304335 return Pair (e, ll->Push (RefCreateNode::make (ZerosLike (e))));
305336 };
306- return LiftTensor (rev, t, e, ll);
337+ auto rev_type = [&](const Type& forward_type) {
338+ return ReverseType (forward_type);
339+ };
340+ return LiftTensor (rev, rev_type, forward_type, e, ll);
307341}
308342
309343/* ! \brief ReverseType(t) -> t. Get the original value. */
310- Expr GetValue (const Type& t, const Expr& e, LetList* ll) {
311- return LiftTensor ([&](const Expr& e) { return GetField (e, 0 ); }, t, e, ll);
344+ Expr GetValue (const Type& forward_type, const Expr& e, LetList* ll) {
345+ auto val = [&](const Expr& e) {
346+ return GetField (e, 0 );
347+ };
348+ auto val_type = [&](const Type& forward_type) {
349+ return forward_type;
350+ };
351+ return LiftTensor (val, val_type, forward_type, e, ll);
312352}
313353
314354/* ! \brief ReverseType(t) -> t. Get the gradient. */
315- Expr GetGrad (const Type& t , const Expr& e, LetList* ll) {
355+ Expr GetGrad (const Type& forward_type , const Expr& e, LetList* ll) {
316356 auto grad = [&](const Expr& e) {
317357 return ll->Push (RefReadNode::make (GetField (e, 1 )));
318358 };
319- return LiftTensor (grad, t, e, ll);
359+ auto grad_type = [&](const Type& forward_type) {
360+ return forward_type;
361+ };
362+ return LiftTensor (grad, grad_type, forward_type, e, ll);
320363}
321364
322365void UpdateGrad (const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
@@ -337,42 +380,87 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
337380 }
338381}
339382
383+ Expr BPEmpty () {
384+ Expr unitF = FunctionNode::make ({}, TupleNode::make ({}), TupleTypeNode::make ({}), {});
385+ return RefCreateNode::make (unitF);
386+ }
387+
340388struct ReverseAD : ExprMutator {
389+ using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;
390+
341391 Var bp;
392+ std::shared_ptr<ADVarMap> ad_vars;
342393 const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>(" FPrimalGradient" );
343394
344- explicit ReverseAD (const Var& bp) : bp(bp) { }
395+ explicit ReverseAD (const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
396+ : bp(bp), ad_vars(ad_vars) { }
345397
346398 Expr VisitExpr_ (const OpNode* op) final {
347399 LOG (FATAL) << " op should only be inside call" ;
348400 throw ;
349401 }
350402
351- Expr VisitExpr_ (const CallNode* op) final {
352- if (const OpNode* op_node = op->op .as <OpNode>()) {
403+ Expr VisitCheckpoint (const CallNode *call) {
404+ const OpNode* op_node = call->op .as <OpNode>();
405+ CHECK (op_node) << " expected op in call" ;
406+ Op op_ref = GetRef<Op>(op_node);
407+ CHECK (op_ref->name == " annotation.checkpoint" ) << " expected checkpoint annotation" ;
408+ auto x = call->args [0 ];
409+ return LetList::With ([&](LetList* ll) {
410+ auto x_var = ll->Push (x);
411+ auto ret = ll->Push (GetRev (call->checked_type (), x_var, ll));
412+ auto bpv = ll->Push (RefReadNode::make (bp));
413+ Expr nbp = FunctionNode::make (
414+ {},
415+ LetList::With ([&](LetList* ll) {
416+ // we need a new ReverseAD visitor to avoid clobbering the bp local var
417+ auto dup_bp = ll->Push (BPEmpty ());
418+ ReverseAD dup_diff (dup_bp, ad_vars);
419+ auto dup_ad = ll->Push (dup_diff.VisitExpr (DeDup (x)));
420+
421+ TransferGrads (call->checked_type (), ret, dup_ad, ll);
422+ ll->Push (CallNode::make (RefReadNode::make (dup_bp), {}));
423+ return CallNode::make (bpv, {});
424+ }),
425+ TupleTypeNode::make ({}),
426+ {});
427+ ll->Push (RefWriteNode::make (bp, nbp));
428+ return ret;
429+ });
430+ }
431+
432+ Expr VisitExpr_ (const CallNode* call) final {
433+ if (const OpNode* op_node = call->op .as <OpNode>()) {
353434 Op op_ref = GetRef<Op>(op_node);
435+
436+ if (op_ref->name == " annotation.checkpoint" ) {
437+ return VisitCheckpoint (call);
438+ }
439+
440+ CHECK (rev_map.count (op_ref))
441+ << op_node->name << " does not have reverse mode defined" ;
354442 return LetList::With ([&](LetList* ll) {
355443 std::vector<Var> args;
356- for (const auto & arg : op ->args ) {
444+ for (const auto & arg : call ->args ) {
357445 args.push_back (ll->Push (VisitExpr (arg)));
358446 }
359447 std::vector<Expr> orig_args;
360448 for (size_t i = 0 ; i < args.size (); i++) {
361- orig_args.push_back (GetValue (op ->args [i]->checked_type (), args[i], ll));
449+ orig_args.push_back (GetValue (call ->args [i]->checked_type (), args[i], ll));
362450 }
363- Expr orig = CallNode::make (op ->op , orig_args, op ->attrs , op ->type_args );
364- orig->checked_type_ = op ->checked_type ();
451+ Expr orig = CallNode::make (call ->op , orig_args, call ->attrs , call ->type_args );
452+ orig->checked_type_ = call ->checked_type ();
365453 Var orig_var = ll->Push (orig);
366- orig_var->checked_type_ = op ->checked_type ();
367- auto ret = ll->Push (GetRev (op ->checked_type (), orig_var, ll));
454+ orig_var->checked_type_ = call ->checked_type ();
455+ auto ret = ll->Push (GetRev (call ->checked_type (), orig_var, ll));
368456 auto bpv = ll->Push (RefReadNode::make (bp));
369457 Expr nbp = FunctionNode::make (
370458 {},
371459 LetList::With ([&](LetList* ll) {
372- tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad (op ->checked_type (), ret, ll));
460+ tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad (call ->checked_type (), ret, ll));
373461 CHECK (args.size () == rev.size ());
374462 for (size_t i = 0 ; i < args.size (); ++i) {
375- UpdateGrad (op ->args [i]->checked_type (), args[i], rev[i], ll);
463+ UpdateGrad (call ->args [i]->checked_type (), args[i], rev[i], ll);
376464 }
377465 return CallNode::make (bpv, {});
378466 }),
@@ -382,7 +470,7 @@ struct ReverseAD : ExprMutator {
382470 return ret;
383471 });
384472 }
385- return ExprMutator::VisitExpr_ (op );
473+ return ExprMutator::VisitExpr_ (call );
386474 }
387475
388476 Expr VisitExpr_ (const ConstantNode* op) final {
@@ -396,24 +484,30 @@ struct ReverseAD : ExprMutator {
396484 VisitExpr (op->false_branch ));
397485 }
398486
487+ Expr VisitExpr_ (const VarNode* var) final {
488+ // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
489+ auto var_ref = GetRef<Var>(var);
490+ if (!ad_vars->count (var_ref)) {
491+ auto res = Downcast<Var>(ExprMutator::VisitExpr_ (var));
492+ (*ad_vars)[var_ref] = res;
493+ }
494+
495+ return ad_vars->at (var_ref);
496+ }
497+
399498 Type VisitType (const Type& t) final {
400499 return t.defined () ? ReverseType (t) : t;
401500 }
402501};
403502
404- Expr BPEmpty () {
405- Expr unitF = FunctionNode::make ({}, TupleNode::make ({}), TupleTypeNode::make ({}), {});
406- return RefCreateNode::make (unitF);
407- }
408-
409503bool MissingGrad (const Expr& e) {
410504 struct MGVisitor : ExprVisitor {
411505 const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>(" FPrimalGradient" );
412506 std::unordered_set<std::string> op_names;
413507
414508 void VisitExpr_ (const OpNode* op) final {
415509 Op op_ref = GetRef<Op>(op);
416- if (!rev_map.count (op_ref)) {
510+ if (op_ref-> name != " annotation.checkpoint " && !rev_map.count (op_ref)) {
417511 op_names.insert (op_ref->name );
418512 }
419513 ExprVisitor::VisitExpr_ (op);
@@ -445,7 +539,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
445539 CHECK (!MissingGrad (e)) << " input has operators with missing gradients" ;
446540 Expr body = LetList::With ([&](LetList* ll) {
447541 Var bp = ll->Push (BPEmpty ());
448- Expr rev = ReverseAD (bp)(e);
542+ Expr rev = ReverseAD (bp, std::make_shared<ReverseAD::ADVarMap>() )(e);
449543 std::vector<Expr> args;
450544 for (const auto & p : f->params ) {
451545 args.push_back (ll->Push (Pair (p, RefCreateNode::make (ZerosLike (p)))));
0 commit comments