-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
test_forward.py
5722 lines (4767 loc) · 199 KB
/
test_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, inconsistent-return-statements
"""
TFLite testcases
================
This article is a test script to test TFLite operator with Relay.
"""
from __future__ import print_function
from functools import partial
import platform
import os
import tempfile
import typing
from packaging import version as package_version
import pytest
import numpy as np
from PIL import Image
from tflite.BuiltinOperator import BuiltinOperator
try:
import tensorflow.compat.v1 as tf
# tensorflow.python.framework.ops module itself is not part of
# TensorFlow's public API: the precise contents of that module
# may vary from one version to the next
import tensorflow.compat.v1 as ops
except ImportError:
import tensorflow as tf
import tensorflow as ops
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
from tensorflow import raw_ops
try:
from tensorflow import lite as interpreter_wrapper
except ImportError:
from tensorflow.contrib import lite as interpreter_wrapper
import tvm
import tvm.relay.testing.tf as tf_testing
from tvm.contrib.download import download_testdata
from tvm import relay, ir
from tvm.contrib import graph_executor
from relay.utils.tag_span import _set_span, _create_span, _verify_structural_equal_with_span
#######################################################################
# Generic run functions for TVM & TFLite
# --------------------------------------
def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x
#######################################################################
# Get a real image for e2e testing
# --------------------------------
def get_real_image(im_height, im_width, quantized=True):
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
img_name = "elephant-299.jpg"
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module="data")
image = Image.open(img_path).resize((im_height, im_width))
x = np.array(image).astype("uint8") if quantized else np.array(image).astype("float32")
data = np.reshape(x, (1, im_height, im_width, 3))
return data
def pre_processed_image(height, width):
"""Image preprocessed"""
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
img_name = "elephant-299.jpg"
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module="data")
image = tf.io.read_file(img_path)
image = tf.image.decode_jpeg(image, channels=3)
with tf.name_scope("eval_image"):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.central_crop(image, central_fraction=0.875)
# Resize the image to the specified height and width.
image = tf.image.resize(image, [height, width], align_corners=False)
image = tf.expand_dims(image, axis=0)
return image
def get_real_image_object_detection(im_height, im_width):
repo_base = "https://github.com/dmlc/web-data/raw/main/gluoncv/detection/"
img_name = "street_small.jpg"
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module="data")
image = Image.open(img_path).resize((im_height, im_width))
x = np.array(image).astype("uint8")
data = np.reshape(x, (1, im_height, im_width, 3))
return data
def vmobj_to_list(obj):
"""Converts TVM objects returned by VM execution to Python List."""
if isinstance(obj, tvm.nd.NDArray):
return [obj.numpy().tolist()]
elif isinstance(obj, tvm.runtime.container.ADT):
result = []
for f in obj:
result.extend(vmobj_to_list(f))
return result
elif isinstance(obj, tvm.relay.backend.interpreter.ConstructorValue):
if obj.constructor.name_hint == "Cons":
t_l = vmobj_to_list(obj.fields[1])
h_d = vmobj_to_list(obj.fields[0])
h_d.extend(t_l)
return h_d
elif obj.constructor.name_hint == "Nil":
return []
elif "tensor_nil" in obj.constructor.name_hint:
return [0]
elif "tensor" in obj.constructor.name_hint:
return [obj.fields[0].numpy()]
else:
raise RuntimeError(f"Unknown object type: {obj.constructor.name_hint}")
else:
raise RuntimeError(f"Unknown object type: {type(obj)}")
def _quantize_keras_model(
keras_model,
representative_data_gen,
is_float_input=False,
is_float_output=False,
int_quant_dtype=tf.int8,
):
"""Utility function to quantize a Keras model using TFLite converter."""
converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
if int_quant_dtype == tf.int8:
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
inference_dtype = tf.uint8
elif int_quant_dtype == tf.int16:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [
tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
]
inference_dtype = tf.uint16
else:
raise RuntimeError(
f"Invalid quantized dtype {int_quant_dtype}. Supported types: int8, int16."
)
# NOTE: If representative dataset is provided, and inference input type is not set,
# then converter will self add quant & dequant Op accordingly.
if not is_float_input:
converter.inference_input_type = inference_dtype
if not is_float_output:
converter.inference_output_type = inference_dtype
return converter.convert()
def run_tvm_graph(
tflite_model_buf,
input_data,
input_node,
num_output=1,
target="llvm",
out_names=None,
mode="graph_executor",
op_converter=relay.frontend.tflite.OperatorConverter,
):
"""Generic function to compile on relay and execute on tvm"""
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except ImportError as exc:
raise ImportError("The tflite package must be installed") from exc
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
shape_dict = {}
dtype_dict = {}
for i, node in enumerate(input_node):
shape_dict[node] = input_data[i].shape
dtype_dict[node] = input_data[i].dtype.name
with tvm.testing.disable_span_filling():
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter
)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter
)
tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"])
if mode in ["debug", "vm"]:
inputs = []
for param in mod["main"].params:
found = False
for i, n in enumerate(input_node):
if n == param.name_hint:
found = True
inputs.append(tvm.nd.array(input_data[i]))
break
# Interpreter doesn't bind constants, so still need to find in params
if not found:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm").evaluate()(
*inputs
)
return vmobj_to_list(result)
else:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target, params=params)
dev = tvm.device(target, 0)
m = graph_executor.GraphModule(lib["default"](dev))
# set inputs
for i, node in enumerate(input_node):
m.set_input(node, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(
out_names
), f"out_names: {out_names} num_output: {num_output}"
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.numpy())
return tvm_output_list
def run_tflite_graph(tflite_model_buf, input_data):
"""Generic function to execute TFLite"""
input_data = convert_to_list(input_data)
interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for i, input_detail in enumerate(input_details):
interpreter.resize_tensor_input(input_detail["index"], input_data[i].shape)
interpreter.allocate_tensors()
# set input
assert len(input_data) == len(input_details)
for i, input_detail in enumerate(input_details):
interpreter.set_tensor(input_detail["index"], input_data[i])
# Run
interpreter.invoke()
# get output
tflite_output = []
for _, output_detail in enumerate(output_details):
tflite_output.append(interpreter.get_tensor(output_detail["index"]))
return tflite_output
def compare_tflite_with_tvm(
in_data: typing.List[np.ndarray],
in_name: typing.List[str],
input_tensors: typing.List,
output_tensors: typing.List,
init_global_variables: bool = False,
out_names=None,
quantized=False,
input_range=None,
mode="graph_executor",
experimental_new_converter=False,
fp16_quantized=False,
int_quant_dtype=tf.uint8,
):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
out_names = convert_to_list(out_names)
in_node = [0] * len(in_name)
for i, _ in enumerate(in_name):
in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i]
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
# convert to tflite model
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
if len(input_tensors) > 1:
if len(input_tensors[0].shape) <= 4 and len(input_tensors[1].shape) <= 4:
converter._experimental_disable_batchmatmul_unfold = True
else:
converter._experimental_disable_batchmatmul_unfold = False
converter.experimental_new_converter = experimental_new_converter
if quantized:
if int_quant_dtype == tf.int16:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
]
elif int_quant_dtype == tf.int8:
converter.inference_type = tf.lite.constants.INT8
else:
# default to int8 quantization
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
input_stats = {}
# calculate the mean and quantization scale for every input tensor,
# with respect to its fp32 input range, defined in fake_quant.
# s = 255/(fmax-fmin); m = -fmin*s (the zero point)
for i in input_arrays:
try:
quant_scale = 255 / (input_range[i][1] - input_range[i][0])
except ZeroDivisionError:
print("Min and max of the input range for tensor " + i + " can't be equal")
mean = -input_range[i][0] * quant_scale
input_stats[i] = (mean, quant_scale)
converter.quantized_input_stats = input_stats
elif fp16_quantized:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model_buffer = converter.convert()
tflite_output = run_tflite_graph(tflite_model_buffer, in_data)
for device in ["llvm"]:
_ = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print(f"Skip because {device} is not enabled")
continue
tvm_output = run_tvm_graph(
tflite_model_buffer,
in_data,
in_node,
target=device,
num_output=len(out_names),
out_names=out_names,
mode=mode,
)
# WARNING: the results could well be random values clipped to 0 or 255 because of badly
# tuned output range for the specific operator. While adding test ensure that we aren't
# getting only clipped values in output tensors that still pass the assertion.
# For reference see _test_elemwise_qnn_out_range()
if quantized and not fp16_quantized:
for i, _ in enumerate(tflite_output):
# allow absolute tolerance of 1 in the quantized results
tvm.testing.assert_allclose(
tflite_output[i], # pylint: disable=unnecessary-list-index-lookup
tvm_output[i],
atol=1,
rtol=1e-5,
)
else:
for i, _ in enumerate(tflite_output):
tvm.testing.assert_allclose(
tflite_output[i], # pylint: disable=unnecessary-list-index-lookup
tvm_output[i],
atol=1e-5,
rtol=1e-5,
)
def with_fused_activation_function(input_tensor, fn_name):
"""Fused activation function"""
if fn_name is None or fn_name == "NONE":
return input_tensor
if fn_name == "RELU":
return nn_ops.relu(input_tensor)
if fn_name == "RELU6":
return nn_ops.relu6(input_tensor)
if fn_name == "RELU_N1_TO_1":
return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
if fn_name == "TANH":
return math_ops.tanh(input_tensor)
raise AssertionError(f"Unknown fused_activation_function {fn_name}")
def _test_split(in_shape, axis, num_splits, dtype):
"""internal split tester taking as parameters in_shape, number of tensors to split into
and dtype (data type)"""
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data")
out = array_ops.split(in_data, num_splits, axis=axis)
num_splits = len(num_splits) if isinstance(num_splits, list) else num_splits
out_names = ["out_" + str(n) + ":0" for n in range(num_splits)]
compare_tflite_with_tvm([np_data], ["in_data"], [in_data], out, out_names=out_names)
def test_forward_split():
"""test split layer"""
# rank 1
_test_split((3,), 0, 1, "float32")
_test_split((3,), 0, 3, "float32")
_test_split((6,), 0, 3, "float32")
# rank 2
_test_split((6, 2), 0, 3, "float32")
_test_split((2, 6), 1, 6, "float32")
# rank 3
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_split((6, 2, 4), 0, 2, "int32")
_test_split((2, 6, 4), 1, 3, "float32")
_test_split((2, 4, 6), 2, 1, "float32")
# rank 4
_test_split((6, 1, 3, 5), 0, 3, "float32")
_test_split((1, 6, 3, 5), 1, 3, "float32")
_test_split((1, 3, 6, 5), 2, 3, "float32")
_test_split((1, 3, 5, 6), 3, 3, "float32")
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, "float32")
_test_split((1, 6, 3, 5), -3, 3, "float32")
_test_split((1, 3, 6, 5), -2, 3, "float32")
_test_split((1, 3, 5, 6), -1, 3, "float32")
# size_splits split
_test_split((6,), 0, [1, 2, 3], "float32")
_test_split((3, 6, 4), -2, [1, 4, 1], "float32")
#######################################################################
# slice
# -----
def _test_slice(data, begin, size):
"""One iteration of SLICE"""
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.slice(in_data, begin, size)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
def test_forward_slice():
"""SLICE"""
_test_slice(np.arange(4, dtype=np.float32).reshape((4,)), begin=[0], size=[2])
_test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
# tflite 1.13 outputs nonsense values if size[i] == -1
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
_test_slice(np.arange(5, dtype=np.int32).reshape((5,)), begin=[4], size=[-1])
#######################################################################
# Topk
# ----
def _test_topk(in_shape, k=1):
"""One iteration of TOPK"""
data = np.random.uniform(size=in_shape).astype("float32")
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_ops.top_k(in_data, k, name="TopK")
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out[0]])
def test_forward_topk():
"""TOPK"""
_test_topk((3,), 1)
_test_topk((3,), 3)
_test_topk((3, 5, 7), 3)
_test_topk((3, 5, 7), 3)
#######################################################################
# Gather
# ------
def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_idx=False):
"""One iteration of Gather"""
indices = np.asarray(indices).astype("int32")
data = np.random.uniform(1, 10, size=dshape)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
if wrap_idx:
in_name = "in_indices"
indices_expr = array_ops.placeholder(
shape=indices.shape, dtype=indices.dtype, name=in_name
)
in_tensor_name = [in_name + ":0"]
in_indices = [indices_expr]
else:
indices_expr = indices
indices = []
in_tensor_name = []
in_indices = []
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
if axis:
out = array_ops.gather(in_data, indices_expr, axis=axis)
else:
out = array_ops.gather(in_data, indices_expr) # tflite conversion fails for None axis
input_range = {"in_data": (-100, 100)} if quantized else None
try:
compare_tflite_with_tvm(
[data] + indices,
["in_data:0"] + in_tensor_name,
[in_data] + in_indices,
[out],
quantized=quantized,
input_range=input_range,
)
except ValueError as exc:
if not oob:
raise exc
except Exception as exc:
raise exc
def test_forward_gather():
"""GATHER"""
for quantized in [False, True]:
for wrap_idx in [False, True]:
_test_gather((4,), [1], 0, "float32", quantized, wrap_idx)
_test_gather((4,), [1], None, "int32", quantized, wrap_idx)
_test_gather((1, 4), [0], 0, "int32", quantized, wrap_idx)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, "float32", quantized, wrap_idx)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, "int32", quantized, wrap_idx)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, "float32", quantized, wrap_idx)
_test_gather((3, 3, 3), [[[1, 0]]], 0, "int32", quantized, wrap_idx)
_test_gather((3, 3, 3), [[[1, 0]]], 2, "int32", quantized, wrap_idx)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32", quantized, wrap_idx)
_test_gather((3, 3, 3), [[[2, 1]]], -1, "int32", quantized, wrap_idx)
# Out of boundary error cannot be tested with wrapped index
_test_gather((4,), [16], 0, "float32", quantized, oob=True)
_test_gather((1, 3, 3), [12], 0, "int32", quantized, oob=True)
_test_gather((1, 3, 3), [20], 1, "float32", quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, "float32", quantized, oob=True)
#######################################################################
# Gather_ND
# ---------
def _test_gather_nd(data, indices):
"""One iteration of GATHER_ND"""
with tf.Graph().as_default():
in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype, name="indices")
out = tf.gather_nd(in_data, indices_data)
compare_tflite_with_tvm(
[data, indices], ["data:0", "indices:0"], [in_data, indices_data], [out]
)
def test_forward_gather_nd():
"""GATHER_ND"""
_test_gather_nd(
np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype("float32"),
np.asarray([[0, 1], [1, 0]]).astype("int32"),
)
_test_gather_nd(
np.reshape(np.arange(30), [5, 6]).astype("int32"), np.asarray([[1, 2]]).astype("int32")
)
_test_gather_nd(
np.reshape(np.arange(12), [2, 3, 2]).astype("int32"),
np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype("int32"),
)
_test_gather_nd(
np.reshape(np.arange(4), [4]).astype("float32"), np.asarray([1]).astype("int32")
)
_test_gather_nd(
np.reshape(np.arange(4), [1, 4]).astype("float32"), np.asarray([0]).astype("int32")
)
_test_gather_nd(
np.reshape(np.arange(4), [1, 4]).astype("float32"), np.asarray([0, 3]).astype("int32")
)
#######################################################################
# StridedSlice
# ------------
def _test_stridedslice(
ip_shape,
begin,
end,
stride,
dtype,
begin_mask=0,
end_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
ellipsis_mask=0,
quantized=False,
):
"""One iteration of a Stridedslice"""
data = np.random.uniform(size=ip_shape).astype(dtype)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(
in_data,
begin,
end,
stride,
begin_mask=begin_mask,
end_mask=end_mask,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask,
)
input_range = {"in_data": (-100, 100)} if quantized else None
compare_tflite_with_tvm(
[data], ["in_data:0"], [in_data], [out], quantized=quantized, input_range=input_range
)
def test_forward_stridedslice():
"""test StridedSlice"""
for quantized in [False, True]:
_test_stridedslice(
(1, 3, 3),
[0, 0, 0],
[3, 3, 3],
[1, 1, 1],
"float32",
shrink_axis_mask=7,
quantized=quantized,
)
_test_stridedslice(
(1, 3, 3),
[0, 0, 0],
[3, 3, 3],
[1, 1, 1],
"float32",
shrink_axis_mask=5,
quantized=quantized,
)
_test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1, quantized=quantized)
_test_stridedslice(
(3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32", quantized=quantized
)
_test_stridedslice(
(3, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=0, quantized=quantized
)
_test_stridedslice(
(4, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2, quantized=quantized
)
_test_stridedslice(
(3, 4), [-1, 0], [0, 3], [1, 1], "float32", shrink_axis_mask=1, quantized=quantized
)
#######################################################################
# transpose
# ---------
def _test_forward_transpose(ishape, axes=()):
data = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
if not axes:
out = array_ops.transpose(in_data)
else:
out = array_ops.transpose(in_data, axes)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
def test_forward_transpose():
_test_forward_transpose((2, 2))
_test_forward_transpose((2, 3, 4))
_test_forward_transpose((7, 8, 8, 10))
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4), (0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), ())
#######################################################################
# Cast
# ----
def _test_cast(data, cast_dtype, use_mlir=False):
"""One iteration of CAST"""
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = math_ops.cast(in_data, cast_dtype)
compare_tflite_with_tvm(
data, "Placeholder:0", [in_data], [out], experimental_new_converter=use_mlir
)
def test_forward_cast():
"""CAST"""
for use_mlir in [False, True]:
_test_cast(
np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32, use_mlir=use_mlir
)
_test_cast(
np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8, use_mlir=use_mlir
)
_test_cast(
np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64, use_mlir=use_mlir
)
#######################################################################
# Batch Mat Mul
# ----
def _test_batch_matmul(
a_shape, b_shape, dtype, out_dtype, adjoint_a=False, adjoint_b=False, quantized=False
):
with tf.Graph().as_default():
a = array_ops.placeholder(shape=a_shape, dtype=dtype, name="A")
b = array_ops.placeholder(shape=b_shape, dtype=dtype, name="B")
print(tf.__version__)
result = raw_ops.BatchMatMulV3(
x=a, y=b, Tout=out_dtype, adj_x=adjoint_a, adj_y=adjoint_b, name="batchmatmul"
)
input_range = {"A": (-100, 100), "B": (-100, 100)} if quantized else None
a_np = np.random.uniform(high=5.0, size=a_shape).astype(dtype)
b_np = np.random.uniform(high=5.0, size=b_shape).astype(dtype)
compare_tflite_with_tvm(
[a_np, b_np],
[a.name, b.name],
[a, b],
[result],
experimental_new_converter=True,
quantized=quantized,
input_range=input_range,
)
@pytest.mark.parametrize("config", [("int8", "int32", True), ("float32", "float32", False)])
def test_forward_batch_matmul(config):
"""BATCH_MAT_MUL"""
_test_batch_matmul(
(3, 5, 4), (3, 4, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
)
_test_batch_matmul(
(3, 5, 4),
(3, 4, 5),
dtype=config[0],
out_dtype=config[1],
adjoint_a=True,
adjoint_b=True,
quantized=config[2],
)
_test_batch_matmul(
(3, 5, 4),
(3, 5, 4),
dtype=config[0],
out_dtype=config[1],
adjoint_a=True,
adjoint_b=False,
quantized=config[2],
)
_test_batch_matmul(
(2, 3, 5, 4),
(1, 3, 5, 4),
dtype=config[0],
out_dtype=config[1],
adjoint_a=True,
adjoint_b=False,
quantized=config[2],
)
_test_batch_matmul(
(3, 5, 4),
(3, 5, 4),
dtype=config[0],
out_dtype=config[1],
adjoint_a=False,
adjoint_b=True,
quantized=config[2],
)
_test_batch_matmul(
(2, 3, 5, 4),
(1, 3, 5, 4),
dtype=config[0],
out_dtype=config[1],
adjoint_a=False,
adjoint_b=True,
quantized=config[2],
)
_test_batch_matmul(
(3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
)
# BatchMatMul doesn't support larger than 4D tensors
# _test_batch_matmul(
# (2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
# )
#######################################################################
# Tile
# ----
def _test_forward_tile(in_shape, reps, dtype):
data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.tile(in_data, reps)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
def test_forward_tile():
_test_forward_tile((2,), (3,), "int32")
_test_forward_tile((2, 2), (2, 3), "float32")
######################################################################
# BatchToSpaceND
# --------------
def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype="int32"):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)
out = array_ops.batch_to_space_nd(in_data, block_shape, crops)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
def test_forward_batch_to_space_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
_test_batch_to_space_nd(input_shape=[4, 1, 1, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
_test_batch_to_space_nd(input_shape=[4, 1, 1, 3], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
_test_batch_to_space_nd(input_shape=[4, 2, 2, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
_test_batch_to_space_nd(input_shape=[4, 3, 3, 1], block_shape=[2, 2], crops=[[0, 1], [0, 1]])
######################################################################
# SpaceToBatchND
# --------------
def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype="int32"):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)
out = array_ops.space_to_batch_nd(in_data, block_shape, paddings)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
def test_forward_space_to_batch_nd():
# test cases: https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
_test_space_to_batch_nd(input_shape=[1, 2, 2, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
_test_space_to_batch_nd(input_shape=[1, 2, 2, 3], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
_test_space_to_batch_nd(input_shape=[1, 4, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
_test_space_to_batch_nd(input_shape=[2, 2, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [2, 0]])
#######################################################################
# Pooling
# -------
def _test_pooling_iteration(input_shape, **kwargs):
"""One iteration of pool operation with given shapes and attributes"""
x = -np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=input_shape, dtype="float32")
out = nn_ops.pool(in_data, **kwargs)
compare_tflite_with_tvm(x, "Placeholder:0", [in_data], [out])
def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling():
"""Pooling"""
for pool_type in ["AVG", "MAX"]:
_test_pooling(
input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding="SAME",
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1],
)
_test_pooling(
input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding="SAME",
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1],
)
_test_pooling(
input_shape=[2, 9, 10, 2],
window_shape=[2, 1],
padding="SAME",
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1],
)
_test_pooling(
input_shape=[2, 10, 9, 2],
window_shape=[2, 3],
padding="SAME",
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[2, 1],
)
def _test_l2_pool2d(input_shape, ksize, strides, padding, data_format, fused_func_name=None):
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
with tf.Graph().as_default():
in_data = tf.placeholder(dtype=tf.float32, name="input", shape=input_shape)
out = tf.sqrt(
tf.nn.avg_pool(
tf.square(in_data),
ksize=ksize,
strides=strides,
padding=padding,
data_format=data_format,
)
)
out = with_fused_activation_function(out, fused_func_name)
compare_tflite_with_tvm(x, "input", [in_data], [out])
def test_forward_l2_pool2d():
_test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], "SAME", "NHWC", "RELU6")
_test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], "SAME", "NHWC", "RELU6")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], "SAME", "NHWC")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], "SAME", "NHWC")
_test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], "VALID", "NHWC", "RELU")
_test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], "VALID", "NHWC")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], "VALID", "NHWC")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], "VALID", "NHWC", "RELU6")
#######################################################################
# Convolution
# -----------