Skip to content

Commit 5faa6f7

Browse files
wweickevinthesun
authored andcommitted
[Relay][Frontend][TF] Add tensor array ops (apache#3798)
* [Relay][Frontend][TF] Add tensor array ops * rename * delete test * Move utility function * Refactor * fix tensor array ops * fix test * fix rebase * Fix serializer bug * Improve tf convert name lookup to use prelude api * Fix lint * Fix test
1 parent a6f37a2 commit 5faa6f7

File tree

8 files changed

+899
-10
lines changed

8 files changed

+899
-10
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222

2323
import warnings
2424
from collections import defaultdict
25+
2526
# Numpy support
2627
import numpy as np
2728

2829
import tvm
30+
31+
from tvm.relay.prelude import Prelude
32+
2933
from .. import analysis
3034
from .. import expr as _expr
3135
from .. import op as _op
@@ -508,6 +512,69 @@ def _impl(inputs, attr, params):
508512
return _op.concatenate(inputs_reshaped, axis)
509513
return _impl
510514

515+
def _tensor_array():
516+
def _impl(inputs, attr, params, prelude):
517+
dtype_str = attr.get('dtype').name
518+
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
519+
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
520+
return _impl
521+
522+
def _tensor_array_scatter():
523+
def _impl(inputs, attr, params, prelude):
524+
dtype_str = attr.get('T').name
525+
values_rank = len(inputs[2].type_annotation.shape)
526+
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
527+
unstack_function = prelude.get_var(unstack_name, dtype_str)
528+
values = unstack_function(inputs[2])
529+
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
530+
return tensor_array_scatter_func(inputs[0], inputs[1], values)
531+
return _impl
532+
533+
def _tensor_array_gather():
534+
def _impl(inputs, attr, params, prelude):
535+
return prelude.tensor_array_gather(inputs[2], inputs[1])
536+
return _impl
537+
538+
def _tensor_array_size():
539+
def _impl(inputs, attr, params, prelude):
540+
return prelude.length(inputs[0])
541+
return _impl
542+
543+
def _tensor_array_write():
544+
def _impl(inputs, attr, params, prelude):
545+
input_rank = len(inputs[2].type_annotation.shape)
546+
dtype = attr.get('T').name
547+
548+
tensor_name = 'tensor{}'.format(input_rank)
549+
tensor_func = prelude.get_var(tensor_name, dtype)
550+
v = tensor_func(inputs[2])
551+
write_func = prelude.get_var('tensor_array_write', dtype)
552+
553+
return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
554+
return _impl
555+
556+
def _tensor_array_read():
557+
def _impl(inputs, attr, params, prelude):
558+
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
559+
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
560+
return _impl
561+
562+
def _tensor_array_split():
563+
def _impl(inputs, attr, params, prelude):
564+
input_rank = len(inputs[1].type_annotation.shape)
565+
dtype_str = attr.get('T').name
566+
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
567+
lengths = _op.cast(inputs[2], 'int32')
568+
split_var = prelude.get_var('tensor_array_split', dtype_str)
569+
return split_var(inputs[0], v, lengths)
570+
return _impl
571+
572+
def _tensor_array_concat():
573+
def _impl(inputs, attr, params, prelude):
574+
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
575+
return concat_func(inputs[1])
576+
return _impl
577+
511578
def _tile():
512579
def _impl(inputs, attr, params):
513580
reps = _get_list_param(params, inputs.pop())
@@ -1313,6 +1380,14 @@ def _impl(inputs, attr, params):
13131380
'NotEqual' : _broadcast('not_equal'),
13141381
'OneHot' : _one_hot(),
13151382
'Pack' : _pack(),
1383+
'TensorArrayV3' : _tensor_array(),
1384+
'TensorArrayScatterV3' : _tensor_array_scatter(),
1385+
'TensorArrayGatherV3' : _tensor_array_gather(),
1386+
'TensorArraySizeV3' : _tensor_array_size(),
1387+
'TensorArrayWriteV3' : _tensor_array_write(),
1388+
'TensorArrayReadV3' : _tensor_array_read(),
1389+
'TensorArraySplitV3' : _tensor_array_split(),
1390+
'TensorArrayConcatV3' : _tensor_array_concat(),
13161391
'Pad' : _pad('Pad'),
13171392
'PadV2' : _pad('PadV2'),
13181393
'Pow' : _elemwise('power'),
@@ -1860,6 +1935,7 @@ def __init__(self):
18601935
self._loops = {}
18611936
self._branches = {}
18621937
self._mod = _module.Module({})
1938+
self._prelude = Prelude(self._mod)
18631939

18641940
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
18651941
"""Construct relay nodes from tensorflow graph definition - GraphDef.
@@ -2335,7 +2411,11 @@ def _convert_operator(self, op_name, inputs, attrs,
23352411
if op_name in identity_list:
23362412
sym = get_relay_op(op_name)(*inputs, **attrs)
23372413
elif op_name in convert_map:
2338-
sym = convert_map[op_name](inputs, attrs, self._params)
2414+
if 'TensorArray' in op_name:
2415+
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
2416+
else:
2417+
sym = convert_map[op_name](inputs, attrs, self._params)
2418+
23392419
elif op_name in convert_map_rnn:
23402420
sym = self._convert_rnn_operator(op_name, inputs, attrs,
23412421
self._params, graph,

python/tvm/relay/op/_tensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target):
108108

109109
register_schedule("clip", schedule_elemwise)
110110

111+
@script
112+
def _cast_shape_function(x):
113+
out_ndim = len(x)
114+
out = output_tensor((out_ndim,), "int64")
115+
for i in const_range(out_ndim):
116+
out[i] = x[i]
117+
return out
118+
119+
def cast_shape_func(attrs, inputs, out_ndims):
120+
return [_cast_shape_function(*inputs)]
121+
122+
@script
123+
def _expand_dims_shape_func(x):
124+
ndim = len(x.shape)
125+
out = output_tensor((ndim+1,), "int64")
126+
out[0] = int64(1)
127+
for i in const_range(0, ndim):
128+
out[i+1] = int64(x.shape[i])
129+
return out
130+
131+
def expand_dims_shape_func(attrs, inputs, out_ndims):
132+
return [_expand_dims_shape_func(*inputs)]
133+
111134
# shape func
112135
@script
113136
def _broadcast_shape_func(x, y, ndim):
@@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim):
140163
def broadcast_shape_func(attrs, inputs, out_ndims):
141164
return [_broadcast_shape_func(*inputs, out_ndims[0])]
142165

166+
register_shape_func("expand_dims", False, expand_dims_shape_func)
167+
register_shape_func("cast", False, cast_shape_func)
168+
143169
register_shape_func("add", False, broadcast_shape_func)
144170
register_shape_func("subtract", False, broadcast_shape_func)
145171
register_shape_func("multiply", False, broadcast_shape_func)

0 commit comments

Comments
 (0)