|
24 | 24 | # Numpy support |
25 | 25 | import numpy as np |
26 | 26 |
|
| 27 | +import pdb |
| 28 | + |
27 | 29 | import tvm |
| 30 | + |
| 31 | +from tvm.relay.prelude import Prelude |
| 32 | +from topi.util import get_const_tuple |
| 33 | + |
28 | 34 | from .. import analysis |
29 | 35 | from .. import expr as _expr |
30 | 36 | from .. import op as _op |
@@ -506,6 +512,69 @@ def _impl(inputs, attr, params): |
506 | 512 | return _op.concatenate(inputs_reshaped, axis) |
507 | 513 | return _impl |
508 | 514 |
|
| 515 | +def _tensor_array(): |
| 516 | + def _impl(inputs, attr, params, prelude): |
| 517 | + return prelude.tensor_array(_op.take(inputs[0], tvm.relay.const(0))) |
| 518 | + return _impl |
| 519 | + |
| 520 | +def _tensor_array_scatter(): |
| 521 | + def _impl(inputs, attr, params, prelude): |
| 522 | + values = None |
| 523 | + import pdb |
| 524 | + # pdb.set_trace() |
| 525 | + if len(inputs[2].type_annotation.shape) == 1: |
| 526 | + pass |
| 527 | + elif len(inputs[2].type_annotation.shape) == 2: |
| 528 | + values = prelude.tensor_array_unstack_tensor2(inputs[2]) |
| 529 | + |
| 530 | + return prelude.tensor_array_scatter(inputs[0], inputs[1], values) |
| 531 | + return _impl |
| 532 | + |
| 533 | +def _tensor_array_gather(): |
| 534 | + def _impl(inputs, attr, params, prelude): |
| 535 | + return prelude.tensor_array_gather(inputs[2], inputs[1]) |
| 536 | + return _impl |
| 537 | + |
| 538 | +def _tensor_array_size(): |
| 539 | + def _impl(inputs, attr, params, prelude): |
| 540 | + return prelude.tensor_array_size(inputs[0]) |
| 541 | + return _impl |
| 542 | + |
| 543 | +def _tensor_array_write(): |
| 544 | + def _impl(inputs, attr, params, prelude): |
| 545 | + import pdb |
| 546 | + # pdb.set_trace() |
| 547 | + if len(inputs[2].type_annotation.shape) == 2: |
| 548 | + v = prelude.tensor2(inputs[2]) |
| 549 | + elif len(inputs[2].type_annotation.shape) == 3: |
| 550 | + v = prelude.tensor3(inputs[2]) |
| 551 | + return prelude.tensor_array_write(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) |
| 552 | + return _impl |
| 553 | + |
| 554 | +def _tensor_array_read(): |
| 555 | + def _impl(inputs, attr, params, prelude): |
| 556 | + import pdb |
| 557 | + # pdb.set_trace() |
| 558 | + return prelude.tensor_array_read(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) |
| 559 | + return _impl |
| 560 | + |
| 561 | +def _tensor_array_split(): |
| 562 | + def _impl(inputs, attr, params, prelude): |
| 563 | + import pdb |
| 564 | + if len(inputs[1].type_annotation.shape) == 2: |
| 565 | + v = prelude.tensor2(inputs[1]) |
| 566 | + elif len(inputs[1].type_annotation.shape) == 3: |
| 567 | + v = prelude.tensor3(inputs[1]) |
| 568 | + # pdb.set_trace() |
| 569 | + lengths = _op.cast(inputs[2], 'int32') |
| 570 | + return prelude.tensor_array_split(inputs[0], v, lengths) |
| 571 | + return _impl |
| 572 | + |
| 573 | +def _tensor_array_concat(): |
| 574 | + def _impl(inputs, attr, params, prelude): |
| 575 | + return prelude.tensor_array_concat(inputs[1]) |
| 576 | + return _impl |
| 577 | + |
509 | 578 | def _tile(): |
510 | 579 | def _impl(inputs, attr, params): |
511 | 580 | reps = params[inputs.pop().name_hint].asnumpy() |
@@ -968,6 +1037,7 @@ def _impl(inputs, attr, params): |
968 | 1037 |
|
969 | 1038 | def _range(): |
970 | 1039 | def _impl(inputs, attr, params): |
| 1040 | + pdb.set_trace() |
971 | 1041 | start = params.pop(inputs[0].name_hint).asnumpy()[0] |
972 | 1042 | limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ |
973 | 1043 | if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] |
@@ -1285,6 +1355,14 @@ def _impl(inputs, attr, params): |
1285 | 1355 | 'Neg' : AttrCvt('negative'), |
1286 | 1356 | 'NotEqual' : _broadcast('not_equal'), |
1287 | 1357 | 'Pack' : _pack(), |
| 1358 | + 'TensorArrayV3' : _tensor_array(), |
| 1359 | + 'TensorArrayScatterV3' : _tensor_array_scatter(), |
| 1360 | + 'TensorArrayGatherV3' : _tensor_array_gather(), |
| 1361 | + 'TensorArraySizeV3' : _tensor_array_size(), |
| 1362 | + 'TensorArrayWriteV3' : _tensor_array_write(), |
| 1363 | + 'TensorArrayReadV3' : _tensor_array_read(), |
| 1364 | + 'TensorArraySplitV3' : _tensor_array_split(), |
| 1365 | + 'TensorArrayConcatV3' : _tensor_array_concat(), |
1288 | 1366 | 'Pad' : _pad('Pad'), |
1289 | 1367 | 'PadV2' : _pad('PadV2'), |
1290 | 1368 | 'Pow' : _elemwise('power'), |
@@ -1830,6 +1908,7 @@ def __init__(self): |
1830 | 1908 | self._loops = {} |
1831 | 1909 | self._branches = {} |
1832 | 1910 | self._mod = _module.Module({}) |
| 1911 | + self._prelude = Prelude(self._mod) |
1833 | 1912 |
|
1834 | 1913 | def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): |
1835 | 1914 | """Construct relay nodes from tensorflow graph definition - GraphDef. |
@@ -2306,7 +2385,11 @@ def _convert_operator(self, op_name, inputs, attrs, |
2306 | 2385 | if op_name in identity_list: |
2307 | 2386 | sym = get_relay_op(op_name)(*inputs, **attrs) |
2308 | 2387 | elif op_name in convert_map: |
2309 | | - sym = convert_map[op_name](inputs, attrs, self._params) |
| 2388 | + if 'TensorArray' in op_name: |
| 2389 | + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) |
| 2390 | + else: |
| 2391 | + sym = convert_map[op_name](inputs, attrs, self._params) |
| 2392 | + |
2310 | 2393 | elif op_name in convert_map_rnn: |
2311 | 2394 | sym = self._convert_rnn_operator(op_name, inputs, attrs, |
2312 | 2395 | self._params, graph, |
|
0 commit comments