Skip to content

Commit

Permalink
add w_mask to control the sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
zdcao121 committed May 28, 2024
1 parent 8021812 commit 5c6103f
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 12 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ We only consider symmetry inequivalent atoms. The remaining atoms are restored b

**Notebooks**: The quickest way to get started with _CrystalFormer_ is our notebooks in the Google Colab and Bohrium (Chinese version) platforms:

- CrystalFormer Quickstart [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IMQV6OQgIGORE8FmSTmZuC5KgQwGCnDx?usp=sharing) [![Open In Bohrium](https://cdn.dp.tech/bohrium/web/static/images/open-in-bohrium.svg)](https://nb.bohrium.dp.tech/detail/68177247598): GUI notebook demonstrating the conditional generation of crystalline materials with _CrystalFormer_.

- CrystalFormer Quickstart [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IMQV6OQgIGORE8FmSTmZuC5KgQwGCnDx?usp=sharing) [![Open In Bohrium](https://cdn.dp.tech/bohrium/web/static/images/open-in-bohrium.svg)](https://nb.bohrium.dp.tech/detail/68177247598): GUI notebook demonstrating the conditional generation of crystalline materials with _CrystalFormer_;
- CrystalFormer Application [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QdkELaQXAHR1zEu2fcdfgabuoP61_wbU?usp=sharing): Generating stable crystals with a given structure prototype. This workflow can be applied to tasks that are dominated by element substitution.

## Installation

Expand Down
2 changes: 1 addition & 1 deletion scripts/awl2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def symmetrize_atoms(g, w, x):
#https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115
def dist_to_op0x(coord):
diff = np.dot(symops[g-1, w, 0], np.array([*coord, 1])) - coord
diff -= np.floor(diff)
diff -= np.rint(diff)
return np.sum(diff**2)
# loc = np.argmin(jax.vmap(dist_to_op0x)(coords))
loc = np.argmin([dist_to_op0x(coord) for coord in coords])
Expand Down
17 changes: 17 additions & 0 deletions src/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@

element_dict = {value: index for index, value in enumerate(element_list)}

# radioactive elements
radioactive_elements = [ 'Tc', 'Pm', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu',
'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh',
'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']
radioactive_elements_dict = {e: element_dict[e] for e in radioactive_elements}

# noble gas elements
noble_gas = ['He', 'Ne', 'Ar', 'Kr', 'Xe', 'Rn', 'Og']
noble_gas_dict = {e: element_dict[e] for e in noble_gas}


if __name__=="__main__":
print (len(element_list))
print (element_dict["H"])
Expand All @@ -38,5 +49,11 @@
aw_mask = [1] + [1 if ((i-1)%(atom_types-1)+1 in idx) else 0 for i in range(1, aw_types)] # 1 for possible elements
print (idx )
print (aw_mask)
print(radioactive_elements_dict)
print(noble_gas_dict)
atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, atom_types)]
print('sampling structure formed by non-radioactive elements and non-noble gas')
print(atom_mask)



39 changes: 36 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import multiprocessing
import math

from utils import GLXYZAW_from_file, GLXA_to_csv
from utils import GLXYZAW_from_file, GLXA_to_csv, letter_to_number
from elements import element_dict, element_list
from transformer import make_transformer
from train import train
Expand Down Expand Up @@ -59,10 +59,14 @@
group.add_argument('--wyck_types', type=int, default=28, help='Number of possible multiplicites including 0')

group = parser.add_argument_group('sampling parameters')
group.add_argument('--seed', type=int, default=None, help='random seed to sample')
group.add_argument('--spacegroup', type=int, help='The space group id to be sampled (1-230)')
group.add_argument('--wyckoff', type=str, default=None, nargs='+', help='The Wyckoff positions to be sampled, e.g. a, b')
group.add_argument('--elements', type=str, default=None, nargs='+', help='name of the chemical elemenets, e.g. Bi, Ti, O')
group.add_argument('--remove_radioactive', action='store_true', help='remove radioactive elements and noble gas')
group.add_argument('--top_p', type=float, default=1.0, help='1.0 means un-modified logits, smaller value of p give give less diverse samples')
group.add_argument('--temperature', type=float, default=1.0, help='temperature used for sampling')
group.add_argument('--T1', type=float, default=None, help='temperature used for sampling the first atom type')
group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io')
group.add_argument('--num_samples', type=int, default=1000, help='number of test samples')
group.add_argument('--use_foriloop', action='store_true', help='use lax.fori_loop in sampling')
Expand Down Expand Up @@ -94,7 +98,28 @@
print ('sampling structure formed by these elements:', args.elements)
print (atom_mask)
else:
atom_mask = jnp.zeros((args.atom_types), dtype=int) # we will do nothing to a_logit in sampling
if args.remove_radioactive:
from elements import radioactive_elements_dict, noble_gas_dict
# remove radioactive elements and noble gas
atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, args.atom_types)]
atom_mask = jnp.array(atom_mask)
print('sampling structure formed by non-radioactive elements and non-noble gas')
print(atom_mask)

else:
atom_mask = jnp.zeros((args.atom_types), dtype=int) # we will do nothing to a_logit in sampling
print(f'there is total {jnp.sum(atom_mask)-1} elements')

if args.wyckoff is not None:
idx = [letter_to_number(w) for w in args.wyckoff]
# padding 0 until the length is args.n_max
w_mask = idx + [0]*(args.n_max -len(idx))
# w_mask = [1 if w in idx else 0 for w in range(1, args.wyck_types+1)]
w_mask = jnp.array(w_mask, dtype=int)
print ('sampling structure formed by these Wyckoff positions:', args.wyckoff)
print (w_mask)
else:
w_mask = None

################### Model #############################
params, transformer = make_transformer(key, args.Nf, args.Kx, args.Kl, args.n_max,
Expand Down Expand Up @@ -188,6 +213,14 @@
jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice
#FYI, the error was [Compiling module extracted] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.

if args.seed is not None:
key = jax.random.PRNGKey(args.seed) # reset key for sampling if seed is provided

if args.T1 is not None:
T1 = args.T1
else:
T1 = args.temperature

num_batches = math.ceil(args.num_samples / args.batchsize)
name, extension = args.output_filename.rsplit('.', 1)
filename = os.path.join(output_path,
Expand All @@ -197,7 +230,7 @@
end_idx = min(start_idx + args.batchsize, args.num_samples)
n_sample = end_idx - start_idx
key, subkey = jax.random.split(key)
XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, atom_mask, args.top_p, args.temperature, args.use_foriloop)
XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, w_mask, atom_mask, args.top_p, args.temperature, T1, args.use_foriloop)
print ("XYZ:\n", XYZ) # fractional coordinate
print ("A:\n", A) # element type
print ("W:\n", W) # Wyckoff positions
Expand Down
22 changes: 17 additions & 5 deletions src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def sample_x(key, h_x, Kx, top_p, temperature, batchsize):
x = (x+ jnp.pi)/(2.0*jnp.pi) # wrap into [0, 1]
return key, x

@partial(jax.jit, static_argnums=(1, 3, 4, 5, 6, 7, 8, 9, 11, 13))
def sample_crystal(key, transformer, params, n_max, batchsize, atom_types, wyck_types, Kx, Kl, g, atom_mask, top_p, temperature, use_foriloop):
@partial(jax.jit, static_argnums=(1, 3, 4, 5, 6, 7, 8, 9, 12, 14, 15))
def sample_crystal(key, transformer, params, n_max, batchsize, atom_types, wyck_types, Kx, Kl, g, w_mask, atom_mask, top_p, temperature, T1, use_foriloop):

if use_foriloop:

Expand All @@ -72,16 +72,22 @@ def body_fn(i, state):
w_logit = w_logit[:, :wyck_types]

key, subkey = jax.random.split(key)
if w_mask is not None:
w_logit = w_logit.at[:, w_mask[i]].set(w_logit[:, w_mask[i]] + 1e10)
w = sample_top_p(subkey, w_logit, top_p, temperature)
W = W.at[:, i].set(w)

# (2) A
h_al = inference(transformer, params, g, W, A, X, Y, Z)[:, 5*i+1] # (batchsize, output_size)
a_logit = h_al[:, :atom_types]

key, subkey = jax.random.split(key)
a_logit = a_logit + jnp.where(atom_mask, 1e10, 0.0) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp)
a = sample_top_p(subkey, a_logit, top_p, temperature)
_temp = jax.lax.cond(i==0,
true_fun=lambda x: jnp.array(T1, dtype=float),
false_fun=lambda x: temperature,
operand=None)
a = sample_top_p(subkey, a_logit, top_p, _temp) # use T1 for the first atom type
A = A.at[:, i].set(a)

lattice_params = h_al[:, atom_types:atom_types+Kl+2*6*Kl]
Expand Down Expand Up @@ -154,6 +160,8 @@ def body_fn(i, state):
w_logit = w_logit[:, :wyck_types]

key, subkey = jax.random.split(key)
if w_mask is not None:
w_logit = w_logit.at[:, w_mask[i]].set(w_logit[:, w_mask[i]] + 1e10)
w = sample_top_p(subkey, w_logit, top_p, temperature)

W = jnp.concatenate([W, w[:, None]], axis=1)
Expand All @@ -170,7 +178,11 @@ def body_fn(i, state):

key, subkey = jax.random.split(key)
a_logit = a_logit + jnp.where(atom_mask, 1e10, 0.0) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp)
a = sample_top_p(subkey, a_logit, top_p, temperature)
_temp = jax.lax.cond(i==0,
true_fun=lambda x: jnp.array(T1, dtype=float),
false_fun=lambda x: temperature,
operand=None)
a = sample_top_p(subkey, a_logit, top_p, _temp) # use T1 for the first atom type
A = jnp.concatenate([A, a[:, None]], axis=1)

lattice_params = h_al[:, atom_types:atom_types+Kl+2*6*Kl]
Expand Down
2 changes: 1 addition & 1 deletion src/wyckoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def symmetrize_atoms(g, w, x):
#https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115
def dist_to_op0x(coord):
diff = jnp.dot(symops[g-1, w, 0], jnp.array([*coord, 1])) - coord
diff -= jnp.floor(diff)
diff -= jnp.rint(diff)
return jnp.sum(diff**2)
loc = jnp.argmin(jax.vmap(dist_to_op0x)(coords))
x = coords[loc].reshape(3,)
Expand Down

0 comments on commit 5c6103f

Please sign in to comment.