-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathvmap_randomness.py
52 lines (30 loc) · 937 Bytes
/
vmap_randomness.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Interplay between jit, vmap, randomness and backend
"""
import tensorcircuit as tc
K = tc.set_backend("tensorflow")
n = 10
batch = 100
print("tensorflow backend")
# has serialization issue for random generation
@K.jit
def f(a, key):
return a + K.stateful_randn(key, [n])
vf = K.jit(K.vmap(f))
key = K.get_random_state(42)
r, _, _ = tc.utils.benchmark(f, K.ones([n], dtype="float32"), key)
print(r)
r, _, _ = tc.utils.benchmark(vf, K.ones([batch, n], dtype="float32"), key)
print(r[:2])
K = tc.set_backend("jax")
print("jax backend")
@K.jit
def f2(a, key):
return a + K.stateful_randn(key, [n])
vf2 = K.jit(K.vmap(f2, vectorized_argnums=(0, 1)))
key = K.get_random_state(42)
r, _, _ = tc.utils.benchmark(f2, K.ones([n], dtype="float32"), key)
print(r)
keys = K.stack([K.get_random_state(i) for i in range(batch)])
r, _, _ = tc.utils.benchmark(vf2, K.ones([batch, n], dtype="float32"), keys)
print(r[:2])