@@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude):
73
73
return None
74
74
75
75
76
- def _get_name_static (canonical , dtype , shape ):
77
- """Get name for static shape tensor array op corresponding
78
- to the canonical name"""
76
+ def _get_name_static (canonical , dtype , shape , batch_dim = None ):
77
+ """Get name for static shape tensor array op
78
+
79
+ By design, static ADT tensor in TVM has type name in the format
80
+ of static_tensor_dim0_dim1_..._dimN_t
81
+ or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item.
82
+
83
+ Parameters
84
+ ----------
85
+ canonical : String
86
+ Tensor array op name
87
+
88
+ dtype : str
89
+ Data type.
90
+
91
+ shape : tuple of (int, Any) or None
92
+ Tensor array shape
93
+
94
+ batch_dim: None or int
95
+ 1 if tensorlist stack only have one item.
96
+ None by default
97
+
98
+ Returns
99
+ -------
100
+ name : String
101
+ The tensor array op name
102
+ """
79
103
dim_names = []
80
104
for dim in shape :
81
105
if isinstance (dim , Any ):
@@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape):
89
113
shape_str = "scalar"
90
114
if canonical == "tensor_t" :
91
115
return "static_tensor_{}_{}_t" .format (dtype , shape_str )
92
- return "{}_{}_{}" .format (canonical , dtype , shape_str )
116
+ if batch_dim is None or canonical in ["tensor_constructor" , "tensor_nil" ]:
117
+ return "{}_{}_{}" .format (canonical , dtype , shape_str )
118
+ if batch_dim != 1 :
119
+ return "{}_{}_{}" .format (canonical , dtype , shape_str )
120
+ return "{}_{}_batch{}_{}" .format (canonical , dtype , str (batch_dim ), shape_str )
93
121
94
122
95
123
class StaticTensorArrayOps (object ):
96
124
"""Contains tensor array related ops for fixed rank tensor array"""
97
125
98
- def __init__ (self , prelude , dtype , shape ):
126
+ def __init__ (self , prelude , dtype , shape , batch_dim = None ):
99
127
"""Create tensor array ops registry"""
100
128
self .prelude = prelude
101
129
self .dtype = dtype
102
130
self .shape = shape
131
+ self .batch_dim = batch_dim
103
132
self .list , self .cons , self .nil = self .prelude .mod .get_type ("List" )
104
133
105
134
def get_name (self , canonical ):
106
135
"""Get name corresponding to the canonical name"""
107
- return _get_name_static (canonical , self .dtype , self .shape )
136
+ return _get_name_static (canonical , self .dtype , self .shape , self . batch_dim )
108
137
109
138
def get_global_var (self , canonical ):
110
139
"""Get global corresponding to the canonical name"""
111
- return self .prelude .get_global_var_static (canonical , self .dtype , self .shape )
140
+ return self .prelude .get_global_var_static (canonical , self .dtype , self .shape , self . batch_dim )
112
141
113
142
def get_type (self , canonical ):
114
143
"""Get type corresponding to the canonical name"""
@@ -262,9 +291,10 @@ def define_tensor_expand_dims(self):
262
291
263
292
# Note: we set the added axis to be Any() instead of 1 due to
264
293
# in stack op, we need to recursively concatenate.
294
+ new_axis = Any () if self .batch_dim is None or self .batch_dim != 1 else self .batch_dim
265
295
tensor_type_var , tensor_constructor , _ = self ._get_adt_by_shape (
266
296
[
267
- Any () ,
297
+ new_axis ,
268
298
]
269
299
+ list (self .shape )
270
300
)
@@ -573,20 +603,27 @@ def define_tensor_array_stack(self):
573
603
expand_dims_var = self .get_global_var ("tensor_expand_dims" )
574
604
575
605
# Register tensor_concatenate for output_shape
606
+ new_axis = Any () if not self .batch_dim or self .batch_dim != 1 else self .batch_dim
576
607
output_shape = [
577
- Any () ,
608
+ new_axis ,
578
609
] + list (self .shape )
579
-
580
610
_ , _ , output_ops = self ._get_adt_by_shape (output_shape )
581
611
output_ops .define_tensor_concatenate ()
582
612
concat_var = output_ops .get_global_var ("tensor_concatenate" )
583
613
584
614
tensor_array_expand_dims = self .prelude .map (expand_dims_var , tensor_array )
585
- tensors = self .prelude .foldl (
586
- concat_var ,
587
- self .prelude .hd (tensor_array_expand_dims ),
588
- self .prelude .tl (tensor_array_expand_dims ),
589
- )
615
+ if self .batch_dim is not None and self .batch_dim == 1 :
616
+ # only one element
617
+ tensors = self .prelude .id (
618
+ self .prelude .hd (tensor_array_expand_dims ),
619
+ )
620
+ else :
621
+ tensors = self .prelude .foldl (
622
+ concat_var ,
623
+ self .prelude .hd (tensor_array_expand_dims ),
624
+ self .prelude .tl (tensor_array_expand_dims ),
625
+ )
626
+
590
627
output_tensor_type_var , _ , _ = self ._get_adt_by_shape (output_shape )
591
628
self .prelude .mod [stack_var ] = Function (
592
629
[tensor_array ], tensors , output_tensor_type_var (), []
@@ -599,8 +636,9 @@ def define_tensor_array_gather(self):
599
636
helper_name = self .get_name ("tensor_array_gather_helper" )
600
637
helper_var = self ._create_global_var (helper_name )
601
638
639
+ new_axis = Any () if self .batch_dim is None or self .batch_dim != 1 else self .batch_dim
602
640
output_shape = [
603
- Any () ,
641
+ new_axis ,
604
642
] + list (self .shape )
605
643
output_tensor_type_var , _ , _ = self ._get_adt_by_shape (output_shape )
606
644
stack_var = self .get_global_var ("tensor_array_stack" )
@@ -668,7 +706,7 @@ def register(self):
668
706
669
707
def _get_adt_by_shape (self , shape ):
670
708
"""Get ADT type and constructor with given shape."""
671
- adt_ops = StaticTensorArrayOps (self .prelude , self .dtype , shape )
709
+ adt_ops = StaticTensorArrayOps (self .prelude , self .dtype , shape , self . batch_dim )
672
710
adt_ops .define_tensor_adt ()
673
711
tensor_type_var = adt_ops .get_type ("tensor_t" )
674
712
tensor_constructor = adt_ops .get_ctor ("tensor_constructor" )
@@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype):
1482
1520
ty = self .get_type ("tensor_t" , dtype )
1483
1521
return self .get_ctor (ty .name_hint , canonical , dtype )
1484
1522
1485
- def get_name_static (self , canonical , dtype , shape ):
1523
+ def get_name_static (self , canonical , dtype , shape , batch_dim = None ):
1486
1524
"""Get name corresponding to the canonical name"""
1487
- return _get_name_static (canonical , dtype , shape )
1525
+ return _get_name_static (canonical , dtype , shape , batch_dim )
1488
1526
1489
- def get_global_var_static (self , canonical , dtype , shape ):
1527
+ def get_global_var_static (self , canonical , dtype , shape , batch_dim = None ):
1490
1528
"""Get var corresponding to the canonical name"""
1491
- name = self .get_name_static (canonical , dtype , shape )
1529
+ name = self .get_name_static (canonical , dtype , shape , batch_dim )
1492
1530
return self .mod .get_global_var (name )
1493
1531
1494
1532
def get_type_static (self , canonical , dtype , shape ):
0 commit comments