Skip to content

Commit

Permalink
add MCMC after sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
zdcao121 committed Jun 5, 2024
1 parent 5c6103f commit 938e0a4
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 187 deletions.
77 changes: 67 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from elements import element_dict, element_list
from transformer import make_transformer
from train import train
from sample import sample_crystal
from sample import sample_crystal, make_update_lattice
from loss import make_loss_fn
import checkpoint
from wyckoff import mult_table
from mcmc import make_mcmc_step

import argparse
parser = argparse.ArgumentParser(description='')
Expand Down Expand Up @@ -69,9 +70,13 @@
group.add_argument('--T1', type=float, default=None, help='temperature used for sampling the first atom type')
group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io')
group.add_argument('--num_samples', type=int, default=1000, help='number of test samples')
group.add_argument('--use_foriloop', action='store_true', help='use lax.fori_loop in sampling')
group.add_argument('--output_filename', type=str, default='output.csv', help='outfile to save sampled structures')

group = parser.add_argument_group('MCMC parameters')
group.add_argument('--mcmc', action='store_true', help='use MCMC to sample')
group.add_argument('--nsweeps', type=int, default=10, help='number of sweeps')
group.add_argument('--mc_width', type=float, default=0.1, help='width of MCMC step')

args = parser.parse_args()

key = jax.random.PRNGKey(42)
Expand All @@ -91,24 +96,59 @@
assert (args.spacegroup is not None) # for inference we need to specify space group
test_data = GLXYZAW_from_file(args.test_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process)

# jnp.set_printoptions(threshold=jnp.inf) # print full array
constraints = jnp.arange(0, args.n_max, 1)

if args.elements is not None:
idx = [element_dict[e] for e in args.elements]
atom_mask = [1] + [1 if a in idx else 0 for a in range(1, args.atom_types)]
atom_mask = jnp.array(atom_mask)
print ('sampling structure formed by these elements:', args.elements)
print (atom_mask)
# judge that if the input elements is a json file
if args.elements[0].endswith('.json'):
import json
with open(args.elements[0], 'r') as f:
_data = json.load(f)
atoms_list = _data["atom_mask"]
_constraints = _data["constraints"]
print(_constraints)

for old_val, new_val in _constraints:
constraints = jnp.where(constraints == new_val, constraints[old_val], constraints) # update constraints
print("constraints", constraints)

assert len(atoms_list) == len(args.wyckoff)
print ('sampling structure formed by these elements:', atoms_list)
atom_mask = []
for elements in atoms_list:
idx = [element_dict[e] for e in elements]
atom_mask_ = [1] + [1 if a in idx else 0 for a in range(1, args.atom_types)]
atom_mask.append(atom_mask_)

# padding 0 until the atom_mask shape is (args.n_max, args.atom_types)
atom_mask = jnp.array(atom_mask)
atom_mask = jnp.pad(atom_mask, ((0, args.n_max-atom_mask.shape[0]), (0, 0)), mode='constant')
print(atom_mask)
else:
idx = [element_dict[e] for e in args.elements]
atom_mask = [1] + [1 if a in idx else 0 for a in range(1, args.atom_types)]
atom_mask = jnp.array(atom_mask)
atom_mask = jnp.stack([atom_mask] * args.n_max, axis=0)
print ('sampling structure formed by these elements:', args.elements)
print (atom_mask)

else:
if args.remove_radioactive:
from elements import radioactive_elements_dict, noble_gas_dict
# remove radioactive elements and noble gas
atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, args.atom_types)]
atom_mask = jnp.array(atom_mask)
atom_mask = jnp.stack([atom_mask] * args.n_max, axis=0)
print('sampling structure formed by non-radioactive elements and non-noble gas')
print(atom_mask)

else:
atom_mask = jnp.zeros((args.atom_types), dtype=int) # we will do nothing to a_logit in sampling
print(f'there is total {jnp.sum(atom_mask)-1} elements')
atom_mask = jnp.stack([atom_mask] * args.n_max, axis=0)
print(atom_mask)
# print(f'there is total {jnp.sum(atom_mask)-1} elements')
print(atom_mask.shape)

if args.wyckoff is not None:
idx = [letter_to_number(w) for w in args.wyckoff]
Expand Down Expand Up @@ -221,6 +261,11 @@
else:
T1 = args.temperature

mc_steps = args.nsweeps * args.n_max
print("mc_steps", mc_steps)
mcmc = make_mcmc_step(params, n_max=args.n_max, atom_types=args.atom_types, atom_mask=atom_mask, constraints=constraints)
update_lattice = make_update_lattice(transformer, params, args.atom_types, args.Kl, args.top_p, args.temperature)

num_batches = math.ceil(args.num_samples / args.batchsize)
name, extension = args.output_filename.rsplit('.', 1)
filename = os.path.join(output_path,
Expand All @@ -230,7 +275,19 @@
end_idx = min(start_idx + args.batchsize, args.num_samples)
n_sample = end_idx - start_idx
key, subkey = jax.random.split(key)
XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, w_mask, atom_mask, args.top_p, args.temperature, T1, args.use_foriloop)
XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, w_mask, atom_mask, args.top_p, args.temperature, T1, constraints)

G = args.spacegroup * jnp.ones((n_sample), dtype=int)
if args.mcmc:
x = (G, L, XYZ, A, W)
key, subkey = jax.random.split(key)
x, acc = mcmc(logp_fn, x_init=x, key=subkey, mc_steps=mc_steps, mc_width=args.mc_width)
print("acc", acc)

G, L, XYZ, A, W = x
key, subkey = jax.random.split(key)
L = update_lattice(subkey, G, XYZ, A, W)

print ("XYZ:\n", XYZ) # fractional coordinate
print ("A:\n", A) # element type
print ("W:\n", W) # Wyckoff positions
Expand All @@ -256,7 +313,7 @@
angle = angle * (jnp.pi / 180) # to rad
L = jnp.concatenate([length, angle], axis=-1)

G = args.spacegroup * jnp.ones((n_sample), dtype=int)
# G = args.spacegroup * jnp.ones((n_sample), dtype=int)
logp_w, logp_xyz, logp_a, logp_l = jax.jit(logp_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False)

data['logp_w'] = np.array(logp_w).tolist()
Expand Down
148 changes: 148 additions & 0 deletions src/mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import jax
import jax.numpy as jnp
from functools import partial

from wyckoff import fc_mask_table
from von_mises import sample_von_mises


get_fc_mask = lambda g, w: jnp.logical_and((w>0)[:, None], fc_mask_table[g-1, w])

def make_mcmc_step(params, n_max, atom_types, atom_mask=None, constraints=None):

if atom_mask is None or jnp.all(atom_mask == 0):
atom_mask = jnp.ones((n_max, atom_types))

if constraints is None:
constraints = jnp.arange(0, n_max, 1)

def update_A(i, A, a, constraints):
def body_fn(j, A):
A = jax.lax.cond(constraints[j] == constraints[i],
lambda _: A.at[:, j].set(a),
lambda _: A,
None)
return A

A = jax.lax.fori_loop(0, A.shape[1], body_fn, A)
return A

@partial(jax.jit, static_argnums=0)
def mcmc(logp_fn, x_init, key, mc_steps, mc_width):
"""
Markov Chain Monte Carlo sampling algorithm.
INPUT:
logp_fn: callable that evaluate log-probability of a batch of configuration x.
The signature is logp_fn(x), where x has shape (batch, n, dim).
x_init: initial value of x, with shape (batch, n, dim).
key: initial PRNG key.
mc_steps: total number of Monte Carlo steps.
mc_width: size of the Monte Carlo proposal.
OUTPUT:
x: resulting batch samples, with the same shape as `x_init`.
"""
def step(i, state):

def true_func(i, state):
x, logp, key, num_accepts = state
G, L, XYZ, A, W = x
key, key_proposal_A, key_proposal_XYZ, key_accept, key_logp = jax.random.split(key, 5)

p_normalized = atom_mask[i%n_max] / jnp.sum(atom_mask[i%n_max]) # only propose atom types that are allowed
_a = jax.random.choice(key_proposal_A, a=atom_types, p=p_normalized, shape=(A.shape[0], ))
# _A = A.at[:, i%n_max].set(_a)
_A = update_A(i%n_max, A, _a, constraints)
A_proposal = jnp.where(A == 0, A, _A)

fc_mask = jax.vmap(get_fc_mask, in_axes=(0, 0))(G, W)
_xyz = XYZ[:, i%n_max] + sample_von_mises(key_proposal_XYZ, 0, 1/mc_width**2, XYZ[:, i%n_max].shape)
_XYZ = XYZ.at[:, i%n_max].set(_xyz)
_XYZ -= jnp.floor(_XYZ) # wrap to [0, 1)
XYZ_proposal = jnp.where(fc_mask, _XYZ, XYZ)
x_proposal = (G, L, XYZ_proposal, A_proposal, W)

logp_w, logp_xyz, logp_a, _ = logp_fn(params, key_logp, *x_proposal, False)
logp_proposal = logp_w + logp_xyz + logp_a

ratio = jnp.exp((logp_proposal - logp))
accept = jax.random.uniform(key_accept, ratio.shape) < ratio

A_new = jnp.where(accept[:, None], A_proposal, A) # update atom types
XYZ_new = jnp.where(accept[:, None, None], XYZ_proposal, XYZ) # update atom positions
x_new = (G, L, XYZ_new, A_new, W)
logp_new = jnp.where(accept, logp_proposal, logp)
num_accepts += jnp.sum(accept*jnp.where(A[:, i%n_max]==0, 0, 1))
return x_new, logp_new, key, num_accepts

def false_func(i, state):
x, logp, key, num_accepts = state
return x, logp, key, num_accepts

x, logp, key, num_accepts = state
A = x[3]
x, logp, key, num_accepts = jax.lax.cond(A[:, i%n_max].sum() != 0,
lambda _: true_func(i, state),
lambda _: false_func(i, state),
None)
return x, logp, key, num_accepts

key, subkey = jax.random.split(key)
logp_w, logp_xyz, logp_a, _ = logp_fn(params, subkey, *x_init, False)
logp_init = logp_w + logp_xyz + logp_a
# print("logp_init", logp_init)

x, logp, key, num_accepts = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0.))
# print("logp", logp)
A = x[3]
scale = jnp.sum(A != 0)/(A.shape[0]*n_max)
accept_rate = num_accepts / (scale * mc_steps * x[0].shape[0])
return x, accept_rate

return mcmc


if __name__ == "__main__":
from utils import GLXYZAW_from_file
from loss import make_loss_fn
from transformer import make_transformer
atom_types = 119
n_max = 21
wyck_types = 28
Nf = 5
Kx = 16
Kl = 4
dropout_rate = 0.3

csv_file = '../data/mini.csv'
G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max)

key = jax.random.PRNGKey(42)

params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate)

loss_fn, logp_fn = make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, transformer)

# MCMC sampling test
mc_steps = 21
mc_width = 0.1
x_init = (G[:5], L[:5], XYZ[:5], A[:5], W[:5])

value = jax.jit(logp_fn, static_argnums=7)(params, key, *x_init, False)

jnp.set_printoptions(threshold=jnp.inf)
mcmc = make_mcmc_step(params, n_max=n_max, atom_types=atom_types)

for i in range(5):
key, subkey = jax.random.split(key)
x, acc = mcmc(logp_fn, x_init=x_init, key=subkey, mc_steps=mc_steps, mc_width=mc_width)
print(i, acc)

print("check if the atom type is changed")
print(x_init[3])
print(x[3])

print("check if the atom position is changed")
print(x_init[2])
print(x[2])
Loading

0 comments on commit 938e0a4

Please sign in to comment.