@@ -117,6 +117,8 @@ def __init__(self, model, subgraph, exp_tab):
117117 'PRELU' : self .convert_prelu ,
118118 'TRANSPOSE_CONV' : self .convert_transpose_conv ,
119119 'SQUARED_DIFFERENCE' : self .convert_squared_difference ,
120+ 'GATHER' : self .convert_gather ,
121+ 'STRIDED_SLICE' : self .convert_strided_slice ,
120122 }
121123
122124 def check_unsupported_ops (self ):
@@ -792,6 +794,147 @@ def convert_not_equal(self, op):
792794 'TFlite quantized NOT_EQUAL operator is not supported yet.' )
793795 return self ._convert_elemwise (_op .not_equal , op )
794796
797+ def convert_gather (self , op ):
798+ """Method to Convert TFLite GATHER operator"""
799+ try :
800+ from tflite .BuiltinOptions import BuiltinOptions
801+ from tflite .GatherOptions import GatherOptions
802+ from tflite .TensorType import TensorType
803+ except ImportError :
804+ raise ImportError ("The tflite package must be installed" )
805+
806+ input_tensors = self .get_input_tensors (op )
807+ data = self .get_expr (input_tensors [0 ].tensor_idx )
808+
809+ indices = input_tensors [1 ]
810+ indices_type = indices .tensor .Type ()
811+ assert indices_type in (TensorType .INT32 , TensorType .INT64 )
812+ indices_type_str = self .get_tensor_type_str (indices_type )
813+ indices = self .exp_tab .new_const (self .get_tensor_value (indices ),
814+ dtype = indices_type_str )
815+
816+ assert op .BuiltinOptionsType () == BuiltinOptions .GatherOptions
817+ op_options = op .BuiltinOptions ()
818+ gather_options = GatherOptions ()
819+ gather_options .Init (op_options .Bytes , op_options .Pos )
820+ axis = gather_options .Axis ()
821+
822+ out = _op .take (data , indices , axis = axis )
823+ return out
824+
825+ def convert_strided_slice (self , op ):
826+ """Method to Convert TFLite STRIDED_SLICE operator"""
827+ try :
828+ from tflite .BuiltinOptions import BuiltinOptions
829+ from tflite .StridedSliceOptions import StridedSliceOptions
830+ except ImportError :
831+ raise ImportError ("The tflite package must be installed" )
832+
833+ input_tensors = self .get_input_tensors (op )
834+ data_expr = self .get_expr (input_tensors [0 ].tensor_idx )
835+
836+ begin = list (self .get_tensor_value (input_tensors [1 ]))
837+ end = list (self .get_tensor_value (input_tensors [2 ]))
838+ stride = list (self .get_tensor_value (input_tensors [3 ]))
839+
840+ assert op .BuiltinOptionsType () == BuiltinOptions .StridedSliceOptions
841+ op_options = op .BuiltinOptions ()
842+ options = StridedSliceOptions ()
843+ options .Init (op_options .Bytes , op_options .Pos )
844+ begin_mask = options .BeginMask ()
845+ end_mask = options .EndMask ()
846+ ellipsis_mask = options .EllipsisMask ()
847+ new_axis_mask = options .NewAxisMask ()
848+ shrink_axis_mask = options .ShrinkAxisMask ()
849+
850+ data_shape = list (input_tensors [0 ].tensor .ShapeAsNumpy ())
851+ data_dim = len (data_shape )
852+ stride_dim = len (list (input_tensors [3 ].tensor .ShapeAsNumpy ()))
853+
854+ def _transform_mask (stride_dim , ellipsis_mask ):
855+ """Handle mask inputs to create new begin, end, stride and output shape"""
856+ m_begin = [0 ] * data_dim
857+ m_end = [0 ] * data_dim
858+ m_stride = [0 ] * data_dim
859+ fshape_indices = []
860+ #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
861+ ellipsis_seen = False
862+ new_axes_after_ellipsis = 0
863+ for i in range (stride_dim ):
864+ mask = 1 << i
865+ if ellipsis_seen and (mask & new_axis_mask ) != 0 :
866+ new_axes_after_ellipsis += 1
867+ if (mask & ellipsis_mask ) != 0 :
868+ ellipsis_seen = True
869+ if not ellipsis_seen :
870+ #Used later for extending the stride attributes in the below loop.
871+ ellipsis_mask |= (1 << stride_dim )
872+ stride_dim += 1
873+ final_index = 0
874+ for index in range (stride_dim ):
875+ mask = 1 << index
876+ if mask & ellipsis_mask :
877+ #Identify the end index for applying ellipsis_mask
878+ to_index = min (((data_dim - (stride_dim - index )) + 1 \
879+ + new_axes_after_ellipsis ), data_dim )
880+ for i in range (final_index , to_index ):
881+ m_begin [final_index ] = 0
882+ m_end [final_index ] = data_shape [final_index ]
883+ m_stride [final_index ] = 1
884+ fshape_indices .append (final_index )
885+ final_index += 1
886+ elif mask & new_axis_mask :
887+ fshape_indices .append (- 1 )
888+ elif not mask & new_axis_mask :
889+ if final_index == len (m_begin ):
890+ break
891+ if mask & begin_mask :
892+ m_begin [final_index ] = data_shape [final_index ] \
893+ if stride [index ] < 0 else 0
894+ elif begin [index ]:
895+ m_begin [final_index ] = begin [index ]
896+ if mask & end_mask :
897+ m_end [final_index ] = 0 if stride [index ] < 0 \
898+ else data_shape [final_index ]
899+ elif end [index ]:
900+ m_end [final_index ] = end [index ]
901+ m_stride [final_index ] = stride [index ]
902+ if mask & shrink_axis_mask :
903+ #Tensorflow make axis with shrink_axis_mask as dimension 1
904+ m_begin [final_index ] = data_shape [final_index ] + begin [index ] \
905+ if begin [index ] < 0 else begin [index ]
906+ m_end [final_index ] = begin [index ] + 1
907+ m_stride [final_index ] = 1
908+ fshape_indices .append (- 2 )
909+ else :
910+ fshape_indices .append (final_index )
911+
912+ final_index += 1
913+ return m_begin , m_end , m_stride , fshape_indices
914+
915+ fshape_indices = None
916+ if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask :
917+ begin , end , stride , fshape_indices = _transform_mask (stride_dim , ellipsis_mask )
918+
919+ out = _op .strided_slice (data_expr , begin = begin , end = end , strides = stride )
920+ out_shape = _infer_shape (out )
921+ if not fshape_indices :
922+ fshape_indices = range (len (out_shape ))
923+
924+ #Create final output shape.
925+ final_output = []
926+ for gather_index in fshape_indices :
927+ if gather_index == - 1 :
928+ final_output .append (1 )
929+ elif gather_index == - 2 :
930+ pass
931+ else :
932+ final_output .append (out_shape [gather_index ])
933+
934+ if not final_output :
935+ return out
936+ return _op .reshape (out , newshape = tuple (final_output ))
937+
795938 def convert_zeros_like (self , op ):
796939 """Convert TFLite ZEROS LIKE"""
797940 try :
0 commit comments