Skip to content

Commit 4e3a080

Browse files
committed
fix first-order AD on tuple arguments
1 parent d164aac commit 4e3a080

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

src/relay/transforms/gradient.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,22 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
181181
return ret;
182182
}
183183

184+
Expr UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
185+
if (t.as<TensorTypeNode>()) {
186+
return ll->Push(Add(arg, grad));
187+
} else if (auto* tt = t.as<TupleTypeNode>()) {
188+
Array<Expr> updates;
189+
for (size_t i = 0; i < tt->fields.size(); ++i) {
190+
updates.push_back(this->UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)),
191+
ll->Push(GetField(grad, i)), ll));
192+
}
193+
return ll->Push(Tuple(updates));
194+
} else {
195+
LOG(FATAL) << "unsupported arg type of operator: " << t;
196+
throw;
197+
}
198+
}
199+
184200
ADValue VisitExpr_(const OpNode* op) final {
185201
Op op_ref = GetRef<Op>(op);
186202
ICHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined";
@@ -198,8 +214,10 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
198214
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
199215
ICHECK(args.size() == rev.size());
200216
for (size_t i = 0; i < args.size(); ++i) {
217+
auto ad_arg = args[i]->get<ADTensor>();
218+
auto ad_arg_type = ad_arg.forward->checked_type();
201219
args[i]->get<ADTensor>().reverse =
202-
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
220+
this->UpdateGrad(ad_arg_type, ad_arg.reverse, rev[i], ll);
203221
}
204222
});
205223
return ret;

tests/python/relay/test_pass_gradient.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,29 @@ def _test_tuple(mode):
255255
tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
256256

257257

258+
def _test_tuple_argument(mode):
259+
shape = (2, 3)
260+
dtype = "float32"
261+
tensor_type = relay.TensorType(shape, dtype)
262+
fields = 3
263+
tuple_type = relay.TupleType([tensor_type] * fields)
264+
tup = relay.var("tup", type_annotation=tuple_type)
265+
body = relay.TupleGetItem(tup, 0)
266+
for i in range(1, fields):
267+
body = relay.add(body, relay.TupleGetItem(tup, i))
268+
func = relay.Function([tup], body)
269+
func = run_infer_type(func)
270+
back_func = run_infer_type(gradient(func, mode=mode))
271+
xs = [rand(dtype, *shape) for _ in range(fields)]
272+
xs_np = np.array([x.asnumpy() for x in xs])
273+
expected_forward = np.sum(xs_np, axis=0)
274+
ex = create_executor()
275+
forward, grad = ex.evaluate(back_func)(tuple(xs))
276+
tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
277+
for field in grad[0]:
278+
tvm.testing.assert_allclose(field.asnumpy(), np.ones_like(field.asnumpy()))
279+
280+
258281
def test_tuple():
259282
_test_tuple("higher_order")
260283

@@ -263,6 +286,16 @@ def test_tuple_first_order():
263286
_test_tuple("first_order")
264287

265288

289+
@pytest.mark.xfail(raises=tvm.error.TVMError)
290+
def test_tuple_argument():
291+
# fails until we add support for top-level tuple arguments in higher-order AD
292+
_test_tuple_argument("higher_order")
293+
294+
295+
def test_tuple_argument_first_order():
296+
_test_tuple_argument("first_order")
297+
298+
266299
def test_pow():
267300
mod = tvm.IRModule()
268301
p = Prelude(mod)

0 commit comments

Comments
 (0)