|
22 | 22 |
|
23 | 23 | import warnings |
24 | 24 | from collections import defaultdict |
| 25 | + |
25 | 26 | # Numpy support |
26 | 27 | import numpy as np |
27 | 28 |
|
28 | 29 | import tvm |
| 30 | + |
| 31 | +from tvm.relay.prelude import Prelude |
| 32 | + |
29 | 33 | from .. import analysis |
30 | 34 | from .. import expr as _expr |
31 | 35 | from .. import op as _op |
@@ -508,6 +512,69 @@ def _impl(inputs, attr, params): |
508 | 512 | return _op.concatenate(inputs_reshaped, axis) |
509 | 513 | return _impl |
510 | 514 |
|
| 515 | +def _tensor_array(): |
| 516 | + def _impl(inputs, attr, params, prelude): |
| 517 | + dtype_str = attr.get('dtype').name |
| 518 | + tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) |
| 519 | + return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) |
| 520 | + return _impl |
| 521 | + |
| 522 | +def _tensor_array_scatter(): |
| 523 | + def _impl(inputs, attr, params, prelude): |
| 524 | + dtype_str = attr.get('T').name |
| 525 | + values_rank = len(inputs[2].type_annotation.shape) |
| 526 | + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) |
| 527 | + unstack_function = prelude.get_var(unstack_name, dtype_str) |
| 528 | + values = unstack_function(inputs[2]) |
| 529 | + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) |
| 530 | + return tensor_array_scatter_func(inputs[0], inputs[1], values) |
| 531 | + return _impl |
| 532 | + |
| 533 | +def _tensor_array_gather(): |
| 534 | + def _impl(inputs, attr, params, prelude): |
| 535 | + return prelude.tensor_array_gather(inputs[2], inputs[1]) |
| 536 | + return _impl |
| 537 | + |
| 538 | +def _tensor_array_size(): |
| 539 | + def _impl(inputs, attr, params, prelude): |
| 540 | + return prelude.length(inputs[0]) |
| 541 | + return _impl |
| 542 | + |
| 543 | +def _tensor_array_write(): |
| 544 | + def _impl(inputs, attr, params, prelude): |
| 545 | + input_rank = len(inputs[2].type_annotation.shape) |
| 546 | + dtype = attr.get('T').name |
| 547 | + |
| 548 | + tensor_name = 'tensor{}'.format(input_rank) |
| 549 | + tensor_func = prelude.get_var(tensor_name, dtype) |
| 550 | + v = tensor_func(inputs[2]) |
| 551 | + write_func = prelude.get_var('tensor_array_write', dtype) |
| 552 | + |
| 553 | + return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) |
| 554 | + return _impl |
| 555 | + |
| 556 | +def _tensor_array_read(): |
| 557 | + def _impl(inputs, attr, params, prelude): |
| 558 | + read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) |
| 559 | + return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) |
| 560 | + return _impl |
| 561 | + |
| 562 | +def _tensor_array_split(): |
| 563 | + def _impl(inputs, attr, params, prelude): |
| 564 | + input_rank = len(inputs[1].type_annotation.shape) |
| 565 | + dtype_str = attr.get('T').name |
| 566 | + v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) |
| 567 | + lengths = _op.cast(inputs[2], 'int32') |
| 568 | + split_var = prelude.get_var('tensor_array_split', dtype_str) |
| 569 | + return split_var(inputs[0], v, lengths) |
| 570 | + return _impl |
| 571 | + |
| 572 | +def _tensor_array_concat(): |
| 573 | + def _impl(inputs, attr, params, prelude): |
| 574 | + concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) |
| 575 | + return concat_func(inputs[1]) |
| 576 | + return _impl |
| 577 | + |
511 | 578 | def _tile(): |
512 | 579 | def _impl(inputs, attr, params): |
513 | 580 | reps = _get_list_param(params, inputs.pop()) |
@@ -1313,6 +1380,14 @@ def _impl(inputs, attr, params): |
1313 | 1380 | 'NotEqual' : _broadcast('not_equal'), |
1314 | 1381 | 'OneHot' : _one_hot(), |
1315 | 1382 | 'Pack' : _pack(), |
| 1383 | + 'TensorArrayV3' : _tensor_array(), |
| 1384 | + 'TensorArrayScatterV3' : _tensor_array_scatter(), |
| 1385 | + 'TensorArrayGatherV3' : _tensor_array_gather(), |
| 1386 | + 'TensorArraySizeV3' : _tensor_array_size(), |
| 1387 | + 'TensorArrayWriteV3' : _tensor_array_write(), |
| 1388 | + 'TensorArrayReadV3' : _tensor_array_read(), |
| 1389 | + 'TensorArraySplitV3' : _tensor_array_split(), |
| 1390 | + 'TensorArrayConcatV3' : _tensor_array_concat(), |
1316 | 1391 | 'Pad' : _pad('Pad'), |
1317 | 1392 | 'PadV2' : _pad('PadV2'), |
1318 | 1393 | 'Pow' : _elemwise('power'), |
@@ -1860,6 +1935,7 @@ def __init__(self): |
1860 | 1935 | self._loops = {} |
1861 | 1936 | self._branches = {} |
1862 | 1937 | self._mod = _module.Module({}) |
| 1938 | + self._prelude = Prelude(self._mod) |
1863 | 1939 |
|
1864 | 1940 | def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): |
1865 | 1941 | """Construct relay nodes from tensorflow graph definition - GraphDef. |
@@ -2335,7 +2411,11 @@ def _convert_operator(self, op_name, inputs, attrs, |
2335 | 2411 | if op_name in identity_list: |
2336 | 2412 | sym = get_relay_op(op_name)(*inputs, **attrs) |
2337 | 2413 | elif op_name in convert_map: |
2338 | | - sym = convert_map[op_name](inputs, attrs, self._params) |
| 2414 | + if 'TensorArray' in op_name: |
| 2415 | + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) |
| 2416 | + else: |
| 2417 | + sym = convert_map[op_name](inputs, attrs, self._params) |
| 2418 | + |
2339 | 2419 | elif op_name in convert_map_rnn: |
2340 | 2420 | sym = self._convert_rnn_operator(op_name, inputs, attrs, |
2341 | 2421 | self._params, graph, |
|
0 commit comments