|
25 | 25 | import numpy as np |
26 | 26 |
|
27 | 27 | import tvm |
| 28 | + |
| 29 | +from tvm.relay.prelude import Prelude |
| 30 | + |
28 | 31 | from .. import analysis |
29 | 32 | from .. import expr as _expr |
30 | 33 | from .. import op as _op |
@@ -505,6 +508,61 @@ def _impl(inputs, attr, params): |
505 | 508 | return _op.concatenate(inputs_reshaped, axis) |
506 | 509 | return _impl |
507 | 510 |
|
| 511 | +def _tensor_array(): |
| 512 | + def _impl(inputs, attr, params, prelude): |
| 513 | + return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0))) |
| 514 | + return _impl |
| 515 | + |
| 516 | +def _tensor_array_scatter(): |
| 517 | + def _impl(inputs, attr, params, prelude): |
| 518 | + values = None |
| 519 | + if len(inputs[2].type_annotation.shape) == 1: |
| 520 | + pass |
| 521 | + elif len(inputs[2].type_annotation.shape) == 2: |
| 522 | + values = prelude.tensor_array_unstack_tensor2(inputs[2]) |
| 523 | + |
| 524 | + return prelude.tensor_array_scatter(inputs[0], inputs[1], values) |
| 525 | + return _impl |
| 526 | + |
| 527 | +def _tensor_array_gather(): |
| 528 | + def _impl(inputs, attr, params, prelude): |
| 529 | + return prelude.tensor_array_gather(inputs[2], inputs[1]) |
| 530 | + return _impl |
| 531 | + |
| 532 | +def _tensor_array_size(): |
| 533 | + def _impl(inputs, attr, params, prelude): |
| 534 | + return prelude.tensor_array_size(inputs[0]) |
| 535 | + return _impl |
| 536 | + |
| 537 | +def _tensor_array_write(): |
| 538 | + def _impl(inputs, attr, params, prelude): |
| 539 | + if len(inputs[2].type_annotation.shape) == 2: |
| 540 | + v = prelude.tensor2(inputs[2]) |
| 541 | + elif len(inputs[2].type_annotation.shape) == 3: |
| 542 | + v = prelude.tensor3(inputs[2]) |
| 543 | + return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) |
| 544 | + return _impl |
| 545 | + |
| 546 | +def _tensor_array_read(): |
| 547 | + def _impl(inputs, attr, params, prelude): |
| 548 | + return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) |
| 549 | + return _impl |
| 550 | + |
| 551 | +def _tensor_array_split(): |
| 552 | + def _impl(inputs, attr, params, prelude): |
| 553 | + if len(inputs[1].type_annotation.shape) == 2: |
| 554 | + v = prelude.tensor2(inputs[1]) |
| 555 | + elif len(inputs[1].type_annotation.shape) == 3: |
| 556 | + v = prelude.tensor3(inputs[1]) |
| 557 | + lengths = _op.cast(inputs[2], 'int32') |
| 558 | + return prelude.tensor_array_split(inputs[0], v, lengths) |
| 559 | + return _impl |
| 560 | + |
| 561 | +def _tensor_array_concat(): |
| 562 | + def _impl(inputs, attr, params, prelude): |
| 563 | + return prelude.tensor_array_concat(inputs[1]) |
| 564 | + return _impl |
| 565 | + |
508 | 566 | def _tile(): |
509 | 567 | def _impl(inputs, attr, params): |
510 | 568 | reps = _get_list_param(params, inputs.pop()) |
@@ -1302,6 +1360,14 @@ def _impl(inputs, attr, params): |
1302 | 1360 | 'NotEqual' : _broadcast('not_equal'), |
1303 | 1361 | 'OneHot' : _one_hot(), |
1304 | 1362 | 'Pack' : _pack(), |
| 1363 | + 'TensorArrayV3' : _tensor_array(), |
| 1364 | + 'TensorArrayScatterV3' : _tensor_array_scatter(), |
| 1365 | + 'TensorArrayGatherV3' : _tensor_array_gather(), |
| 1366 | + 'TensorArraySizeV3' : _tensor_array_size(), |
| 1367 | + 'TensorArrayWriteV3' : _tensor_array_write(), |
| 1368 | + 'TensorArrayReadV3' : _tensor_array_read(), |
| 1369 | + 'TensorArraySplitV3' : _tensor_array_split(), |
| 1370 | + 'TensorArrayConcatV3' : _tensor_array_concat(), |
1305 | 1371 | 'Pad' : _pad('Pad'), |
1306 | 1372 | 'PadV2' : _pad('PadV2'), |
1307 | 1373 | 'Pow' : _elemwise('power'), |
@@ -1847,6 +1913,7 @@ def __init__(self): |
1847 | 1913 | self._loops = {} |
1848 | 1914 | self._branches = {} |
1849 | 1915 | self._mod = _module.Module({}) |
| 1916 | + self._prelude = Prelude(self._mod) |
1850 | 1917 |
|
1851 | 1918 | def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): |
1852 | 1919 | """Construct relay nodes from tensorflow graph definition - GraphDef. |
@@ -2322,7 +2389,11 @@ def _convert_operator(self, op_name, inputs, attrs, |
2322 | 2389 | if op_name in identity_list: |
2323 | 2390 | sym = get_relay_op(op_name)(*inputs, **attrs) |
2324 | 2391 | elif op_name in convert_map: |
2325 | | - sym = convert_map[op_name](inputs, attrs, self._params) |
| 2392 | + if 'TensorArray' in op_name: |
| 2393 | + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) |
| 2394 | + else: |
| 2395 | + sym = convert_map[op_name](inputs, attrs, self._params) |
| 2396 | + |
2326 | 2397 | elif op_name in convert_map_rnn: |
2327 | 2398 | sym = self._convert_rnn_operator(op_name, inputs, attrs, |
2328 | 2399 | self._params, graph, |
|
0 commit comments