-
Notifications
You must be signed in to change notification settings - Fork 0
/
addition_broadcast_benchmarking.py
58 lines (47 loc) · 1.42 KB
/
addition_broadcast_benchmarking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import time
import random
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from pbar import printProgress
def add_number(length,repeat = 1):
value_to_add = 1432
loops_times = []
numpy_times = []
jax_times = []
samples = []
start = time.time()
end = time.time()
for n in range(1,length,1):
printProgress(n/length)
i = n*10
samples.append(i)
array = [x for x in range(i)]
np_array = np.array(array)
jnp_array = jnp.array(array)
for _ in range(repeat):
start = time.time()
for j,_ in enumerate(array):
array[j] += value_to_add
end = time.time()
loops_times.append(end - start)
for _ in range(repeat):
start = time.time()
np_array = np.add(np_array,value_to_add)
end = time.time()
numpy_times.append(end - start)
for _ in range(repeat):
start = time.time()
jnp_array = jnp.add(jnp_array, value_to_add)
end = time.time()
jax_times.append(end - start)
plt.title('add_number')
plt.xlabel('Array size')
plt.ylabel('Time')
plt.plot(samples,loops_times, label=f"loops_{repeat}")
plt.plot(samples,numpy_times, label=f"numpy_{repeat}")
plt.plot(samples,jax_times , label=f"jax_{repeat}")
add_number(500,20)
add_number(500,40)
plt.legend()
plt.show()