@@ -294,69 +294,74 @@ def test_forward_topk():
294294# Gather
295295# ------
296296
297- def _test_gather (dshape , indices , axis , dtype ):
297+ def _test_gather (dshape , indices , axis , dtype , quantized = False , oob = False ):
298298 """ One iteration of Gather """
299- data = np .random .uniform (1 , 10 , size = dshape ).astype (dtype )
300299 indices = np .asarray (indices ).astype ('int32' )
301-
302- with tf .Graph ().as_default ():
303- in_data = array_ops .placeholder (shape = data .shape , dtype = data .dtype )
304- out = array_ops .gather (in_data , indices , axis = axis )
305- compare_tflite_with_tvm (data , 'Placeholder:0' , [in_data ], [out ])
306-
307- #Test quantized input
308- data = np .random .uniform (1 , 10 , size = dshape ).astype (np .uint8 )
300+ data = np .random .uniform (1 , 10 , size = dshape )
301+ data = data .astype (np .uint8 ) if quantized else data .astype (dtype )
309302 with tf .Graph ().as_default ():
310303 in_data = array_ops .placeholder (shape = data .shape , dtype = data .dtype , name = "in_data" )
311- out = array_ops .gather (in_data , indices , axis = axis )
312- compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ], quantized = True )
304+ if axis :
305+ out = array_ops .gather (in_data , indices , axis = axis )
306+ else :
307+ out = array_ops .gather (in_data , indices ) #tflite conversion fails for None axis
308+ input_range = {'in_data' : (- 100 , 100 )} if quantized else None
309+ try :
310+ compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ],
311+ quantized = quantized , input_range = input_range )
312+ except ValueError as e :
313+ if not oob :
314+ raise e
315+ except Exception as e :
316+ raise e
313317
314318def test_forward_gather ():
315319 """ GATHER """
316- _test_gather ((4 ,), [1 ], 0 , 'float32' )
317- _test_gather ((1 , 4 ), [0 ], 0 , 'int32' )
318- _test_gather ((4 ,), [[[1 , 0 ], [0 , 1 ]]], 0 , 'float32' )
319- _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 0 , 'int32' )
320- _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 1 , 'int32' )
321- _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 0 , 'float32' )
322- _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 0 , 'int32' )
323- _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 2 , 'int32' )
324- _test_gather ((4 , 3 , 5 , 6 ), [[2 , 1 , 0 , 0 ]], 0 , 'float32' )
320+ for quantized in [False , True ]:
321+ _test_gather ((4 ,), [1 ], 0 , 'float32' , quantized )
322+ _test_gather ((4 ,), [1 ], None , 'int32' , quantized )
323+ _test_gather ((1 , 4 ), [0 ], 0 , 'int32' , quantized )
324+ _test_gather ((4 ,), [[[1 , 0 ], [0 , 1 ]]], 0 , 'float32' , quantized )
325+ _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 1 , 'int32' , quantized )
326+ _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], None , 'float32' , quantized )
327+ _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 0 , 'int32' , quantized )
328+ _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 2 , 'int32' , quantized )
329+ _test_gather ((4 , 3 , 5 , 6 ), [[2 , 1 , 0 , 0 ]], 0 , 'float32' , quantized )
330+ _test_gather ((3 , 3 , 3 ), [[[2 , 1 ]]], - 1 , 'int32' , quantized )
331+ _test_gather ((4 ,), [16 ], 0 , 'float32' , quantized , oob = True )
332+ _test_gather ((1 , 3 , 3 ), [12 ], 0 , 'int32' , quantized , oob = True )
333+ _test_gather ((1 , 3 , 3 ), [20 ], 1 , 'float32' , quantized , oob = True )
334+ _test_gather ((1 , 3 , 3 ), [20 , 20 ], 2 , 'float32' , quantized , oob = True )
325335
326336#######################################################################
327337# StridedSlice
328338# ------------
329339
330340def _test_stridedslice (ip_shape , begin , end , stride , dtype ,
331341 begin_mask = 0 , end_mask = 0 , new_axis_mask = 0 ,
332- shrink_axis_mask = 0 , ellipsis_mask = 0 ):
342+ shrink_axis_mask = 0 , ellipsis_mask = 0 , quantized = False ):
333343 """ One iteration of a Stridedslice """
334344 data = np .random .uniform (size = ip_shape ).astype (dtype )
345+ data = data .astype (np .uint8 ) if quantized else data .astype (dtype )
335346 with tf .Graph ().as_default ():
336347 in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
337348 out = array_ops .strided_slice (in_data , begin , end , stride ,
338349 begin_mask = begin_mask ,
339- end_mask = end_mask , new_axis_mask = new_axis_mask ,
340- shrink_axis_mask = shrink_axis_mask ,
341- ellipsis_mask = ellipsis_mask )
342- compare_tflite_with_tvm (data , 'in_data:0' , [in_data ], [out ])
343-
344- #Test with quantized inputs
345- data = np .random .uniform (size = ip_shape ).astype (np .uint8 )
346- with tf .Graph ().as_default ():
347- in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
348- out = array_ops .strided_slice (in_data , begin , end , stride ,
349- begin_mask = begin_mask ,
350- end_mask = end_mask , new_axis_mask = new_axis_mask ,
350+ end_mask = end_mask ,
351+ new_axis_mask = new_axis_mask ,
351352 shrink_axis_mask = shrink_axis_mask ,
352353 ellipsis_mask = ellipsis_mask )
353- compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ], quantized = True )
354+ input_range = {'in_data' : (- 100 , 100 )} if quantized else None
355+ compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ], quantized = quantized ,
356+ input_range = input_range )
354357
355358def test_forward_stridedslice ():
356359 '''test StridedSlice'''
357- _test_stridedslice ((2 ), [1 ], [1 ], [1 ], 'float32' , shrink_axis_mask = 1 )
358- _test_stridedslice ((3 , 4 , 3 ), [1 , - 1 , 0 ], [4 , - 5 , 3 ], [2 , - 1 , 1 ], 'float32' )
359- _test_stridedslice ((3 , 4 ), [1 , 0 ], [4 , 4 ], [1 , 1 ], 'float32' , shrink_axis_mask = 1 )
360+ for quantized in [False , True ]:
361+ _test_stridedslice ((2 ), [1 ], [1 ], [1 ], 'float32' , shrink_axis_mask = 1 , quantized = quantized )
362+ _test_stridedslice ((3 , 4 , 3 ), [1 , - 1 , 0 ], [4 , - 5 , 3 ], [2 , - 1 , 1 ], 'float32' , quantized = quantized )
363+ _test_stridedslice ((3 , 4 ), [1 , 0 ], [4 , 4 ], [1 , 1 ], 'float32' , shrink_axis_mask = 0 , quantized = quantized )
364+ _test_stridedslice ((4 , 4 ), [1 , 0 ], [4 , 4 ], [1 , 1 ], 'float32' , shrink_axis_mask = 2 , quantized = quantized )
360365
361366#######################################################################
362367# transpose
0 commit comments