1
+ import jax
2
+ import jax .numpy as jnp
3
+ from functools import partial
4
+
5
+ from wyckoff import fc_mask_table
6
+ from von_mises import sample_von_mises
7
+
8
+
9
+ get_fc_mask = lambda g , w : jnp .logical_and ((w > 0 )[:, None ], fc_mask_table [g - 1 , w ])
10
+
11
+ def make_mcmc_step (params , n_max , atom_types , atom_mask = None , constraints = None ):
12
+
13
+ if atom_mask is None or jnp .all (atom_mask == 0 ):
14
+ atom_mask = jnp .ones ((n_max , atom_types ))
15
+
16
+ if constraints is None :
17
+ constraints = jnp .arange (0 , n_max , 1 )
18
+
19
+ def update_A (i , A , a , constraints ):
20
+ def body_fn (j , A ):
21
+ A = jax .lax .cond (constraints [j ] == constraints [i ],
22
+ lambda _ : A .at [:, j ].set (a ),
23
+ lambda _ : A ,
24
+ None )
25
+ return A
26
+
27
+ A = jax .lax .fori_loop (0 , A .shape [1 ], body_fn , A )
28
+ return A
29
+
30
+ @partial (jax .jit , static_argnums = 0 )
31
+ def mcmc (logp_fn , x_init , key , mc_steps , mc_width ):
32
+ """
33
+ Markov Chain Monte Carlo sampling algorithm.
34
+
35
+ INPUT:
36
+ logp_fn: callable that evaluate log-probability of a batch of configuration x.
37
+ The signature is logp_fn(x), where x has shape (batch, n, dim).
38
+ x_init: initial value of x, with shape (batch, n, dim).
39
+ key: initial PRNG key.
40
+ mc_steps: total number of Monte Carlo steps.
41
+ mc_width: size of the Monte Carlo proposal.
42
+
43
+ OUTPUT:
44
+ x: resulting batch samples, with the same shape as `x_init`.
45
+ """
46
+ def step (i , state ):
47
+
48
+ def true_func (i , state ):
49
+ x , logp , key , num_accepts = state
50
+ G , L , XYZ , A , W = x
51
+ key , key_proposal_A , key_proposal_XYZ , key_accept , key_logp = jax .random .split (key , 5 )
52
+
53
+ p_normalized = atom_mask [i % n_max ] / jnp .sum (atom_mask [i % n_max ]) # only propose atom types that are allowed
54
+ _a = jax .random .choice (key_proposal_A , a = atom_types , p = p_normalized , shape = (A .shape [0 ], ))
55
+ # _A = A.at[:, i%n_max].set(_a)
56
+ _A = update_A (i % n_max , A , _a , constraints )
57
+ A_proposal = jnp .where (A == 0 , A , _A )
58
+
59
+ fc_mask = jax .vmap (get_fc_mask , in_axes = (0 , 0 ))(G , W )
60
+ _xyz = XYZ [:, i % n_max ] + sample_von_mises (key_proposal_XYZ , 0 , 1 / mc_width ** 2 , XYZ [:, i % n_max ].shape )
61
+ _XYZ = XYZ .at [:, i % n_max ].set (_xyz )
62
+ _XYZ -= jnp .floor (_XYZ ) # wrap to [0, 1)
63
+ XYZ_proposal = jnp .where (fc_mask , _XYZ , XYZ )
64
+ x_proposal = (G , L , XYZ_proposal , A_proposal , W )
65
+
66
+ logp_w , logp_xyz , logp_a , _ = logp_fn (params , key_logp , * x_proposal , False )
67
+ logp_proposal = logp_w + logp_xyz + logp_a
68
+
69
+ ratio = jnp .exp ((logp_proposal - logp ))
70
+ accept = jax .random .uniform (key_accept , ratio .shape ) < ratio
71
+
72
+ A_new = jnp .where (accept [:, None ], A_proposal , A ) # update atom types
73
+ XYZ_new = jnp .where (accept [:, None , None ], XYZ_proposal , XYZ ) # update atom positions
74
+ x_new = (G , L , XYZ_new , A_new , W )
75
+ logp_new = jnp .where (accept , logp_proposal , logp )
76
+ num_accepts += jnp .sum (accept * jnp .where (A [:, i % n_max ]== 0 , 0 , 1 ))
77
+ return x_new , logp_new , key , num_accepts
78
+
79
+ def false_func (i , state ):
80
+ x , logp , key , num_accepts = state
81
+ return x , logp , key , num_accepts
82
+
83
+ x , logp , key , num_accepts = state
84
+ A = x [3 ]
85
+ x , logp , key , num_accepts = jax .lax .cond (A [:, i % n_max ].sum () != 0 ,
86
+ lambda _ : true_func (i , state ),
87
+ lambda _ : false_func (i , state ),
88
+ None )
89
+ return x , logp , key , num_accepts
90
+
91
+ key , subkey = jax .random .split (key )
92
+ logp_w , logp_xyz , logp_a , _ = logp_fn (params , subkey , * x_init , False )
93
+ logp_init = logp_w + logp_xyz + logp_a
94
+ # print("logp_init", logp_init)
95
+
96
+ x , logp , key , num_accepts = jax .lax .fori_loop (0 , mc_steps , step , (x_init , logp_init , key , 0. ))
97
+ # print("logp", logp)
98
+ A = x [3 ]
99
+ scale = jnp .sum (A != 0 )/ (A .shape [0 ]* n_max )
100
+ accept_rate = num_accepts / (scale * mc_steps * x [0 ].shape [0 ])
101
+ return x , accept_rate
102
+
103
+ return mcmc
104
+
105
+
106
+ if __name__ == "__main__" :
107
+ from utils import GLXYZAW_from_file
108
+ from loss import make_loss_fn
109
+ from transformer import make_transformer
110
+ atom_types = 119
111
+ n_max = 21
112
+ wyck_types = 28
113
+ Nf = 5
114
+ Kx = 16
115
+ Kl = 4
116
+ dropout_rate = 0.3
117
+
118
+ csv_file = '../data/mini.csv'
119
+ G , L , XYZ , A , W = GLXYZAW_from_file (csv_file , atom_types , wyck_types , n_max )
120
+
121
+ key = jax .random .PRNGKey (42 )
122
+
123
+ params , transformer = make_transformer (key , Nf , Kx , Kl , n_max , 128 , 4 , 4 , 8 , 16 , 16 , atom_types , wyck_types , dropout_rate )
124
+
125
+ loss_fn , logp_fn = make_loss_fn (n_max , atom_types , wyck_types , Kx , Kl , transformer )
126
+
127
+ # MCMC sampling test
128
+ mc_steps = 21
129
+ mc_width = 0.1
130
+ x_init = (G [:5 ], L [:5 ], XYZ [:5 ], A [:5 ], W [:5 ])
131
+
132
+ value = jax .jit (logp_fn , static_argnums = 7 )(params , key , * x_init , False )
133
+
134
+ jnp .set_printoptions (threshold = jnp .inf )
135
+ mcmc = make_mcmc_step (params , n_max = n_max , atom_types = atom_types )
136
+
137
+ for i in range (5 ):
138
+ key , subkey = jax .random .split (key )
139
+ x , acc = mcmc (logp_fn , x_init = x_init , key = subkey , mc_steps = mc_steps , mc_width = mc_width )
140
+ print (i , acc )
141
+
142
+ print ("check if the atom type is changed" )
143
+ print (x_init [3 ])
144
+ print (x [3 ])
145
+
146
+ print ("check if the atom position is changed" )
147
+ print (x_init [2 ])
148
+ print (x [2 ])
0 commit comments