@@ -112,6 +112,8 @@ def __init__(self, model, subgraph, exp_tab):
112112 'PRELU' : self .convert_prelu ,
113113 'TRANSPOSE_CONV' : self .convert_transpose_conv ,
114114 'SQUARED_DIFFERENCE' : self .convert_squared_difference ,
115+ 'GATHER' : self .convert_gather ,
116+ 'STRIDED_SLICE' : self .convert_strided_slice ,
115117 }
116118
117119 def check_unsupported_ops (self ):
@@ -747,6 +749,158 @@ def convert_squared_difference(self, op):
747749 out = _op .power (difference , relay .const (2 , exp_type ))
748750 return out
749751
752+ def convert_gather (self , op ):
753+ """Method to Convert TFLite Gather operator"""
754+ # Check if the input tensor is quantized, call QNN op
755+ if self .is_quantized (op ):
756+ raise tvm .error .OpNotImplemented (
757+ 'TFlite quantized gather operator is not supported yet.' )
758+ input_tensors = self .get_input_tensors (op )
759+
760+ try :
761+ from tflite .BuiltinOptions import BuiltinOptions
762+ from tflite .GatherOptions import GatherOptions
763+ from tflite .TensorType import TensorType
764+ except ImportError :
765+ raise ImportError ("The tflite package must be installed" )
766+
767+ assert op .BuiltinOptionsType () == BuiltinOptions .GatherOptions
768+ op_options = op .BuiltinOptions ()
769+ gather_options = GatherOptions ()
770+ gather_options .Init (op_options .Bytes , op_options .Pos )
771+ axis = gather_options .Axis ()
772+
773+ data = self .get_expr (input_tensors [0 ].tensor_idx )
774+
775+ indices = input_tensors [1 ]
776+ indices_type = indices .tensor .Type ()
777+
778+ assert indices_type in (TensorType .INT32 , TensorType .INT64 )
779+ indices_type_str = self .get_tensor_type_str (indices_type )
780+ indices = self .exp_tab .new_const (self .get_tensor_value (indices ),
781+ dtype = indices_type_str )
782+ out = _op .take (data , indices , axis = axis )
783+ return out
784+
785+ def convert_strided_slice (self , op ):
786+ """Method to Convert TFLite Strided Slice operator"""
787+ # Check if the input tensor is quantized, call QNN op
788+ if self .is_quantized (op ):
789+ raise tvm .error .OpNotImplemented (
790+ 'TFlite quantized strided slice operator is not supported yet.' )
791+ input_tensors = self .get_input_tensors (op )
792+
793+ try :
794+ from tflite .BuiltinOptions import BuiltinOptions
795+ from tflite .StridedSliceOptions import StridedSliceOptions
796+ except ImportError :
797+ raise ImportError ("The tflite package must be installed" )
798+
799+ data_expr = self .get_expr (input_tensors [0 ].tensor_idx )
800+
801+ begin = list (self .get_tensor_value (input_tensors [1 ]))
802+ end = list (self .get_tensor_value (input_tensors [2 ]))
803+ stride = list (self .get_tensor_value (input_tensors [3 ]))
804+
805+ assert op .BuiltinOptionsType () == BuiltinOptions .StridedSliceOptions
806+ op_options = op .BuiltinOptions ()
807+ options = StridedSliceOptions ()
808+ options .Init (op_options .Bytes , op_options .Pos )
809+ begin_mask = options .BeginMask ()
810+ end_mask = options .EndMask ()
811+ ellipsis_mask = options .EllipsisMask ()
812+ new_axis_mask = options .NewAxisMask ()
813+ shrink_axis_mask = options .ShrinkAxisMask ()
814+
815+ data_shape = list (input_tensors [0 ].tensor .ShapeAsNumpy ())
816+
817+ data_dim = len (data_shape )
818+ stride_dim = len (list (input_tensors [3 ].tensor .ShapeAsNumpy ()))
819+
820+ def _transform_mask (stride_dim , ellipsis_mask ):
821+ """Handle mask inputs to create new begin, end, stride and output shape"""
822+ m_begin = [0 ] * data_dim
823+ m_end = [0 ] * data_dim
824+ m_stride = [0 ] * data_dim
825+ fshape_indices = []
826+ #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
827+ ellipsis_seen = False
828+ new_axes_after_ellipsis = 0
829+ for i in range (stride_dim ):
830+ mask = 1 << i
831+ if ellipsis_seen and (mask & new_axis_mask ) != 0 :
832+ new_axes_after_ellipsis += 1
833+ if (mask & ellipsis_mask ) != 0 :
834+ ellipsis_seen = True
835+ if not ellipsis_seen :
836+ #Used later for extending the stride attributes in the below loop.
837+ ellipsis_mask |= (1 << stride_dim )
838+ stride_dim += 1
839+ final_index = 0
840+ for index in range (stride_dim ):
841+ mask = 1 << index
842+ if mask & ellipsis_mask :
843+ #Identify the end index for applying ellipsis_mask
844+ to_index = min (((data_dim - (stride_dim - index )) + 1 \
845+ + new_axes_after_ellipsis ), data_dim )
846+ for i in range (final_index , to_index ):
847+ m_begin [final_index ] = 0
848+ m_end [final_index ] = data_shape [final_index ]
849+ m_stride [final_index ] = 1
850+ fshape_indices .append (final_index )
851+ final_index += 1
852+ elif mask & new_axis_mask :
853+ fshape_indices .append (- 1 )
854+ elif not mask & new_axis_mask :
855+ if final_index == len (m_begin ):
856+ break
857+ if mask & begin_mask :
858+ m_begin [final_index ] = data_shape [final_index ] \
859+ if stride [index ] < 0 else 0
860+ elif begin [index ]:
861+ m_begin [final_index ] = begin [index ]
862+ if mask & end_mask :
863+ m_end [final_index ] = 0 if stride [index ] < 0 \
864+ else data_shape [final_index ]
865+ elif end [index ]:
866+ m_end [final_index ] = end [index ]
867+ m_stride [final_index ] = stride [index ]
868+ if mask & shrink_axis_mask :
869+ #Tensorflow make axis with shrink_axis_mask as dimension 1
870+ m_begin [final_index ] = data_shape [final_index ] + begin [index ] \
871+ if begin [index ] < 0 else begin [index ]
872+ m_end [final_index ] = begin [index ] + 1
873+ m_stride [final_index ] = 1
874+ fshape_indices .append (- 2 )
875+ else :
876+ fshape_indices .append (final_index )
877+
878+ final_index += 1
879+ return m_begin , m_end , m_stride , fshape_indices
880+
881+ fshape_indices = None
882+ if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask :
883+ begin , end , stride , fshape_indices = _transform_mask (stride_dim , ellipsis_mask )
884+
885+ out = _op .strided_slice (data_expr , begin = begin , end = end , strides = stride )
886+ out_shape = _infer_shape (out )
887+ if not fshape_indices :
888+ fshape_indices = range (len (out_shape ))
889+
890+ #Create final output shape.
891+ final_output = []
892+ for gather_index in fshape_indices :
893+ if gather_index == - 1 :
894+ final_output .append (1 )
895+ elif gather_index == - 2 :
896+ pass
897+ else :
898+ final_output .append (out_shape [gather_index ])
899+
900+ if not final_output :
901+ return out
902+ return _op .reshape (out , newshape = tuple (final_output ))
903+
750904 def convert_zeros_like (self , op ):
751905 """Convert TFLite ZEROS LIKE"""
752906 try :
0 commit comments