@@ -112,6 +112,8 @@ def __init__(self, model, subgraph, exp_tab):
112112 'PRELU' : self .convert_prelu ,
113113 'TRANSPOSE_CONV' : self .convert_transpose_conv ,
114114 'SQUARED_DIFFERENCE' : self .convert_squared_difference ,
115+ 'GATHER' : self .convert_gather ,
116+ 'STRIDED_SLICE' : self .convert_strided_slice ,
115117 }
116118
117119 def check_unsupported_ops (self ):
@@ -747,6 +749,156 @@ def convert_squared_difference(self, op):
747749 out = _op .power (difference , relay .const (2 , exp_type ))
748750 return out
749751
752+ def convert_gather (self , op ):
753+ # Check if the input tensor is quantized, call QNN op
754+ if self .is_quantized (op ):
755+ raise tvm .error .OpNotImplemented (
756+ 'TFlite quantized gather operator is not supported yet.' )
757+ input_tensors = self .get_input_tensors (op )
758+
759+ try :
760+ from tflite .BuiltinOptions import BuiltinOptions
761+ from tflite .GatherOptions import GatherOptions
762+ from tflite .TensorType import TensorType
763+ except ImportError :
764+ raise ImportError ("The tflite package must be installed" )
765+
766+ assert op .BuiltinOptionsType () == BuiltinOptions .GatherOptions
767+ op_options = op .BuiltinOptions ()
768+ gather_options = GatherOptions ()
769+ gather_options .Init (op_options .Bytes , op_options .Pos )
770+ axis = gather_options .Axis ()
771+
772+ data = self .get_expr (input_tensors [0 ].tensor_idx )
773+
774+ indices = input_tensors [1 ]
775+ indices_type = indices .tensor .Type ()
776+
777+ assert indices_type in (TensorType .INT32 , TensorType .INT64 )
778+ indices_type_str = self .get_tensor_type_str (indices_type )
779+ indices = self .exp_tab .new_const (self .get_tensor_value (indices ),
780+ dtype = indices_type_str )
781+ out = _op .take (data , indices , axis = axis )
782+ return out
783+
784+ def convert_strided_slice (self , op ):
785+ # Check if the input tensor is quantized, call QNN op
786+ if self .is_quantized (op ):
787+ raise tvm .error .OpNotImplemented (
788+ 'TFlite quantized strided slice operator is not supported yet.' )
789+ input_tensors = self .get_input_tensors (op )
790+
791+ try :
792+ from tflite .BuiltinOptions import BuiltinOptions
793+ from tflite .StridedSliceOptions import StridedSliceOptions
794+ except ImportError :
795+ raise ImportError ("The tflite package must be installed" )
796+
797+ data_expr = self .get_expr (input_tensors [0 ].tensor_idx )
798+
799+ begin = list (self .get_tensor_value (input_tensors [1 ]))
800+ end = list (self .get_tensor_value (input_tensors [2 ]))
801+ stride = list (self .get_tensor_value (input_tensors [3 ]))
802+
803+ assert op .BuiltinOptionsType () == BuiltinOptions .StridedSliceOptions
804+ op_options = op .BuiltinOptions ()
805+ options = StridedSliceOptions ()
806+ options .Init (op_options .Bytes , op_options .Pos )
807+ begin_mask = options .BeginMask ()
808+ end_mask = options .EndMask ()
809+ ellipsis_mask = options .EllipsisMask ()
810+ new_axis_mask = options .NewAxisMask ()
811+ shrink_axis_mask = options .ShrinkAxisMask ()
812+
813+ data_shape = list (input_tensors [0 ].tensor .ShapeAsNumpy ())
814+
815+ data_dim = len (data_shape )
816+ stride_dim = len (list (input_tensors [3 ].tensor .ShapeAsNumpy ()))
817+
818+ def _transform_mask (stride_dim , ellipsis_mask ):
819+ """Handle mask inputs to create new begin, end, stride and output shape"""
820+ m_begin = [0 ] * data_dim
821+ m_end = [0 ] * data_dim
822+ m_stride = [0 ] * data_dim
823+ fshape_indices = []
824+ #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
825+ ellipsis_seen = False
826+ new_axes_after_ellipsis = 0
827+ for i in range (stride_dim ):
828+ mask = 1 << i
829+ if ellipsis_seen and (mask & new_axis_mask ) != 0 :
830+ new_axes_after_ellipsis += 1
831+ if (mask & ellipsis_mask ) != 0 :
832+ ellipsis_seen = True
833+ if not ellipsis_seen :
834+ #Used later for extending the stride attributes in the below loop.
835+ ellipsis_mask |= (1 << stride_dim )
836+ stride_dim += 1
837+ final_index = 0
838+ for index in range (stride_dim ):
839+ mask = 1 << index
840+ if mask & ellipsis_mask :
841+ #Identify the end index for applying ellipsis_mask
842+ to_index = min (((data_dim - (stride_dim - index )) + 1 \
843+ + new_axes_after_ellipsis ), data_dim )
844+ for i in range (final_index , to_index ):
845+ m_begin [final_index ] = 0
846+ m_end [final_index ] = data_shape [final_index ]
847+ m_stride [final_index ] = 1
848+ fshape_indices .append (final_index )
849+ final_index += 1
850+ elif mask & new_axis_mask :
851+ fshape_indices .append (- 1 )
852+ elif not mask & new_axis_mask :
853+ if final_index == len (m_begin ):
854+ break
855+ if mask & begin_mask :
856+ m_begin [final_index ] = data_shape [final_index ] \
857+ if stride [index ] < 0 else 0
858+ elif begin [index ]:
859+ m_begin [final_index ] = begin [index ]
860+ if mask & end_mask :
861+ m_end [final_index ] = 0 if stride [index ] < 0 \
862+ else data_shape [final_index ]
863+ elif end [index ]:
864+ m_end [final_index ] = end [index ]
865+ m_stride [final_index ] = stride [index ]
866+ if mask & shrink_axis_mask :
867+ #Tensorflow make axis with shrink_axis_mask as dimension 1
868+ m_begin [final_index ] = data_shape [final_index ] + begin [index ] \
869+ if begin [index ] < 0 else begin [index ]
870+ m_end [final_index ] = begin [index ] + 1
871+ m_stride [final_index ] = 1
872+ fshape_indices .append (- 2 )
873+ else :
874+ fshape_indices .append (final_index )
875+
876+ final_index += 1
877+ return m_begin , m_end , m_stride , fshape_indices
878+
879+ fshape_indices = None
880+ if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask :
881+ begin , end , stride , fshape_indices = _transform_mask (stride_dim , ellipsis_mask )
882+
883+ out = _op .strided_slice (data_expr , begin = begin , end = end , strides = stride )
884+ out_shape = _infer_shape (out )
885+ if not fshape_indices :
886+ fshape_indices = range (len (out_shape ))
887+
888+ #Create final output shape.
889+ final_output = []
890+ for gather_index in fshape_indices :
891+ if gather_index == - 1 :
892+ final_output .append (1 )
893+ elif gather_index == - 2 :
894+ pass
895+ else :
896+ final_output .append (out_shape [gather_index ])
897+
898+ if not final_output :
899+ return out
900+ return _op .reshape (out , newshape = tuple (final_output ))
901+
750902 def convert_zeros_like (self , op ):
751903 """Convert TFLite ZEROS LIKE"""
752904 try :
0 commit comments