Skip to content

Commit 3390745

Browse files
committed
[Relay][Frontend][TF] Add tensor array ops
1 parent 4b431c6 commit 3390745

File tree

8 files changed

+817
-30
lines changed

8 files changed

+817
-30
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
import numpy as np
2626

2727
import tvm
28+
29+
from tvm.relay.prelude import Prelude
30+
2831
from .. import analysis
2932
from .. import expr as _expr
3033
from .. import op as _op
@@ -505,6 +508,61 @@ def _impl(inputs, attr, params):
505508
return _op.concatenate(inputs_reshaped, axis)
506509
return _impl
507510

511+
def _tensor_array():
512+
def _impl(inputs, attr, params, prelude):
513+
return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0)))
514+
return _impl
515+
516+
def _tensor_array_scatter():
517+
def _impl(inputs, attr, params, prelude):
518+
values = None
519+
if len(inputs[2].type_annotation.shape) == 1:
520+
pass
521+
elif len(inputs[2].type_annotation.shape) == 2:
522+
values = prelude.tensor_array_unstack_tensor2(inputs[2])
523+
524+
return prelude.tensor_array_scatter(inputs[0], inputs[1], values)
525+
return _impl
526+
527+
def _tensor_array_gather():
528+
def _impl(inputs, attr, params, prelude):
529+
return prelude.tensor_array_gather(inputs[2], inputs[1])
530+
return _impl
531+
532+
def _tensor_array_size():
533+
def _impl(inputs, attr, params, prelude):
534+
return prelude.tensor_array_size(inputs[0])
535+
return _impl
536+
537+
def _tensor_array_write():
538+
def _impl(inputs, attr, params, prelude):
539+
if len(inputs[2].type_annotation.shape) == 2:
540+
v = prelude.tensor2(inputs[2])
541+
elif len(inputs[2].type_annotation.shape) == 3:
542+
v = prelude.tensor3(inputs[2])
543+
return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
544+
return _impl
545+
546+
def _tensor_array_read():
547+
def _impl(inputs, attr, params, prelude):
548+
return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
549+
return _impl
550+
551+
def _tensor_array_split():
552+
def _impl(inputs, attr, params, prelude):
553+
if len(inputs[1].type_annotation.shape) == 2:
554+
v = prelude.tensor2(inputs[1])
555+
elif len(inputs[1].type_annotation.shape) == 3:
556+
v = prelude.tensor3(inputs[1])
557+
lengths = _op.cast(inputs[2], 'int32')
558+
return prelude.tensor_array_split(inputs[0], v, lengths)
559+
return _impl
560+
561+
def _tensor_array_concat():
562+
def _impl(inputs, attr, params, prelude):
563+
return prelude.tensor_array_concat(inputs[1])
564+
return _impl
565+
508566
def _tile():
509567
def _impl(inputs, attr, params):
510568
reps = _get_list_param(params, inputs.pop())
@@ -1302,6 +1360,14 @@ def _impl(inputs, attr, params):
13021360
'NotEqual' : _broadcast('not_equal'),
13031361
'OneHot' : _one_hot(),
13041362
'Pack' : _pack(),
1363+
'TensorArrayV3' : _tensor_array(),
1364+
'TensorArrayScatterV3' : _tensor_array_scatter(),
1365+
'TensorArrayGatherV3' : _tensor_array_gather(),
1366+
'TensorArraySizeV3' : _tensor_array_size(),
1367+
'TensorArrayWriteV3' : _tensor_array_write(),
1368+
'TensorArrayReadV3' : _tensor_array_read(),
1369+
'TensorArraySplitV3' : _tensor_array_split(),
1370+
'TensorArrayConcatV3' : _tensor_array_concat(),
13051371
'Pad' : _pad('Pad'),
13061372
'PadV2' : _pad('PadV2'),
13071373
'Pow' : _elemwise('power'),
@@ -1847,6 +1913,7 @@ def __init__(self):
18471913
self._loops = {}
18481914
self._branches = {}
18491915
self._mod = _module.Module({})
1916+
self._prelude = Prelude(self._mod)
18501917

18511918
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
18521919
"""Construct relay nodes from tensorflow graph definition - GraphDef.
@@ -2322,7 +2389,11 @@ def _convert_operator(self, op_name, inputs, attrs,
23222389
if op_name in identity_list:
23232390
sym = get_relay_op(op_name)(*inputs, **attrs)
23242391
elif op_name in convert_map:
2325-
sym = convert_map[op_name](inputs, attrs, self._params)
2392+
if 'TensorArray' in op_name:
2393+
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
2394+
else:
2395+
sym = convert_map[op_name](inputs, attrs, self._params)
2396+
23262397
elif op_name in convert_map_rnn:
23272398
sym = self._convert_rnn_operator(op_name, inputs, attrs,
23282399
self._params, graph,

python/tvm/relay/op/_tensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,29 @@ def clip_compute(attrs, inputs, output_type, target):
107107

108108
register_schedule("clip", schedule_elemwise)
109109

110+
@script
111+
def _cast_shape_function(x):
112+
out_ndim = len(x)
113+
out = output_tensor((out_ndim,), "int64")
114+
for i in const_range(out_ndim):
115+
out[i] = x[i]
116+
return out
117+
118+
def cast_shape_func(attrs, inputs, out_ndims):
119+
return [_cast_shape_function(*inputs)]
120+
121+
@script
122+
def _expand_dims_shape_func(x):
123+
ndim = len(x.shape)
124+
out = output_tensor((ndim+1,), "int64")
125+
out[0] = int64(1)
126+
for i in const_range(0, ndim):
127+
out[i+1] = int64(x.shape[i])
128+
return out
129+
130+
def expand_dims_shape_func(attrs, inputs, out_ndims):
131+
return [_expand_dims_shape_func(*inputs)]
132+
110133
# shape func
111134
@script
112135
def _broadcast_shape_func(x, y, ndim):
@@ -139,6 +162,9 @@ def _broadcast_shape_func(x, y, ndim):
139162
def broadcast_shape_func(attrs, inputs, out_ndims):
140163
return [_broadcast_shape_func(*inputs, out_ndims[0])]
141164

165+
register_shape_func("expand_dims", False, expand_dims_shape_func)
166+
register_shape_func("cast", False, cast_shape_func)
167+
142168
register_shape_func("add", False, broadcast_shape_func)
143169
register_shape_func("subtract", False, broadcast_shape_func)
144170
register_shape_func("multiply", False, broadcast_shape_func)

0 commit comments

Comments
 (0)