Skip to content

Commit 496f10b

Browse files
author
Xingyu Zhou
authored
[Frontend][TENSORFLOW] Add support for unpack with dim 0 after tensorlist stack (#8558)
* enable testcase when tensorlist stack follows by a unpack for dim 0 * address reviews and improve the docstring
1 parent 8954968 commit 496f10b

File tree

3 files changed

+101
-54
lines changed

3 files changed

+101
-54
lines changed

python/tvm/relay/frontend/tensorflow2_ops.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,21 @@ def _impl(inputs, attr, params, prelude):
133133
stack_func = prelude.get_global_var("tensor_array_stack", dtype_str)
134134
out = stack_func(inputs[0])
135135
else:
136-
static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape)
136+
if "num_elements" in attr:
137+
num_elements = attr["num_elements"]
138+
static_tensor_array_ops = StaticTensorArrayOps(
139+
prelude, dtype_str, input_ta_shape, num_elements
140+
)
137141
static_tensor_array_ops.register()
138142
stack_func = prelude.get_global_var_static(
139-
"tensor_array_stack", dtype_str, input_ta_shape
143+
"tensor_array_stack", dtype_str, input_ta_shape, num_elements
140144
)
141145
out_tensor = stack_func(inputs[0])
142-
out_shape = (Any(),) + input_ta_shape
146+
out_shape = (
147+
(num_elements,) + input_ta_shape
148+
if num_elements and num_elements == 1
149+
else (Any(),) + input_ta_shape
150+
)
143151
static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape)
144152
static_tensor_array_ops.register()
145153
get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape)

python/tvm/relay/prelude.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude):
7373
return None
7474

7575

76-
def _get_name_static(canonical, dtype, shape):
77-
"""Get name for static shape tensor array op corresponding
78-
to the canonical name"""
76+
def _get_name_static(canonical, dtype, shape, batch_dim=None):
77+
"""Get name for static shape tensor array op
78+
79+
By design, static ADT tensor in TVM has type name in the format
80+
of static_tensor_dim0_dim1_..._dimN_t
81+
or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item.
82+
83+
Parameters
84+
----------
85+
canonical : String
86+
Tensor array op name
87+
88+
dtype : str
89+
Data type.
90+
91+
shape : tuple of (int, Any) or None
92+
Tensor array shape
93+
94+
batch_dim: None or int
95+
1 if tensorlist stack only have one item.
96+
None by default
97+
98+
Returns
99+
-------
100+
name : String
101+
The tensor array op name
102+
"""
79103
dim_names = []
80104
for dim in shape:
81105
if isinstance(dim, Any):
@@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape):
89113
shape_str = "scalar"
90114
if canonical == "tensor_t":
91115
return "static_tensor_{}_{}_t".format(dtype, shape_str)
92-
return "{}_{}_{}".format(canonical, dtype, shape_str)
116+
if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]:
117+
return "{}_{}_{}".format(canonical, dtype, shape_str)
118+
if batch_dim != 1:
119+
return "{}_{}_{}".format(canonical, dtype, shape_str)
120+
return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str)
93121

94122

95123
class StaticTensorArrayOps(object):
96124
"""Contains tensor array related ops for fixed rank tensor array"""
97125

98-
def __init__(self, prelude, dtype, shape):
126+
def __init__(self, prelude, dtype, shape, batch_dim=None):
99127
"""Create tensor array ops registry"""
100128
self.prelude = prelude
101129
self.dtype = dtype
102130
self.shape = shape
131+
self.batch_dim = batch_dim
103132
self.list, self.cons, self.nil = self.prelude.mod.get_type("List")
104133

105134
def get_name(self, canonical):
106135
"""Get name corresponding to the canonical name"""
107-
return _get_name_static(canonical, self.dtype, self.shape)
136+
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim)
108137

109138
def get_global_var(self, canonical):
110139
"""Get global corresponding to the canonical name"""
111-
return self.prelude.get_global_var_static(canonical, self.dtype, self.shape)
140+
return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim)
112141

113142
def get_type(self, canonical):
114143
"""Get type corresponding to the canonical name"""
@@ -262,9 +291,10 @@ def define_tensor_expand_dims(self):
262291

263292
# Note: we set the added axis to be Any() instead of 1 due to
264293
# in stack op, we need to recursively concatenate.
294+
new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim
265295
tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(
266296
[
267-
Any(),
297+
new_axis,
268298
]
269299
+ list(self.shape)
270300
)
@@ -573,20 +603,27 @@ def define_tensor_array_stack(self):
573603
expand_dims_var = self.get_global_var("tensor_expand_dims")
574604

575605
# Register tensor_concatenate for output_shape
606+
new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim
576607
output_shape = [
577-
Any(),
608+
new_axis,
578609
] + list(self.shape)
579-
580610
_, _, output_ops = self._get_adt_by_shape(output_shape)
581611
output_ops.define_tensor_concatenate()
582612
concat_var = output_ops.get_global_var("tensor_concatenate")
583613

584614
tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
585-
tensors = self.prelude.foldl(
586-
concat_var,
587-
self.prelude.hd(tensor_array_expand_dims),
588-
self.prelude.tl(tensor_array_expand_dims),
589-
)
615+
if self.batch_dim is not None and self.batch_dim == 1:
616+
# only one element
617+
tensors = self.prelude.id(
618+
self.prelude.hd(tensor_array_expand_dims),
619+
)
620+
else:
621+
tensors = self.prelude.foldl(
622+
concat_var,
623+
self.prelude.hd(tensor_array_expand_dims),
624+
self.prelude.tl(tensor_array_expand_dims),
625+
)
626+
590627
output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape)
591628
self.prelude.mod[stack_var] = Function(
592629
[tensor_array], tensors, output_tensor_type_var(), []
@@ -599,8 +636,9 @@ def define_tensor_array_gather(self):
599636
helper_name = self.get_name("tensor_array_gather_helper")
600637
helper_var = self._create_global_var(helper_name)
601638

639+
new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim
602640
output_shape = [
603-
Any(),
641+
new_axis,
604642
] + list(self.shape)
605643
output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape)
606644
stack_var = self.get_global_var("tensor_array_stack")
@@ -668,7 +706,7 @@ def register(self):
668706

669707
def _get_adt_by_shape(self, shape):
670708
"""Get ADT type and constructor with given shape."""
671-
adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape)
709+
adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim)
672710
adt_ops.define_tensor_adt()
673711
tensor_type_var = adt_ops.get_type("tensor_t")
674712
tensor_constructor = adt_ops.get_ctor("tensor_constructor")
@@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype):
14821520
ty = self.get_type("tensor_t", dtype)
14831521
return self.get_ctor(ty.name_hint, canonical, dtype)
14841522

1485-
def get_name_static(self, canonical, dtype, shape):
1523+
def get_name_static(self, canonical, dtype, shape, batch_dim=None):
14861524
"""Get name corresponding to the canonical name"""
1487-
return _get_name_static(canonical, dtype, shape)
1525+
return _get_name_static(canonical, dtype, shape, batch_dim)
14881526

1489-
def get_global_var_static(self, canonical, dtype, shape):
1527+
def get_global_var_static(self, canonical, dtype, shape, batch_dim=None):
14901528
"""Get var corresponding to the canonical name"""
1491-
name = self.get_name_static(canonical, dtype, shape)
1529+
name = self.get_name_static(canonical, dtype, shape, batch_dim)
14921530
return self.mod.get_global_var(name)
14931531

14941532
def get_type_static(self, canonical, dtype, shape):

tests/python/frontend/tensorflow2/test_functional_models.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,6 @@ def get_input(self):
484484
in_tens[1] = np.zeros((3,), dtype="float32")
485485
return in_tens
486486

487-
"""2D array as input"""
488-
489487
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)])
490488
def func(self, x):
491489
dtype = tf.float32
@@ -513,8 +511,6 @@ def get_input(self):
513511
in_tens[1, :, :] = np.zeros((3, 4), dtype="float32")
514512
return in_tens
515513

516-
"""2D array as input"""
517-
518514
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)])
519515
def func(self, x):
520516
dtype = tf.float32
@@ -531,18 +527,8 @@ def func(self, x):
531527
run_model_graph(TensorList2D)
532528
run_func_graph(TensorList2D, runtime="vm")
533529

534-
run_test(
535-
(
536-
3,
537-
4,
538-
)
539-
)
540-
run_test(
541-
(
542-
-1,
543-
-1,
544-
)
545-
)
530+
run_test((3, 4))
531+
run_test((-1, -1))
546532

547533

548534
def test_tensorlist_stack_2d():
@@ -553,8 +539,6 @@ def get_input(self):
553539
in_tens[1, :, :] = np.zeros((3, 4), dtype="float32")
554540
return in_tens
555541

556-
"""2D array as input"""
557-
558542
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)])
559543
def func(self, x):
560544
dtype = tf.float32
@@ -570,18 +554,35 @@ def func(self, x):
570554
run_model_graph(TensorListStack2D)
571555
run_func_graph(TensorListStack2D, runtime="vm")
572556

573-
run_test(
574-
(
575-
3,
576-
4,
577-
)
578-
)
579-
run_test(
580-
(
581-
-1,
582-
-1,
583-
)
584-
)
557+
run_test((3, 4))
558+
run_test((-1, -1))
559+
560+
561+
def test_tensorlist_stack_unpack():
562+
def run_test(elem_shape):
563+
class TensorListStack2D(tf.Module):
564+
def get_input(self):
565+
in_tens = np.ones((1, 3, 4), dtype="float32")
566+
return in_tens
567+
568+
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4), dtype=tf.float32)])
569+
def func(self, x):
570+
dtype = tf.float32
571+
tl = tf.raw_ops.TensorListReserve(
572+
element_shape=elem_shape, num_elements=1, element_dtype=dtype
573+
)
574+
tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :])
575+
output = tf.raw_ops.TensorListStack(
576+
input_handle=tl, element_shape=elem_shape, element_dtype=dtype, num_elements=1
577+
)
578+
output = tf.raw_ops.Unpack(value=output, num=1, axis=0)
579+
return output
580+
581+
run_model_graph(TensorListStack2D)
582+
run_func_graph(TensorListStack2D, runtime="vm")
583+
584+
run_test((3, 4))
585+
run_test((-1, -1))
585586

586587

587588
if __name__ == "__main__":

0 commit comments

Comments
 (0)