Skip to content
This repository was archived by the owner on Dec 8, 2024. It is now read-only.

Commit 8f997cf

Browse files
authored
Fixed ARAD wrapper bug (#6)
1 parent 9b2e8ce commit 8f997cf

File tree

2 files changed

+81
-7
lines changed

2 files changed

+81
-7
lines changed

src/wrappers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .fkernels import fget_vector_kernels_laplacian
2828

2929
from .arad_kernels import get_atomic_kernels_arad
30+
from .arad_kernels import get_atomic_symmetric_kernels_arad
3031

3132

3233
def get_atomic_kernels_laplacian(mols1, mols2, sigmas):
@@ -100,13 +101,13 @@ def get_atomic_kernels_gaussian(mols1, mols2, sigmas):
100101
def arad_kernels(mols1, mols2, sigmas,
101102
width=0.2, cut_distance=5.0, r_width=1.0, c_width=0.5):
102103

103-
amax = mol1[0].arad_descriptor.shape[1]
104+
amax = mols1[0].arad_representation.shape[0]
104105

105106
nm1 = len(mols1)
106107
nm2 = len(mols2)
107108

108-
X1 = np.array([mol.arad_descriptor for mol in mols1]).reshape((nm1,amax,5,amax))
109-
X2 = np.array([mol.arad_descriptor for mol in mols2]).reshape((nm2,amax,5,amax))
109+
X1 = np.array([mol.arad_representation for mol in mols1]).reshape((nm1,amax,5,amax))
110+
X2 = np.array([mol.arad_representation for mol in mols2]).reshape((nm2,amax,5,amax))
110111

111112
Z1 = [mol.nuclear_charges for mol in mols1]
112113
Z2 = [mol.nuclear_charges for mol in mols2]
@@ -120,15 +121,15 @@ def arad_kernels(mols1, mols2, sigmas,
120121
def arad_symmetric_kernels(mols1, sigmas,
121122
width=0.2, cut_distance=5.0, r_width=1.0, c_width=0.5):
122123

123-
amax = mol1[0].arad_descriptor.shape[1]
124+
amax = mols1[0].arad_representation.shape[0]
124125

125126
nm1 = len(mols1)
126127

127-
X1 = np.array([mol.arad_descriptor for mol in mols1]).reshape((nm1,amax,5,amax))
128+
X1 = np.array([mol.arad_representation for mol in mols1]).reshape((nm1,amax,5,amax))
128129

129130
Z1 = [mol.nuclear_charges for mol in mols1]
130131

131-
K = get_symmetric_atomic_kernels_arad(X1, Z1, sigmas, \
132+
K = get_atomic_symmetric_kernels_arad(X1, Z1, sigmas, \
132133
width=width, cut_distance=cut_distance, r_width=r_width, c_width=c_width)
133134

134135
return K

tests/test_wrappers.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,74 @@
1-
# Test
1+
import qml
2+
import numpy as np
3+
4+
from qml.wrappers import arad_kernels, arad_symmetric_kernels
5+
6+
def get_energies(filename):
7+
""" Returns a dictionary with heats of formation for each xyz-file.
8+
"""
9+
10+
f = open(filename, "r")
11+
lines = f.readlines()
12+
f.close()
13+
14+
energies = dict()
15+
16+
for line in lines:
17+
tokens = line.split()
18+
19+
xyz_name = tokens[0]
20+
hof = float(tokens[1])
21+
22+
energies[xyz_name] = hof
23+
24+
return energies
25+
26+
27+
def test_arad_wrapper():
28+
29+
# Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
30+
data = get_energies("tests/hof_qm7.txt")
31+
32+
# Generate a list of qml.Compound() objects
33+
mols = []
34+
35+
for xyz_file in sorted(data.keys())[:50]:
36+
37+
# Initialize the qml.Compound() objects
38+
mol = qml.Compound(xyz="tests/qm7/" + xyz_file)
39+
40+
# Associate a property (heat of formation) with the object
41+
mol.properties = data[xyz_file]
42+
43+
# This is a Molecular Coulomb matrix sorted by row norm
44+
mol.generate_arad_representation(size=23)
45+
46+
mols.append(mol)
47+
48+
49+
# Shuffle molecules
50+
np.random.seed(666)
51+
np.random.shuffle(mols)
52+
53+
# Make training and test sets
54+
n_test = 10
55+
n_train = 40
56+
57+
training = mols[:n_train]
58+
test = mols[-n_test:]
59+
60+
sigmas = [10.0, 100.0]
61+
62+
63+
K1 = arad_symmetric_kernels(training, sigmas)
64+
assert np.all(K1 > 0.0), "ERROR: ARAD symmetric kernel negative"
65+
assert np.invert(np.all(np.isnan(K1))), "ERROR: ARAD symmetric kernel contains NaN"
66+
67+
68+
K2 = arad_kernels(training, test, sigmas)
69+
assert np.all(K2 > 0.0), "ERROR: ARAD symmetric kernel negative"
70+
assert np.invert(np.all(np.isnan(K2))), "ERROR: ARAD symmetric kernel contains NaN"
71+
72+
if __name__ == "__main__":
73+
74+
test_arad_wrapper()

0 commit comments

Comments
 (0)