Skip to content

Commit

Permalink
[Frontend][TENSORFLOW] Add support for unpack with dim 0 after tensor…
Browse files Browse the repository at this point in the history
…list stack (apache#8558)

* enable testcase when tensorlist stack follows by a unpack for dim 0

* address reviews and improve the docstring
  • Loading branch information
Xingyu Zhou authored and ylc committed Sep 29, 2021
1 parent 746a315 commit 4920db9
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 54 deletions.
14 changes: 11 additions & 3 deletions python/tvm/relay/frontend/tensorflow2_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,21 @@ def _impl(inputs, attr, params, prelude):
stack_func = prelude.get_global_var("tensor_array_stack", dtype_str)
out = stack_func(inputs[0])
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape)
if "num_elements" in attr:
num_elements = attr["num_elements"]
static_tensor_array_ops = StaticTensorArrayOps(
prelude, dtype_str, input_ta_shape, num_elements
)
static_tensor_array_ops.register()
stack_func = prelude.get_global_var_static(
"tensor_array_stack", dtype_str, input_ta_shape
"tensor_array_stack", dtype_str, input_ta_shape, num_elements
)
out_tensor = stack_func(inputs[0])
out_shape = (Any(),) + input_ta_shape
out_shape = (
(num_elements,) + input_ta_shape
if num_elements and num_elements == 1
else (Any(),) + input_ta_shape
)
static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape)
static_tensor_array_ops.register()
get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape)
Expand Down
80 changes: 59 additions & 21 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude):
return None


def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding
to the canonical name"""
def _get_name_static(canonical, dtype, shape, batch_dim=None):
"""Get name for static shape tensor array op
By design, static ADT tensor in TVM has type name in the format
of static_tensor_dim0_dim1_..._dimN_t
or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item.
Parameters
----------
canonical : String
Tensor array op name
dtype : str
Data type.
shape : tuple of (int, Any) or None
Tensor array shape
batch_dim: None or int
1 if tensorlist stack only have one item.
None by default
Returns
-------
name : String
The tensor array op name
"""
dim_names = []
for dim in shape:
if isinstance(dim, Any):
Expand All @@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape):
shape_str = "scalar"
if canonical == "tensor_t":
return "static_tensor_{}_{}_t".format(dtype, shape_str)
return "{}_{}_{}".format(canonical, dtype, shape_str)
if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]:
return "{}_{}_{}".format(canonical, dtype, shape_str)
if batch_dim != 1:
return "{}_{}_{}".format(canonical, dtype, shape_str)
return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str)


class StaticTensorArrayOps(object):
"""Contains tensor array related ops for fixed rank tensor array"""

def __init__(self, prelude, dtype, shape):
def __init__(self, prelude, dtype, shape, batch_dim=None):
"""Create tensor array ops registry"""
self.prelude = prelude
self.dtype = dtype
self.shape = shape
self.batch_dim = batch_dim
self.list, self.cons, self.nil = self.prelude.mod.get_type("List")

def get_name(self, canonical):
"""Get name corresponding to the canonical name"""
return _get_name_static(canonical, self.dtype, self.shape)
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim)

def get_global_var(self, canonical):
"""Get global corresponding to the canonical name"""
return self.prelude.get_global_var_static(canonical, self.dtype, self.shape)
return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim)

def get_type(self, canonical):
"""Get type corresponding to the canonical name"""
Expand Down Expand Up @@ -262,9 +291,10 @@ def define_tensor_expand_dims(self):

# Note: we set the added axis to be Any() instead of 1 due to
# in stack op, we need to recursively concatenate.
new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim
tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(
[
Any(),
new_axis,
]
+ list(self.shape)
)
Expand Down Expand Up @@ -573,20 +603,27 @@ def define_tensor_array_stack(self):
expand_dims_var = self.get_global_var("tensor_expand_dims")

# Register tensor_concatenate for output_shape
new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim
output_shape = [
Any(),
new_axis,
] + list(self.shape)

_, _, output_ops = self._get_adt_by_shape(output_shape)
output_ops.define_tensor_concatenate()
concat_var = output_ops.get_global_var("tensor_concatenate")

tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
tensors = self.prelude.foldl(
concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims),
)
if self.batch_dim is not None and self.batch_dim == 1:
# only one element
tensors = self.prelude.id(
self.prelude.hd(tensor_array_expand_dims),
)
else:
tensors = self.prelude.foldl(
concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims),
)

output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape)
self.prelude.mod[stack_var] = Function(
[tensor_array], tensors, output_tensor_type_var(), []
Expand All @@ -599,8 +636,9 @@ def define_tensor_array_gather(self):
helper_name = self.get_name("tensor_array_gather_helper")
helper_var = self._create_global_var(helper_name)

new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim
output_shape = [
Any(),
new_axis,
] + list(self.shape)
output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape)
stack_var = self.get_global_var("tensor_array_stack")
Expand Down Expand Up @@ -668,7 +706,7 @@ def register(self):

def _get_adt_by_shape(self, shape):
"""Get ADT type and constructor with given shape."""
adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape)
adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim)
adt_ops.define_tensor_adt()
tensor_type_var = adt_ops.get_type("tensor_t")
tensor_constructor = adt_ops.get_ctor("tensor_constructor")
Expand Down Expand Up @@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype):
ty = self.get_type("tensor_t", dtype)
return self.get_ctor(ty.name_hint, canonical, dtype)

def get_name_static(self, canonical, dtype, shape):
def get_name_static(self, canonical, dtype, shape, batch_dim=None):
"""Get name corresponding to the canonical name"""
return _get_name_static(canonical, dtype, shape)
return _get_name_static(canonical, dtype, shape, batch_dim)

def get_global_var_static(self, canonical, dtype, shape):
def get_global_var_static(self, canonical, dtype, shape, batch_dim=None):
"""Get var corresponding to the canonical name"""
name = self.get_name_static(canonical, dtype, shape)
name = self.get_name_static(canonical, dtype, shape, batch_dim)
return self.mod.get_global_var(name)

def get_type_static(self, canonical, dtype, shape):
Expand Down
61 changes: 31 additions & 30 deletions tests/python/frontend/tensorflow2/test_functional_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,6 @@ def get_input(self):
in_tens[1] = np.zeros((3,), dtype="float32")
return in_tens

"""2D array as input"""

@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
Expand Down Expand Up @@ -513,8 +511,6 @@ def get_input(self):
in_tens[1, :, :] = np.zeros((3, 4), dtype="float32")
return in_tens

"""2D array as input"""

@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
Expand All @@ -531,18 +527,8 @@ def func(self, x):
run_model_graph(TensorList2D)
run_func_graph(TensorList2D, runtime="vm")

run_test(
(
3,
4,
)
)
run_test(
(
-1,
-1,
)
)
run_test((3, 4))
run_test((-1, -1))


def test_tensorlist_stack_2d():
Expand All @@ -553,8 +539,6 @@ def get_input(self):
in_tens[1, :, :] = np.zeros((3, 4), dtype="float32")
return in_tens

"""2D array as input"""

@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
Expand All @@ -570,18 +554,35 @@ def func(self, x):
run_model_graph(TensorListStack2D)
run_func_graph(TensorListStack2D, runtime="vm")

run_test(
(
3,
4,
)
)
run_test(
(
-1,
-1,
)
)
run_test((3, 4))
run_test((-1, -1))


def test_tensorlist_stack_unpack():
def run_test(elem_shape):
class TensorListStack2D(tf.Module):
def get_input(self):
in_tens = np.ones((1, 3, 4), dtype="float32")
return in_tens

@tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
tl = tf.raw_ops.TensorListReserve(
element_shape=elem_shape, num_elements=1, element_dtype=dtype
)
tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :])
output = tf.raw_ops.TensorListStack(
input_handle=tl, element_shape=elem_shape, element_dtype=dtype, num_elements=1
)
output = tf.raw_ops.Unpack(value=output, num=1, axis=0)
return output

run_model_graph(TensorListStack2D)
run_func_graph(TensorListStack2D, runtime="vm")

run_test((3, 4))
run_test((-1, -1))


if __name__ == "__main__":
Expand Down

0 comments on commit 4920db9

Please sign in to comment.