|
162 | 162 | else:
|
163 | 163 | pass
|
164 | 164 |
|
165 |
| - print("\n========== Print out some test data for the given space group ==========") |
| 165 | + print("\n========== Calculate the loss of test dataset ==========") |
166 | 166 | import numpy as np
|
167 | 167 | np.set_printoptions(threshold=np.inf)
|
168 | 168 |
|
169 |
| - G, L, XYZ, A, W = test_data |
170 |
| - print (G.shape, L.shape, XYZ.shape, A.shape, W.shape) |
171 |
| - |
172 |
| - idx = jnp.where(G==args.spacegroup,size=5) |
173 |
| - G = G[idx] |
174 |
| - L = L[idx] |
175 |
| - XYZ = XYZ[idx] |
176 |
| - A = A[idx] |
177 |
| - W = W[idx] |
178 |
| - |
179 |
| - num_sites = jnp.sum(A!=0, axis=1) |
180 |
| - print ("num_sites:", num_sites) |
181 |
| - @jax.vmap |
182 |
| - def lookup(G, W): |
183 |
| - return mult_table[G-1, W] # (n_max, ) |
184 |
| - M = lookup(G, W) # (batchsize, n_max) |
185 |
| - num_atoms = M.sum(axis=-1) |
186 |
| - print ("num_atoms:", num_atoms) |
187 |
| - |
188 |
| - print ("G:", G) |
189 |
| - print ("A:\n", A) |
190 |
| - for a in A: |
191 |
| - print([element_list[i] for i in a]) |
192 |
| - print ("W:\n",W) |
193 |
| - print ("XYZ:\n",XYZ) |
194 |
| - |
195 |
| - outputs = jax.vmap(transformer, (None, None, 0, 0, 0, 0, 0, None), (0))(params, key, G, XYZ, A, W, M, False) |
196 |
| - print ("outputs.shape", outputs.shape) |
197 |
| - |
198 |
| - h_al = outputs[:, 1::5, :] # (:, n_max, :) |
199 |
| - a_logit = h_al[:, :, :args.atom_types] |
200 |
| - l_logit, mu, sigma = jnp.split(h_al[jnp.arange(h_al.shape[0]), num_sites, |
201 |
| - args.atom_types:args.atom_types+args.Kl+2*6*args.Kl], |
202 |
| - [args.Kl, args.Kl+6*args.Kl], axis=-1) |
203 |
| - print ("L:\n",L) |
204 |
| - print ("exp(l_logit):\n", jnp.exp(l_logit)) |
205 |
| - print ("mu:\n", mu.reshape(-1, args.Kl, 6)) |
206 |
| - print ("sigma:\n", sigma.reshape(-1, args.Kl, 6)) |
| 169 | + test_G, test_L, test_XYZ, test_A, test_W = test_data |
| 170 | + print (test_G.shape, test_L.shape, test_XYZ.shape, test_A.shape, test_W.shape) |
| 171 | + test_loss = 0 |
| 172 | + num_samples = len(test_L) |
| 173 | + num_batches = math.ceil(num_samples / args.batchsize) |
| 174 | + for batch_idx in range(num_batches): |
| 175 | + start_idx = batch_idx * args.batchsize |
| 176 | + end_idx = min(start_idx + args.batchsize, num_samples) |
| 177 | + G, L, XYZ, A, W = test_G[start_idx:end_idx], \ |
| 178 | + test_L[start_idx:end_idx], \ |
| 179 | + test_XYZ[start_idx:end_idx], \ |
| 180 | + test_A[start_idx:end_idx], \ |
| 181 | + test_W[start_idx:end_idx] |
| 182 | + loss, _ = jax.jit(loss_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False) |
| 183 | + test_loss += loss |
| 184 | + test_loss = test_loss / num_batches |
| 185 | + print ("evaluating loss on test data:" , test_loss) |
207 | 186 |
|
208 | 187 | print("\n========== Start sampling ==========")
|
209 | 188 | jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice
|
|
0 commit comments