Skip to content

Commit a9b6d1a

Browse files
committed
replace testdata ouputs with loss
1 parent 0dfb5e7 commit a9b6d1a

File tree

1 file changed

+18
-39
lines changed

1 file changed

+18
-39
lines changed

src/main.py

+18-39
Original file line numberDiff line numberDiff line change
@@ -162,48 +162,27 @@
162162
else:
163163
pass
164164

165-
print("\n========== Print out some test data for the given space group ==========")
165+
print("\n========== Calculate the loss of test dataset ==========")
166166
import numpy as np
167167
np.set_printoptions(threshold=np.inf)
168168

169-
G, L, XYZ, A, W = test_data
170-
print (G.shape, L.shape, XYZ.shape, A.shape, W.shape)
171-
172-
idx = jnp.where(G==args.spacegroup,size=5)
173-
G = G[idx]
174-
L = L[idx]
175-
XYZ = XYZ[idx]
176-
A = A[idx]
177-
W = W[idx]
178-
179-
num_sites = jnp.sum(A!=0, axis=1)
180-
print ("num_sites:", num_sites)
181-
@jax.vmap
182-
def lookup(G, W):
183-
return mult_table[G-1, W] # (n_max, )
184-
M = lookup(G, W) # (batchsize, n_max)
185-
num_atoms = M.sum(axis=-1)
186-
print ("num_atoms:", num_atoms)
187-
188-
print ("G:", G)
189-
print ("A:\n", A)
190-
for a in A:
191-
print([element_list[i] for i in a])
192-
print ("W:\n",W)
193-
print ("XYZ:\n",XYZ)
194-
195-
outputs = jax.vmap(transformer, (None, None, 0, 0, 0, 0, 0, None), (0))(params, key, G, XYZ, A, W, M, False)
196-
print ("outputs.shape", outputs.shape)
197-
198-
h_al = outputs[:, 1::5, :] # (:, n_max, :)
199-
a_logit = h_al[:, :, :args.atom_types]
200-
l_logit, mu, sigma = jnp.split(h_al[jnp.arange(h_al.shape[0]), num_sites,
201-
args.atom_types:args.atom_types+args.Kl+2*6*args.Kl],
202-
[args.Kl, args.Kl+6*args.Kl], axis=-1)
203-
print ("L:\n",L)
204-
print ("exp(l_logit):\n", jnp.exp(l_logit))
205-
print ("mu:\n", mu.reshape(-1, args.Kl, 6))
206-
print ("sigma:\n", sigma.reshape(-1, args.Kl, 6))
169+
test_G, test_L, test_XYZ, test_A, test_W = test_data
170+
print (test_G.shape, test_L.shape, test_XYZ.shape, test_A.shape, test_W.shape)
171+
test_loss = 0
172+
num_samples = len(test_L)
173+
num_batches = math.ceil(num_samples / args.batchsize)
174+
for batch_idx in range(num_batches):
175+
start_idx = batch_idx * args.batchsize
176+
end_idx = min(start_idx + args.batchsize, num_samples)
177+
G, L, XYZ, A, W = test_G[start_idx:end_idx], \
178+
test_L[start_idx:end_idx], \
179+
test_XYZ[start_idx:end_idx], \
180+
test_A[start_idx:end_idx], \
181+
test_W[start_idx:end_idx]
182+
loss, _ = jax.jit(loss_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False)
183+
test_loss += loss
184+
test_loss = test_loss / num_batches
185+
print ("evaluating loss on test data:" , test_loss)
207186

208187
print("\n========== Start sampling ==========")
209188
jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice

0 commit comments

Comments
 (0)