Skip to content

Commit 938e0a4

Browse files
committed
add MCMC after sampling
1 parent 5c6103f commit 938e0a4

File tree

3 files changed

+337
-187
lines changed

3 files changed

+337
-187
lines changed

src/main.py

+67-10
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from elements import element_dict, element_list
1212
from transformer import make_transformer
1313
from train import train
14-
from sample import sample_crystal
14+
from sample import sample_crystal, make_update_lattice
1515
from loss import make_loss_fn
1616
import checkpoint
1717
from wyckoff import mult_table
18+
from mcmc import make_mcmc_step
1819

1920
import argparse
2021
parser = argparse.ArgumentParser(description='')
@@ -69,9 +70,13 @@
6970
group.add_argument('--T1', type=float, default=None, help='temperature used for sampling the first atom type')
7071
group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io')
7172
group.add_argument('--num_samples', type=int, default=1000, help='number of test samples')
72-
group.add_argument('--use_foriloop', action='store_true', help='use lax.fori_loop in sampling')
7373
group.add_argument('--output_filename', type=str, default='output.csv', help='outfile to save sampled structures')
7474

75+
group = parser.add_argument_group('MCMC parameters')
76+
group.add_argument('--mcmc', action='store_true', help='use MCMC to sample')
77+
group.add_argument('--nsweeps', type=int, default=10, help='number of sweeps')
78+
group.add_argument('--mc_width', type=float, default=0.1, help='width of MCMC step')
79+
7580
args = parser.parse_args()
7681

7782
key = jax.random.PRNGKey(42)
@@ -91,24 +96,59 @@
9196
assert (args.spacegroup is not None) # for inference we need to specify space group
9297
test_data = GLXYZAW_from_file(args.test_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process)
9398

99+
# jnp.set_printoptions(threshold=jnp.inf) # print full array
100+
constraints = jnp.arange(0, args.n_max, 1)
101+
94102
if args.elements is not None:
95-
idx = [element_dict[e] for e in args.elements]
96-
atom_mask = [1] + [1 if a in idx else 0 for a in range(1, args.atom_types)]
97-
atom_mask = jnp.array(atom_mask)
98-
print ('sampling structure formed by these elements:', args.elements)
99-
print (atom_mask)
103+
# judge that if the input elements is a json file
104+
if args.elements[0].endswith('.json'):
105+
import json
106+
with open(args.elements[0], 'r') as f:
107+
_data = json.load(f)
108+
atoms_list = _data["atom_mask"]
109+
_constraints = _data["constraints"]
110+
print(_constraints)
111+
112+
for old_val, new_val in _constraints:
113+
constraints = jnp.where(constraints == new_val, constraints[old_val], constraints) # update constraints
114+
print("constraints", constraints)
115+
116+
assert len(atoms_list) == len(args.wyckoff)
117+
print ('sampling structure formed by these elements:', atoms_list)
118+
atom_mask = []
119+
for elements in atoms_list:
120+
idx = [element_dict[e] for e in elements]
121+
atom_mask_ = [1] + [1 if a in idx else 0 for a in range(1, args.atom_types)]
122+
atom_mask.append(atom_mask_)
123+
124+
# padding 0 until the atom_mask shape is (args.n_max, args.atom_types)
125+
atom_mask = jnp.array(atom_mask)
126+
atom_mask = jnp.pad(atom_mask, ((0, args.n_max-atom_mask.shape[0]), (0, 0)), mode='constant')
127+
print(atom_mask)
128+
else:
129+
idx = [element_dict[e] for e in args.elements]
130+
atom_mask = [1] + [1 if a in idx else 0 for a in range(1, args.atom_types)]
131+
atom_mask = jnp.array(atom_mask)
132+
atom_mask = jnp.stack([atom_mask] * args.n_max, axis=0)
133+
print ('sampling structure formed by these elements:', args.elements)
134+
print (atom_mask)
135+
100136
else:
101137
if args.remove_radioactive:
102138
from elements import radioactive_elements_dict, noble_gas_dict
103139
# remove radioactive elements and noble gas
104140
atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, args.atom_types)]
105141
atom_mask = jnp.array(atom_mask)
142+
atom_mask = jnp.stack([atom_mask] * args.n_max, axis=0)
106143
print('sampling structure formed by non-radioactive elements and non-noble gas')
107144
print(atom_mask)
108145

109146
else:
110147
atom_mask = jnp.zeros((args.atom_types), dtype=int) # we will do nothing to a_logit in sampling
111-
print(f'there is total {jnp.sum(atom_mask)-1} elements')
148+
atom_mask = jnp.stack([atom_mask] * args.n_max, axis=0)
149+
print(atom_mask)
150+
# print(f'there is total {jnp.sum(atom_mask)-1} elements')
151+
print(atom_mask.shape)
112152

113153
if args.wyckoff is not None:
114154
idx = [letter_to_number(w) for w in args.wyckoff]
@@ -221,6 +261,11 @@
221261
else:
222262
T1 = args.temperature
223263

264+
mc_steps = args.nsweeps * args.n_max
265+
print("mc_steps", mc_steps)
266+
mcmc = make_mcmc_step(params, n_max=args.n_max, atom_types=args.atom_types, atom_mask=atom_mask, constraints=constraints)
267+
update_lattice = make_update_lattice(transformer, params, args.atom_types, args.Kl, args.top_p, args.temperature)
268+
224269
num_batches = math.ceil(args.num_samples / args.batchsize)
225270
name, extension = args.output_filename.rsplit('.', 1)
226271
filename = os.path.join(output_path,
@@ -230,7 +275,19 @@
230275
end_idx = min(start_idx + args.batchsize, args.num_samples)
231276
n_sample = end_idx - start_idx
232277
key, subkey = jax.random.split(key)
233-
XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, w_mask, atom_mask, args.top_p, args.temperature, T1, args.use_foriloop)
278+
XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, w_mask, atom_mask, args.top_p, args.temperature, T1, constraints)
279+
280+
G = args.spacegroup * jnp.ones((n_sample), dtype=int)
281+
if args.mcmc:
282+
x = (G, L, XYZ, A, W)
283+
key, subkey = jax.random.split(key)
284+
x, acc = mcmc(logp_fn, x_init=x, key=subkey, mc_steps=mc_steps, mc_width=args.mc_width)
285+
print("acc", acc)
286+
287+
G, L, XYZ, A, W = x
288+
key, subkey = jax.random.split(key)
289+
L = update_lattice(subkey, G, XYZ, A, W)
290+
234291
print ("XYZ:\n", XYZ) # fractional coordinate
235292
print ("A:\n", A) # element type
236293
print ("W:\n", W) # Wyckoff positions
@@ -256,7 +313,7 @@
256313
angle = angle * (jnp.pi / 180) # to rad
257314
L = jnp.concatenate([length, angle], axis=-1)
258315

259-
G = args.spacegroup * jnp.ones((n_sample), dtype=int)
316+
# G = args.spacegroup * jnp.ones((n_sample), dtype=int)
260317
logp_w, logp_xyz, logp_a, logp_l = jax.jit(logp_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False)
261318

262319
data['logp_w'] = np.array(logp_w).tolist()

src/mcmc.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from functools import partial
4+
5+
from wyckoff import fc_mask_table
6+
from von_mises import sample_von_mises
7+
8+
9+
get_fc_mask = lambda g, w: jnp.logical_and((w>0)[:, None], fc_mask_table[g-1, w])
10+
11+
def make_mcmc_step(params, n_max, atom_types, atom_mask=None, constraints=None):
12+
13+
if atom_mask is None or jnp.all(atom_mask == 0):
14+
atom_mask = jnp.ones((n_max, atom_types))
15+
16+
if constraints is None:
17+
constraints = jnp.arange(0, n_max, 1)
18+
19+
def update_A(i, A, a, constraints):
20+
def body_fn(j, A):
21+
A = jax.lax.cond(constraints[j] == constraints[i],
22+
lambda _: A.at[:, j].set(a),
23+
lambda _: A,
24+
None)
25+
return A
26+
27+
A = jax.lax.fori_loop(0, A.shape[1], body_fn, A)
28+
return A
29+
30+
@partial(jax.jit, static_argnums=0)
31+
def mcmc(logp_fn, x_init, key, mc_steps, mc_width):
32+
"""
33+
Markov Chain Monte Carlo sampling algorithm.
34+
35+
INPUT:
36+
logp_fn: callable that evaluate log-probability of a batch of configuration x.
37+
The signature is logp_fn(x), where x has shape (batch, n, dim).
38+
x_init: initial value of x, with shape (batch, n, dim).
39+
key: initial PRNG key.
40+
mc_steps: total number of Monte Carlo steps.
41+
mc_width: size of the Monte Carlo proposal.
42+
43+
OUTPUT:
44+
x: resulting batch samples, with the same shape as `x_init`.
45+
"""
46+
def step(i, state):
47+
48+
def true_func(i, state):
49+
x, logp, key, num_accepts = state
50+
G, L, XYZ, A, W = x
51+
key, key_proposal_A, key_proposal_XYZ, key_accept, key_logp = jax.random.split(key, 5)
52+
53+
p_normalized = atom_mask[i%n_max] / jnp.sum(atom_mask[i%n_max]) # only propose atom types that are allowed
54+
_a = jax.random.choice(key_proposal_A, a=atom_types, p=p_normalized, shape=(A.shape[0], ))
55+
# _A = A.at[:, i%n_max].set(_a)
56+
_A = update_A(i%n_max, A, _a, constraints)
57+
A_proposal = jnp.where(A == 0, A, _A)
58+
59+
fc_mask = jax.vmap(get_fc_mask, in_axes=(0, 0))(G, W)
60+
_xyz = XYZ[:, i%n_max] + sample_von_mises(key_proposal_XYZ, 0, 1/mc_width**2, XYZ[:, i%n_max].shape)
61+
_XYZ = XYZ.at[:, i%n_max].set(_xyz)
62+
_XYZ -= jnp.floor(_XYZ) # wrap to [0, 1)
63+
XYZ_proposal = jnp.where(fc_mask, _XYZ, XYZ)
64+
x_proposal = (G, L, XYZ_proposal, A_proposal, W)
65+
66+
logp_w, logp_xyz, logp_a, _ = logp_fn(params, key_logp, *x_proposal, False)
67+
logp_proposal = logp_w + logp_xyz + logp_a
68+
69+
ratio = jnp.exp((logp_proposal - logp))
70+
accept = jax.random.uniform(key_accept, ratio.shape) < ratio
71+
72+
A_new = jnp.where(accept[:, None], A_proposal, A) # update atom types
73+
XYZ_new = jnp.where(accept[:, None, None], XYZ_proposal, XYZ) # update atom positions
74+
x_new = (G, L, XYZ_new, A_new, W)
75+
logp_new = jnp.where(accept, logp_proposal, logp)
76+
num_accepts += jnp.sum(accept*jnp.where(A[:, i%n_max]==0, 0, 1))
77+
return x_new, logp_new, key, num_accepts
78+
79+
def false_func(i, state):
80+
x, logp, key, num_accepts = state
81+
return x, logp, key, num_accepts
82+
83+
x, logp, key, num_accepts = state
84+
A = x[3]
85+
x, logp, key, num_accepts = jax.lax.cond(A[:, i%n_max].sum() != 0,
86+
lambda _: true_func(i, state),
87+
lambda _: false_func(i, state),
88+
None)
89+
return x, logp, key, num_accepts
90+
91+
key, subkey = jax.random.split(key)
92+
logp_w, logp_xyz, logp_a, _ = logp_fn(params, subkey, *x_init, False)
93+
logp_init = logp_w + logp_xyz + logp_a
94+
# print("logp_init", logp_init)
95+
96+
x, logp, key, num_accepts = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0.))
97+
# print("logp", logp)
98+
A = x[3]
99+
scale = jnp.sum(A != 0)/(A.shape[0]*n_max)
100+
accept_rate = num_accepts / (scale * mc_steps * x[0].shape[0])
101+
return x, accept_rate
102+
103+
return mcmc
104+
105+
106+
if __name__ == "__main__":
107+
from utils import GLXYZAW_from_file
108+
from loss import make_loss_fn
109+
from transformer import make_transformer
110+
atom_types = 119
111+
n_max = 21
112+
wyck_types = 28
113+
Nf = 5
114+
Kx = 16
115+
Kl = 4
116+
dropout_rate = 0.3
117+
118+
csv_file = '../data/mini.csv'
119+
G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max)
120+
121+
key = jax.random.PRNGKey(42)
122+
123+
params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate)
124+
125+
loss_fn, logp_fn = make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, transformer)
126+
127+
# MCMC sampling test
128+
mc_steps = 21
129+
mc_width = 0.1
130+
x_init = (G[:5], L[:5], XYZ[:5], A[:5], W[:5])
131+
132+
value = jax.jit(logp_fn, static_argnums=7)(params, key, *x_init, False)
133+
134+
jnp.set_printoptions(threshold=jnp.inf)
135+
mcmc = make_mcmc_step(params, n_max=n_max, atom_types=atom_types)
136+
137+
for i in range(5):
138+
key, subkey = jax.random.split(key)
139+
x, acc = mcmc(logp_fn, x_init=x_init, key=subkey, mc_steps=mc_steps, mc_width=mc_width)
140+
print(i, acc)
141+
142+
print("check if the atom type is changed")
143+
print(x_init[3])
144+
print(x[3])
145+
146+
print("check if the atom position is changed")
147+
print(x_init[2])
148+
print(x[2])

0 commit comments

Comments
 (0)