@@ -271,69 +271,74 @@ def test_forward_slice():
271271# Gather
272272# ------
273273
274- def _test_gather (dshape , indices , axis , dtype ):
274+ def _test_gather (dshape , indices , axis , dtype , quantized = False , oob = False ):
275275 """ One iteration of Gather """
276- data = np .random .uniform (1 , 10 , size = dshape ).astype (dtype )
277276 indices = np .asarray (indices ).astype ('int32' )
278-
279- with tf .Graph ().as_default ():
280- in_data = array_ops .placeholder (shape = data .shape , dtype = data .dtype )
281- out = array_ops .gather (in_data , indices , axis = axis )
282- compare_tflite_with_tvm (data , 'Placeholder:0' , [in_data ], [out ])
283-
284- #Test quantized input
285- data = np .random .uniform (1 , 10 , size = dshape ).astype (np .uint8 )
277+ data = np .random .uniform (1 , 10 , size = dshape )
278+ data = data .astype (np .uint8 ) if quantized else data .astype (dtype )
286279 with tf .Graph ().as_default ():
287280 in_data = array_ops .placeholder (shape = data .shape , dtype = data .dtype , name = "in_data" )
288- out = array_ops .gather (in_data , indices , axis = axis )
289- compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ], quantized = True )
281+ if axis :
282+ out = array_ops .gather (in_data , indices , axis = axis )
283+ else :
284+ out = array_ops .gather (in_data , indices ) #tflite conversion fails for None axis
285+ input_range = {'in_data' : (- 100 , 100 )} if quantized else None
286+ try :
287+ compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ],
288+ quantized = quantized , input_range = input_range )
289+ except ValueError as e :
290+ if not oob :
291+ raise e
292+ except Exception as e :
293+ raise e
290294
291295def test_forward_gather ():
292296 """ GATHER """
293- _test_gather ((4 ,), [1 ], 0 , 'float32' )
294- _test_gather ((1 , 4 ), [0 ], 0 , 'int32' )
295- _test_gather ((4 ,), [[[1 , 0 ], [0 , 1 ]]], 0 , 'float32' )
296- _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 0 , 'int32' )
297- _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 1 , 'int32' )
298- _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 0 , 'float32' )
299- _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 0 , 'int32' )
300- _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 2 , 'int32' )
301- _test_gather ((4 , 3 , 5 , 6 ), [[2 , 1 , 0 , 0 ]], 0 , 'float32' )
297+ for quantized in [False , True ]:
298+ _test_gather ((4 ,), [1 ], 0 , 'float32' , quantized )
299+ _test_gather ((4 ,), [1 ], None , 'int32' , quantized )
300+ _test_gather ((1 , 4 ), [0 ], 0 , 'int32' , quantized )
301+ _test_gather ((4 ,), [[[1 , 0 ], [0 , 1 ]]], 0 , 'float32' , quantized )
302+ _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 1 , 'int32' , quantized )
303+ _test_gather ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], None , 'float32' , quantized )
304+ _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 0 , 'int32' , quantized )
305+ _test_gather ((3 , 3 , 3 ), [[[1 , 0 ]]], 2 , 'int32' , quantized )
306+ _test_gather ((4 , 3 , 5 , 6 ), [[2 , 1 , 0 , 0 ]], 0 , 'float32' , quantized )
307+ _test_gather ((3 , 3 , 3 ), [[[2 , 1 ]]], - 1 , 'int32' , quantized )
308+ _test_gather ((4 ,), [16 ], 0 , 'float32' , quantized , oob = True )
309+ _test_gather ((1 , 3 , 3 ), [12 ], 0 , 'int32' , quantized , oob = True )
310+ _test_gather ((1 , 3 , 3 ), [20 ], 1 , 'float32' , quantized , oob = True )
311+ _test_gather ((1 , 3 , 3 ), [20 , 20 ], 2 , 'float32' , quantized , oob = True )
302312
303313#######################################################################
304314# StridedSlice
305315# ------------
306316
307317def _test_stridedslice (ip_shape , begin , end , stride , dtype ,
308318 begin_mask = 0 , end_mask = 0 , new_axis_mask = 0 ,
309- shrink_axis_mask = 0 , ellipsis_mask = 0 ):
319+ shrink_axis_mask = 0 , ellipsis_mask = 0 , quantized = False ):
310320 """ One iteration of a Stridedslice """
311321 data = np .random .uniform (size = ip_shape ).astype (dtype )
322+ data = data .astype (np .uint8 ) if quantized else data .astype (dtype )
312323 with tf .Graph ().as_default ():
313324 in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
314325 out = array_ops .strided_slice (in_data , begin , end , stride ,
315326 begin_mask = begin_mask ,
316- end_mask = end_mask , new_axis_mask = new_axis_mask ,
317- shrink_axis_mask = shrink_axis_mask ,
318- ellipsis_mask = ellipsis_mask )
319- compare_tflite_with_tvm (data , 'in_data:0' , [in_data ], [out ])
320-
321- #Test with quantized inputs
322- data = np .random .uniform (size = ip_shape ).astype (np .uint8 )
323- with tf .Graph ().as_default ():
324- in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
325- out = array_ops .strided_slice (in_data , begin , end , stride ,
326- begin_mask = begin_mask ,
327- end_mask = end_mask , new_axis_mask = new_axis_mask ,
327+ end_mask = end_mask ,
328+ new_axis_mask = new_axis_mask ,
328329 shrink_axis_mask = shrink_axis_mask ,
329330 ellipsis_mask = ellipsis_mask )
330- compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ], quantized = True )
331+ input_range = {'in_data' : (- 100 , 100 )} if quantized else None
332+ compare_tflite_with_tvm ([data ], ['in_data:0' ], [in_data ], [out ], quantized = quantized ,
333+ input_range = input_range )
331334
332335def test_forward_stridedslice ():
333336 '''test StridedSlice'''
334- _test_stridedslice ((2 ), [1 ], [1 ], [1 ], 'float32' , shrink_axis_mask = 1 )
335- _test_stridedslice ((3 , 4 , 3 ), [1 , - 1 , 0 ], [4 , - 5 , 3 ], [2 , - 1 , 1 ], 'float32' )
336- _test_stridedslice ((3 , 4 ), [1 , 0 ], [4 , 4 ], [1 , 1 ], 'float32' , shrink_axis_mask = 1 )
337+ for quantized in [False , True ]:
338+ _test_stridedslice ((2 ), [1 ], [1 ], [1 ], 'float32' , shrink_axis_mask = 1 , quantized = quantized )
339+ _test_stridedslice ((3 , 4 , 3 ), [1 , - 1 , 0 ], [4 , - 5 , 3 ], [2 , - 1 , 1 ], 'float32' , quantized = quantized )
340+ _test_stridedslice ((3 , 4 ), [1 , 0 ], [4 , 4 ], [1 , 1 ], 'float32' , shrink_axis_mask = 0 , quantized = quantized )
341+ _test_stridedslice ((4 , 4 ), [1 , 0 ], [4 , 4 ], [1 , 1 ], 'float32' , shrink_axis_mask = 2 , quantized = quantized )
337342
338343#######################################################################
339344# transpose
0 commit comments