Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
"""Annotation operations."""
from __future__ import absolute_import as _abs
from . import _make
from ..op import register_schedule, schedule_injective
from .... import nd as _nd
from .... import TVMContext as _TVMContext


def on_device(data, device):
"""Annotate an expression with a certain device type.

Expand Down Expand Up @@ -61,3 +61,20 @@ def stop_fusion(data):
The annotated expression.
"""
return _make.stop_fusion(data)

def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.

Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.checkpoint(data)

register_schedule("annotation.checkpoint", schedule_injective)
27 changes: 27 additions & 0 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,32 @@ Mark the end of bitpacking.
return {topi::identity(inputs[0])};
});

TVM_REGISTER_API("relay.op.annotation._make.checkpoint")
.set_body_typed<Expr(Expr)>([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint");
return CallNode::make(op, {data}, Attrs{}, {});
});

RELAY_REGISTER_OP("annotation.checkpoint")
.describe(R"code(
Mark a checkpoint for checkpointing memory optimization.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
Array<Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i]));
}
return outputs;
});

} // namespace relay
} // namespace tvm
4 changes: 3 additions & 1 deletion src/relay/pass/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ Expr DeDup(const Expr& e) {
}

Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
return ret;
}

Expr VisitExpr_(const VarNode* op) final {
Expand Down
162 changes: 128 additions & 34 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,50 +273,93 @@ Type ReverseType(const Type& t) {
* by doing a structure preserving map.
*/
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
const Type& t,
const std::function<Type(const Type&)>& tf,
const Type& forward_type,
const Expr& e,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (t.as<TensorTypeNode>()) {
if (forward_type.as<TensorTypeNode>()) {
auto ret = f(e);
ret->checked_type_ = t;
ret->checked_type_ = tf(forward_type);
return ret;
} else if (auto* tt = t.as<TupleTypeNode>()) {
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
tvm::Array<Type> types;
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(LiftTensor(f,
tt->fields[i],
ll->Push(GetField(e, i)),
ll));
auto field = LiftTensor(f,
tf,
tt->fields[i],
ll->Push(GetField(e, i)),
ll);
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = TupleNode::make(fields);
ret->checked_type_ = t;
ret->checked_type_ = TupleTypeNode::make(types);
return std::move(ret);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
throw;
}
}

/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
* by stitching the references in the AD values.
*/
void TransferGrads(const Type& forward_type,
const Expr& from,
const Expr& to,
LetList* ll) {
CHECK(IsAtomic(from)) << from;
CHECK(IsAtomic(to)) << to;
if (forward_type.as<TensorTypeNode>()) {
auto from_ref = TupleGetItemNode::make(from, 1);
auto to_ref = TupleGetItemNode::make(to, 1);
ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
TransferGrads(tt->fields[i],
ll->Push(TupleGetItemNode::make(from, i)),
ll->Push(TupleGetItemNode::make(to, i)),
ll);
}
} else {
LOG(FATAL) << "Unsupported input/output type: " << forward_type;
throw;
}
}

/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) {
return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
};
return LiftTensor(rev, t, e, ll);
auto rev_type = [&](const Type& forward_type) {
return ReverseType(forward_type);
};
return LiftTensor(rev, rev_type, forward_type, e, ll);
}

/*! \brief ReverseType(t) -> t. Get the original value. */
Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
auto val = [&](const Expr& e) {
return GetField(e, 0);
};
auto val_type = [&](const Type& forward_type) {
return forward_type;
};
return LiftTensor(val, val_type, forward_type, e, ll);
}

/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) {
return ll->Push(RefReadNode::make(GetField(e, 1)));
};
return LiftTensor(grad, t, e, ll);
auto grad_type = [&](const Type& forward_type) {
return forward_type;
};
return LiftTensor(grad, grad_type, forward_type, e, ll);
}

void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
Expand All @@ -337,42 +380,87 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
}
}

Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}

struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;

Var bp;
std::shared_ptr<ADVarMap> ad_vars;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");

explicit ReverseAD(const Var& bp) : bp(bp) { }
explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
: bp(bp), ad_vars(ad_vars) { }

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}

Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) {
Expr VisitCheckpoint(const CallNode *call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));

TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return ret;
});
}

Expr VisitExpr_(const CallNode* call) final {
if (const OpNode* op_node = call->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);

if (op_ref->name == "annotation.checkpoint") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should probably factor out this branch so it will be easy to isolate and document the special handling of checkpointing.

return VisitCheckpoint(call);
}

CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
for (const auto& arg : call->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (size_t i = 0; i < args.size(); i++) {
orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
}
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
orig->checked_type_ = op->checked_type();
Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
orig->checked_type_ = call->checked_type();
Var orig_var = ll->Push(orig);
orig_var->checked_type_ = op->checked_type();
auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return CallNode::make(bpv, {});
}),
Expand All @@ -382,7 +470,7 @@ struct ReverseAD : ExprMutator {
return ret;
});
}
return ExprMutator::VisitExpr_(op);
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expand All @@ -396,24 +484,30 @@ struct ReverseAD : ExprMutator {
VisitExpr(op->false_branch));
}

Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (!ad_vars->count(var_ref)) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}

return ad_vars->at(var_ref);
}

Type VisitType(const Type& t) final {
return t.defined() ? ReverseType(t) : t;
}
};

Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}

bool MissingGrad(const Expr& e) {
struct MGVisitor : ExprVisitor {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::unordered_set<std::string> op_names;

void VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
if (!rev_map.count(op_ref)) {
if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
op_names.insert(op_ref->name);
}
ExprVisitor::VisitExpr_(op);
Expand Down Expand Up @@ -445,7 +539,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def test_cross_entropy_with_logits_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)

def test_checkpoint():
inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
output = relay.multiply(relay.add(inputs[0], inputs[1]),
relay.add(inputs[2], inputs[3]))
check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))

out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
relay.multiply(inputs[2], inputs[3])])
out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
relay.TupleGetItem(out_tuple, 1))
check_grad(relay.Function(inputs, out_single))


if __name__ == "__main__":
Expand Down
Loading