-
Notifications
You must be signed in to change notification settings - Fork 18
/
benchmark_xla.py
59 lines (41 loc) · 1.31 KB
/
benchmark_xla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Script to benchmark the regular MAXIM model in TF and its JiT-compiled variant.
Expected outputs (benchmarked on my Mac locally):
```
Benchmarking TF model...
Average latency (seconds): 3.1694554823999987.
Benchmarking Jit-compiled TF model...
Average latency (seconds): 1.2475706969000029.
```
"""
import timeit
import numpy as np
import tensorflow as tf
from create_maxim_model import Model
INPUT_RESOLUTION = 256
MAXIM_S1 = Model("S-1")
DUMMY_INPUTS = tf.ones((1, INPUT_RESOLUTION, INPUT_RESOLUTION, 3))
def benchmark_regular_model():
# Warmup
print("Benchmarking TF model...")
for _ in range(2):
_ = MAXIM_S1(DUMMY_INPUTS, training=False)
# Timing
tf_runtimes = timeit.repeat(
lambda: MAXIM_S1(DUMMY_INPUTS, training=False), number=1, repeat=10
)
print(f"Average latency (seconds): {np.mean(tf_runtimes)}.")
@tf.function(jit_compile=True)
def infer():
return MAXIM_S1(DUMMY_INPUTS, training=False)
def benchmark_xla_model():
# Warmup
print("Benchmarking Jit-compiled TF model...")
for _ in range(2):
_ = infer()
# Timing
tf_runtimes = timeit.repeat(lambda: infer(), number=1, repeat=10)
print(f"Average latency (seconds): {np.mean(tf_runtimes)}.")
if __name__ == "__main__":
benchmark_regular_model()
benchmark_xla_model()