|
24 | 24 | # Numpy support |
25 | 25 | import numpy as np |
26 | 26 |
|
| 27 | +import pdb |
| 28 | + |
27 | 29 | import tvm |
| 30 | + |
| 31 | +from tvm.relay.prelude import Prelude |
| 32 | +from topi.util import get_const_tuple |
| 33 | + |
28 | 34 | from .. import analysis |
29 | 35 | from .. import expr as _expr |
30 | 36 | from .. import op as _op |
@@ -506,6 +512,64 @@ def _impl(inputs, attr, params): |
506 | 512 | return _op.concatenate(inputs_reshaped, axis) |
507 | 513 | return _impl |
508 | 514 |
|
| 515 | +def _tensor_array(): |
| 516 | + def _impl(inputs, attr, params, prelude): |
| 517 | + return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0))) |
| 518 | + return _impl |
| 519 | + |
| 520 | +def _tensor_array_scatter(): |
| 521 | + def _impl(inputs, attr, params, prelude): |
| 522 | + values = None |
| 523 | + import pdb |
| 524 | + # pdb.set_trace() |
| 525 | + if len(inputs[2].type_annotation.shape) == 1: |
| 526 | + pass |
| 527 | + elif len(inputs[2].type_annotation.shape) == 2: |
| 528 | + values = prelude.tensor_array_unstack_tensor2(inputs[2]) |
| 529 | + |
| 530 | + return prelude.tensor_array_scatter(inputs[0], inputs[1], values) |
| 531 | + return _impl |
| 532 | + |
| 533 | +def _tensor_array_gather(): |
| 534 | + def _impl(inputs, attr, params, prelude): |
| 535 | + return prelude.tensor_array_gather(inputs[2], inputs[1]) |
| 536 | + return _impl |
| 537 | + |
| 538 | +def _tensor_array_size(): |
| 539 | + def _impl(inputs, attr, params, prelude): |
| 540 | + return prelude.tensor_array_size(inputs[0]) |
| 541 | + return _impl |
| 542 | + |
| 543 | +def _tensor_array_write(): |
| 544 | + def _impl(inputs, attr, params, prelude): |
| 545 | + import pdb |
| 546 | + # pdb.set_trace() |
| 547 | + if len(inputs[2].type_annotation.shape) == 2: |
| 548 | + v = prelude.tensor2(inputs[2]) |
| 549 | + elif len(inputs[2].type_annotation.shape) == 3: |
| 550 | + v = prelude.tensor3(inputs[2]) |
| 551 | + return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) |
| 552 | + return _impl |
| 553 | + |
| 554 | +def _tensor_array_read(): |
| 555 | + def _impl(inputs, attr, params, prelude): |
| 556 | + import pdb |
| 557 | + # pdb.set_trace() |
| 558 | + return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) |
| 559 | + return _impl |
| 560 | + |
| 561 | +def _tensor_array_split(): |
| 562 | + def _impl(inputs, attr, params, prelude): |
| 563 | + import pdb |
| 564 | + if len(inputs[1].type_annotation.shape) == 2: |
| 565 | + v = prelude.tensor2(inputs[1]) |
| 566 | + elif len(inputs[1].type_annotation.shape) == 3: |
| 567 | + v = prelude.tensor3(inputs[1]) |
| 568 | + # pdb.set_trace() |
| 569 | + lengths = _op.cast(inputs[2], 'int32') |
| 570 | + return prelude.tensor_array_split(inputs[0], v, lengths) |
| 571 | + return _impl |
| 572 | + |
509 | 573 | def _tile(): |
510 | 574 | def _impl(inputs, attr, params): |
511 | 575 | reps = params[inputs.pop().name_hint].asnumpy() |
@@ -968,6 +1032,7 @@ def _impl(inputs, attr, params): |
968 | 1032 |
|
969 | 1033 | def _range(): |
970 | 1034 | def _impl(inputs, attr, params): |
| 1035 | + pdb.set_trace() |
971 | 1036 | start = params.pop(inputs[0].name_hint).asnumpy()[0] |
972 | 1037 | limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ |
973 | 1038 | if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] |
@@ -1285,6 +1350,13 @@ def _impl(inputs, attr, params): |
1285 | 1350 | 'Neg' : AttrCvt('negative'), |
1286 | 1351 | 'NotEqual' : _broadcast('not_equal'), |
1287 | 1352 | 'Pack' : _pack(), |
| 1353 | + 'TensorArrayV3' : _tensor_array(), |
| 1354 | + 'TensorArrayScatterV3' : _tensor_array_scatter(), |
| 1355 | + 'TensorArrayGatherV3' : _tensor_array_gather(), |
| 1356 | + 'TensorArraySizeV3' : _tensor_array_size(), |
| 1357 | + 'TensorArrayWriteV3' : _tensor_array_write(), |
| 1358 | + 'TensorArrayReadV3' : _tensor_array_read(), |
| 1359 | + 'TensorArraySplitV3' : _tensor_array_split(), |
1288 | 1360 | 'Pad' : _pad('Pad'), |
1289 | 1361 | 'PadV2' : _pad('PadV2'), |
1290 | 1362 | 'Pow' : _elemwise('power'), |
@@ -1830,6 +1902,7 @@ def __init__(self): |
1830 | 1902 | self._loops = {} |
1831 | 1903 | self._branches = {} |
1832 | 1904 | self._mod = _module.Module({}) |
| 1905 | + self._prelude = Prelude(self._mod) |
1833 | 1906 |
|
1834 | 1907 | def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): |
1835 | 1908 | """Construct relay nodes from tensorflow graph definition - GraphDef. |
@@ -2306,7 +2379,11 @@ def _convert_operator(self, op_name, inputs, attrs, |
2306 | 2379 | if op_name in identity_list: |
2307 | 2380 | sym = get_relay_op(op_name)(*inputs, **attrs) |
2308 | 2381 | elif op_name in convert_map: |
2309 | | - sym = convert_map[op_name](inputs, attrs, self._params) |
| 2382 | + if 'TensorArray' in op_name: |
| 2383 | + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) |
| 2384 | + else: |
| 2385 | + sym = convert_map[op_name](inputs, attrs, self._params) |
| 2386 | + |
2310 | 2387 | elif op_name in convert_map_rnn: |
2311 | 2388 | sym = self._convert_rnn_operator(op_name, inputs, attrs, |
2312 | 2389 | self._params, graph, |
|
0 commit comments