Skip to content

Commit 42f78bb

Browse files
committed
[Relay][Frontend][TF] Add tensor array ops
1 parent d4b66da commit 42f78bb

File tree

8 files changed

+813
-28
lines changed

8 files changed

+813
-28
lines changed

python/tvm/relay/adt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(self, lhs, rhs):
186186
class Match(Expr):
187187
"""Pattern matching expression in Relay."""
188188

189-
def __init__(self, data, clauses, complete=True):
189+
def __init__(self, data, clauses, complete=False):
190190
"""Construct a Match.
191191
192192
Parameters

python/tvm/relay/frontend/tensorflow.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
# Numpy support
2525
import numpy as np
2626

27+
import pdb
28+
2729
import tvm
30+
31+
from tvm.relay.prelude import Prelude
32+
from topi.util import get_const_tuple
33+
2834
from .. import analysis
2935
from .. import expr as _expr
3036
from .. import op as _op
@@ -506,6 +512,69 @@ def _impl(inputs, attr, params):
506512
return _op.concatenate(inputs_reshaped, axis)
507513
return _impl
508514

515+
def _tensor_array():
516+
def _impl(inputs, attr, params, prelude):
517+
return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0)))
518+
return _impl
519+
520+
def _tensor_array_scatter():
521+
def _impl(inputs, attr, params, prelude):
522+
values = None
523+
import pdb
524+
# pdb.set_trace()
525+
if len(inputs[2].type_annotation.shape) == 1:
526+
pass
527+
elif len(inputs[2].type_annotation.shape) == 2:
528+
values = prelude.tensor_array_unstack_tensor2(inputs[2])
529+
530+
return prelude.tensor_array_scatter(inputs[0], inputs[1], values)
531+
return _impl
532+
533+
def _tensor_array_gather():
534+
def _impl(inputs, attr, params, prelude):
535+
return prelude.tensor_array_gather(inputs[2], inputs[1])
536+
return _impl
537+
538+
def _tensor_array_size():
539+
def _impl(inputs, attr, params, prelude):
540+
return prelude.tensor_array_size(inputs[0])
541+
return _impl
542+
543+
def _tensor_array_write():
544+
def _impl(inputs, attr, params, prelude):
545+
import pdb
546+
# pdb.set_trace()
547+
if len(inputs[2].type_annotation.shape) == 2:
548+
v = prelude.tensor2(inputs[2])
549+
elif len(inputs[2].type_annotation.shape) == 3:
550+
v = prelude.tensor3(inputs[2])
551+
return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
552+
return _impl
553+
554+
def _tensor_array_read():
555+
def _impl(inputs, attr, params, prelude):
556+
import pdb
557+
# pdb.set_trace()
558+
return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
559+
return _impl
560+
561+
def _tensor_array_split():
562+
def _impl(inputs, attr, params, prelude):
563+
import pdb
564+
if len(inputs[1].type_annotation.shape) == 2:
565+
v = prelude.tensor2(inputs[1])
566+
elif len(inputs[1].type_annotation.shape) == 3:
567+
v = prelude.tensor3(inputs[1])
568+
# pdb.set_trace()
569+
lengths = _op.cast(inputs[2], 'int32')
570+
return prelude.tensor_array_split(inputs[0], v, lengths)
571+
return _impl
572+
573+
def _tensor_array_concat():
574+
def _impl(inputs, attr, params, prelude):
575+
return prelude.tensor_array_concat(inputs[1])
576+
return _impl
577+
509578
def _tile():
510579
def _impl(inputs, attr, params):
511580
reps = params[inputs.pop().name_hint].asnumpy()
@@ -968,6 +1037,7 @@ def _impl(inputs, attr, params):
9681037

9691038
def _range():
9701039
def _impl(inputs, attr, params):
1040+
pdb.set_trace()
9711041
start = params.pop(inputs[0].name_hint).asnumpy()[0]
9721042
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
9731043
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
@@ -1285,6 +1355,14 @@ def _impl(inputs, attr, params):
12851355
'Neg' : AttrCvt('negative'),
12861356
'NotEqual' : _broadcast('not_equal'),
12871357
'Pack' : _pack(),
1358+
'TensorArrayV3' : _tensor_array(),
1359+
'TensorArrayScatterV3' : _tensor_array_scatter(),
1360+
'TensorArrayGatherV3' : _tensor_array_gather(),
1361+
'TensorArraySizeV3' : _tensor_array_size(),
1362+
'TensorArrayWriteV3' : _tensor_array_write(),
1363+
'TensorArrayReadV3' : _tensor_array_read(),
1364+
'TensorArraySplitV3' : _tensor_array_split(),
1365+
'TensorArrayConcatV3' : _tensor_array_concat(),
12881366
'Pad' : _pad('Pad'),
12891367
'PadV2' : _pad('PadV2'),
12901368
'Pow' : _elemwise('power'),
@@ -1830,6 +1908,7 @@ def __init__(self):
18301908
self._loops = {}
18311909
self._branches = {}
18321910
self._mod = _module.Module({})
1911+
self._prelude = Prelude(self._mod)
18331912

18341913
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
18351914
"""Construct relay nodes from tensorflow graph definition - GraphDef.
@@ -2306,7 +2385,11 @@ def _convert_operator(self, op_name, inputs, attrs,
23062385
if op_name in identity_list:
23072386
sym = get_relay_op(op_name)(*inputs, **attrs)
23082387
elif op_name in convert_map:
2309-
sym = convert_map[op_name](inputs, attrs, self._params)
2388+
if 'TensorArray' in op_name:
2389+
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
2390+
else:
2391+
sym = convert_map[op_name](inputs, attrs, self._params)
2392+
23102393
elif op_name in convert_map_rnn:
23112394
sym = self._convert_rnn_operator(op_name, inputs, attrs,
23122395
self._params, graph,

python/tvm/relay/op/_tensor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import topi
2121
from .op import register_compute, register_schedule, register_pattern
2222
from .op import schedule_injective, OpPattern
23+
from ...hybrid import script
24+
from ...api import convert
2325

2426
schedule_broadcast = schedule_injective
2527
schedule_elemwise = schedule_injective
@@ -104,3 +106,75 @@ def clip_compute(attrs, inputs, output_type, target):
104106
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
105107

106108
register_schedule("clip", schedule_elemwise)
109+
110+
@script
111+
def _cast_shape_function(x):
112+
out_ndim = len(x)
113+
out = output_tensor((out_ndim,), "int64")
114+
for i in const_range(out_ndim):
115+
out[i] = x[i]
116+
return out
117+
118+
def cast_shape_func(attrs, inputs, out_ndims):
119+
return [_cast_shape_function(*inputs)]
120+
121+
@script
122+
def _expand_dims_shape_func(x):
123+
ndim = len(x.shape)
124+
out = output_tensor((ndim+1,), "int64")
125+
out[0] = int64(1)
126+
for i in const_range(0, ndim):
127+
out[i+1] = int64(x.shape[i])
128+
return out
129+
130+
def expand_dims_shape_func(attrs, inputs, out_ndims):
131+
return [_expand_dims_shape_func(*inputs)]
132+
133+
# shape func
134+
@script
135+
def _broadcast_shape_func(x, y, ndim):
136+
out = output_tensor((ndim,), "int64")
137+
if len(x.shape) == 0:
138+
for i in const_range(ndim):
139+
out[i] = y[i]
140+
elif len(y.shape) == 0:
141+
for i in const_range(ndim):
142+
out[i] = x[i]
143+
else:
144+
ndim1 = x.shape[0]
145+
ndim2 = y.shape[0]
146+
for i in const_range(1, min(ndim1, ndim2)+1):
147+
if x[ndim1-i] == y[ndim2-i]:
148+
out[ndim-i] = x[ndim1-i]
149+
elif x[ndim1-i] == 1:
150+
out[ndim-i] = y[ndim2-i]
151+
else:
152+
assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
153+
x[ndim1-i], y[ndim2-i])
154+
out[ndim-i] = x[ndim1-i]
155+
for i in const_range(min(ndim1, ndim2)+1, ndim+1):
156+
if ndim1 >= ndim2:
157+
out[ndim-i] = x[ndim1-i]
158+
else:
159+
out[ndim-i] = y[ndim2-i]
160+
return out
161+
162+
def broadcast_shape_func(attrs, inputs, out_ndims):
163+
return [_broadcast_shape_func(*inputs, out_ndims[0])]
164+
165+
register_shape_func("expand_dims", False, expand_dims_shape_func)
166+
register_shape_func("cast", False, cast_shape_func)
167+
168+
register_shape_func("add", False, broadcast_shape_func)
169+
register_shape_func("subtract", False, broadcast_shape_func)
170+
register_shape_func("multiply", False, broadcast_shape_func)
171+
register_shape_func("divide", False, broadcast_shape_func)
172+
register_shape_func("mod", False, broadcast_shape_func)
173+
register_shape_func("logical_and", False, broadcast_shape_func)
174+
register_shape_func("logical_or", False, broadcast_shape_func)
175+
register_shape_func("equal", False, broadcast_shape_func)
176+
register_shape_func("not_equal", False, broadcast_shape_func)
177+
register_shape_func("less", False, broadcast_shape_func)
178+
register_shape_func("less_equal", False, broadcast_shape_func)
179+
register_shape_func("greater", False, broadcast_shape_func)
180+
register_shape_func("greater_equal", False, broadcast_shape_func)

0 commit comments

Comments
 (0)