@@ -255,6 +255,29 @@ def _test_tuple(mode):
255
255
tvm .testing .assert_allclose (grad_z .asnumpy (), - 1 * np .ones_like (grad_z .asnumpy ()))
256
256
257
257
258
+ def _test_tuple_argument (mode ):
259
+ shape = (2 , 3 )
260
+ dtype = "float32"
261
+ tensor_type = relay .TensorType (shape , dtype )
262
+ fields = 3
263
+ tuple_type = relay .TupleType ([tensor_type ] * fields )
264
+ tup = relay .var ("tup" , type_annotation = tuple_type )
265
+ body = relay .TupleGetItem (tup , 0 )
266
+ for i in range (1 , fields ):
267
+ body = relay .add (body , relay .TupleGetItem (tup , i ))
268
+ func = relay .Function ([tup ], body )
269
+ func = run_infer_type (func )
270
+ back_func = run_infer_type (gradient (func , mode = mode ))
271
+ xs = [rand (dtype , * shape ) for _ in range (fields )]
272
+ xs_np = np .array ([x .asnumpy () for x in xs ])
273
+ expected_forward = np .sum (xs_np , axis = 0 )
274
+ ex = create_executor ()
275
+ forward , grad = ex .evaluate (back_func )(tuple (xs ))
276
+ tvm .testing .assert_allclose (forward .asnumpy (), expected_forward )
277
+ for field in grad [0 ]:
278
+ tvm .testing .assert_allclose (field .asnumpy (), np .ones_like (field .asnumpy ()))
279
+
280
+
258
281
def test_tuple ():
259
282
_test_tuple ("higher_order" )
260
283
@@ -263,6 +286,16 @@ def test_tuple_first_order():
263
286
_test_tuple ("first_order" )
264
287
265
288
289
+ @pytest .mark .xfail (raises = tvm .error .TVMError )
290
+ def test_tuple_argument ():
291
+ # fails until we add support for top-level tuple arguments in higher-order AD
292
+ _test_tuple_argument ("higher_order" )
293
+
294
+
295
+ def test_tuple_argument_first_order ():
296
+ _test_tuple_argument ("first_order" )
297
+
298
+
266
299
def test_pow ():
267
300
mod = tvm .IRModule ()
268
301
p = Prelude (mod )
0 commit comments