From 3f472f94b3d8140dbb4f77f75c9781cb3e850a7f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 4 Nov 2019 10:37:41 -0800 Subject: [PATCH] [Relay][Frontend][Tensorflow] Fix GatherV2, Add StopGradient (#4238) * Add StopGradient. Add batch_dims attr to ignore list for GatherV2 * Trigger CI --- python/tvm/relay/frontend/tensorflow.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4bee712e9200..0abcb09d6ace 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -872,11 +872,14 @@ def _impl(inputs, attr, params): axis = _get_num_param(params, inputs.pop(2)) else: axis = 0 + if int(attr.get('batch_dims', 0)) != 0: + raise tvm.error.OpAttributeUnImplemented( + 'Attribute batch_dims is not supported') new_input = inputs[0:2] return AttrCvt(op_name="take", extras={'axis': tvm.const(axis, 'int32')}, ignores=['Tindices', 'Tparams', 'validate_indices', - 'Taxis', '_class'])(new_input, attr) + 'Taxis', '_class', 'batch_dims'])(new_input, attr) return _impl def _gather_nd(): @@ -1472,6 +1475,7 @@ def _impl(inputs, attr, params): 'Square' : _square(), 'SquaredDifference' : _squared_difference(), 'Squeeze' : _squeeze(), + 'StopGradient' : _identity(), 'StridedSlice' : _stridedSlice(), 'Sub' : _elemwise('subtract'), 'Sum' : _sum(),