Skip to content

Commit

Permalink
[Relay][Frontend][Tensorflow] Fix GatherV2, Add StopGradient (#4238)
Browse files Browse the repository at this point in the history
* Add StopGradient. Add batch_dims attr to ignore list for GatherV2

* Trigger CI
  • Loading branch information
trevor-m authored and tqchen committed Nov 4, 2019
1 parent 996cf30 commit 3f472f9
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,11 +872,14 @@ def _impl(inputs, attr, params):
axis = _get_num_param(params, inputs.pop(2))
else:
axis = 0
if int(attr.get('batch_dims', 0)) != 0:
raise tvm.error.OpAttributeUnImplemented(
'Attribute batch_dims is not supported')
new_input = inputs[0:2]
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class'])(new_input, attr)
'Taxis', '_class', 'batch_dims'])(new_input, attr)
return _impl

def _gather_nd():
Expand Down Expand Up @@ -1472,6 +1475,7 @@ def _impl(inputs, attr, params):
'Square' : _square(),
'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(),
'StopGradient' : _identity(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
'Sum' : _sum(),
Expand Down

0 comments on commit 3f472f9

Please sign in to comment.