Skip to content

Commit

Permalink
improved xtb driver and testing suite
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyolen committed Dec 2, 2023
1 parent 1c219bc commit a276df3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
9 changes: 3 additions & 6 deletions molli/external/xtb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Generator, Callable
from ..chem import Molecule
from ..config import BACKUP_DIR
from ..config import SCRATCH_DIR
from subprocess import run, PIPE
from pathlib import Path
import attrs
Expand Down Expand Up @@ -134,6 +132,8 @@ def __contains__(self, key):

class XTBDriver:
def __init__(self, nprocs: int = 1) -> None:
from ..config import BACKUP_DIR
from ..config import SCRATCH_DIR
self.nprocs = nprocs
self.backup_dir = BACKUP_DIR
self.scratch_dir = SCRATCH_DIR
Expand All @@ -152,14 +152,11 @@ def optimize(
crit: str = "normal",
xtbinp: str = "",
maxiter: int = 50,
# xyz_name: str = "mol", # do we need this anymore?

):
assert isinstance(M, Molecule), "User did not pass a Molecule object!"
# print(self.nprocs)
inp = JobInput(
M.name,
command=f"""xtb input.xyz --{method} --opt {crit} --charge {M.charge} {"--input param.inp" if xtbinp else ""} -P {self.nprocs}""",
command=f"""xtb input.xyz --{method} --opt {crit} --charge {M.charge} --iterations {maxiter} {"--input param.inp" if xtbinp else ""} -P {self.nprocs}""",
files={"input.xyz": M.dumps_xyz().encode()}
)

Expand Down
78 changes: 62 additions & 16 deletions molli_test/test_external_xtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,51 @@

import numpy as np
import molli as ml
import subprocess
import os
from pathlib import Path
import shutil

from joblib import delayed,Parallel
from molli.external import XTBDriver
from molli.config import BACKUP_DIR
from molli.config import SCRATCH_DIR

# how to handle this?
# try:
# from rdkit.Chem.PropertyMol import PropertyMol
# except:
# _RDKIT_INSTALLED = False
# else:
# _RDKIT_INSTALLED = True
# try:
# from rdkit.Chem.Draw import IPythonConsole
# except:
# _IPYTHON_INSTALLED = False
# else:
# from molli.external import _rdkit
# _IPYTHON_INSTALLED = True

# a little clunky but will check if xTB is installed
try:
out = subprocess.run(['xtb', '--version'], capture_output=True)
if out.stderr == b'normal termination of xtb\n':
_XTB_INSTALLED = True
except FileNotFoundError:
_XTB_INSTALLED = False

_CURRENT_BACKUP_DIR = BACKUP_DIR
_CURRENT_SCRATCH_DIR = SCRATCH_DIR
_TEST_BACKUP_DIR: Path = ml.config.HOME / "test_backup"
_TEST_SCRATCH_DIR: Path = ml.config.HOME / "test_scratch"

def prep_dirs():
ml.config.BACKUP_DIR = _TEST_BACKUP_DIR
ml.config.SCRATCH_DIR = _TEST_SCRATCH_DIR

def cleanup_dirs():
shutil.rmtree(_TEST_BACKUP_DIR)
shutil.rmtree(_TEST_SCRATCH_DIR)
ml.config.BACKUP_DIR = _CURRENT_BACKUP_DIR
ml.config.SCRATCH_DIR = _CURRENT_SCRATCH_DIR

class XTBTC(ut.TestCase):
"""This test suite is for the basic installation stuff"""

# @ut.skipUnless(_RDKIT_INSTALLED, "RDKit is not installed in current environment.")
# @ut.skipUnless(_IPYTHON_INSTALLED, "IPython is not installed in current environment.")

@ut.skipUnless(_XTB_INSTALLED, "xtb is not installed in current environment.")
def test_xtb_optimize(self):

prep_dirs()

# test with cinchonidine library

mlib1 = ml.MoleculeLibrary(ml.files.cinchonidines)
#Cinchonidine Charges = 1
for m in mlib1:
Expand All @@ -52,4 +70,32 @@ def test_xtb_optimize(self):
for m1, m2 in zip(mlib1, res):
self.assertNotAlmostEqual(np.linalg.norm(m1.coords - m2.coords), 0) # make sure the atom coordinates have moved

cleanup_dirs()

@ut.skipUnless(_XTB_INSTALLED, "xtb is not installed in current environment.")
def test_xtb_energy(self):

prep_dirs()

# test with cinchonidine library
mlib1 = ml.MoleculeLibrary(ml.files.cinchonidines)
#Cinchonidine Charges = 1
for m in mlib1:
m.charge = 1

# testing in serial works fine
xtb = XTBDriver(nprocs=4)
res = [xtb.energy(m) for m in mlib1]

# we will spot check several of these based on output from separate call to xtb
for i, energy in enumerate(res):
if mlib1[i].name == '1_5_c':
self.assertEqual(energy, -105.591624613587)
if mlib1[i].name == '2_12_c':
self.assertEqual(energy, -102.911928497077)
if mlib1[i].name == '10_5_c':
self.assertEqual(energy, -116.035733938867)

cleanup_dirs()


0 comments on commit a276df3

Please sign in to comment.