Skip to content

Commit

Permalink
More type-checking updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwthompson committed Oct 10, 2024
1 parent 8d4bfeb commit 64a5e7e
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 181 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ jobs:
yammbs/
- name: Run mypy
if: ${{ matrix.python-version == '3.10' }}
run: python -m mypy -p "yammbs" --exclude "yammbs/_tests"
if: ${{ matrix.python-version == '3.12' }}
run: mypy --install-types && python -m mypy -p "yammbs" --exclude "yammbs/_tests"

- name: CodeCov
uses: codecov/codecov-action@v4
Expand Down
34 changes: 22 additions & 12 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,24 @@ versionfile_build = yammbs/_version.py
tag_prefix = 'v'

[mypy]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
plugins = pydantic.mypy,sqlmypy
warn_unused_configs = True
warn_unused_ignores = True
warn_incomplete_stub = True
show_error_codes = True

[mypy-testing.*]
disallow_untyped_defs = false
[mypy-openeye]
ignore_missing_imports = True

[mypy-tests.*]
disallow_untyped_defs = false
[mypy-rdkit.Chem]
ignore_missing_imports = True

[mypy-openff.toolkit.*]
ignore_missing_imports = True

[mypy-openff.nagl.toolkits.openff]
ignore_missing_imports = True

[mypy-openmm.*]
ignore_missing_imports = True

Expand All @@ -83,8 +84,17 @@ ignore_missing_imports = True
[mypy-openmmforcefields.generators]
ignore_missing_imports = True

[mypy-espaloma]
[mypy-sqlalchemy]
ignore_missing_imports = True

[mypy-sqlalchemy.orm]
ignore_missing_imports = True

[mypy-geometric.*]
ignore_missing_imports = True

[mypy-qcelemental]
ignore_missing_imports = True

[mypy-openff.qcsubmit.results]
ignore_missing_imports = True
6 changes: 5 additions & 1 deletion yammbs/_base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def _array_serializer(val: NDArray, nxt) -> list[float]:
return val.flatten().tolist()


CoordinateArray = Annotated[numpy.ndarray, BeforeValidator(_strip_units), WrapSerializer(_array_serializer)]
CoordinateArray = Annotated[
NDArray[numpy.float64],
BeforeValidator(_strip_units),
WrapSerializer(_array_serializer),
]

Array = CoordinateArray
5 changes: 3 additions & 2 deletions yammbs/_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import hashlib
import json
from typing import Any, ClassVar, List, Union
from typing import Any, ClassVar, Union

import numpy
from numpy.typing import NDArray
from pydantic import BaseModel, ConfigDict

FloatArrayLike = Union[List, numpy.ndarray, float]
FloatArrayLike = Union[list[float], NDArray[numpy.float64], float]


def round_floats(
Expand Down
25 changes: 12 additions & 13 deletions yammbs/_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import Dict, List

from sqlalchemy import Column, Float, ForeignKey, Integer, PickleType, String
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm import declarative_base, relationship # type: ignore[attr-defined]

from yammbs.models import MMConformerRecord, QMConformerRecord

Expand All @@ -13,7 +12,7 @@
LOGGER = logging.getLogger(__name__)


class DBQMConformerRecord(DBBase):
class DBQMConformerRecord(DBBase): # type: ignore
__tablename__ = "qm_conformers"

id = Column(Integer, primary_key=True, index=True)
Expand All @@ -26,7 +25,7 @@ class DBQMConformerRecord(DBBase):
energy = Column(Float, nullable=False)


class DBMMConformerRecord(DBBase):
class DBMMConformerRecord(DBBase): # type: ignore
__tablename__ = "mm_conformers"

id = Column(Integer, primary_key=True, index=True)
Expand All @@ -40,15 +39,15 @@ class DBMMConformerRecord(DBBase):
energy = Column(Float, nullable=False)


class DBMoleculeRecord(DBBase):
class DBMoleculeRecord(DBBase): # type: ignore
__tablename__ = "molecules"

id = Column(Integer, primary_key=True, index=True)

inchi_key = Column(String, nullable=False, index=True)
mapped_smiles = Column(String, nullable=False)

def store_qm_conformer_records(self, records: List[QMConformerRecord]):
def store_qm_conformer_records(self, records: list[QMConformerRecord]):
if not isinstance(records, list):
raise ValueError("records must be a list")
# TODO: match conformers?
Expand All @@ -60,7 +59,7 @@ def store_qm_conformer_records(self, records: List[QMConformerRecord]):
)
self.qm_conformers.append(db_record)

def store_mm_conformer_records(self, records: List[MMConformerRecord]):
def store_mm_conformer_records(self, records: list[MMConformerRecord]):
if not isinstance(records, list):
raise ValueError("records must be a list")
# TODO: match conformers?
Expand All @@ -74,7 +73,7 @@ def store_mm_conformer_records(self, records: List[MMConformerRecord]):
self.mm_conformers.append(db_record)


class DBGeneralProvenance(DBBase):
class DBGeneralProvenance(DBBase): # type: ignore
__tablename__ = "general_provenance"

key = Column(String, primary_key=True, index=True, unique=True)
Expand All @@ -83,7 +82,7 @@ class DBGeneralProvenance(DBBase):
parent_id = Column(Integer, ForeignKey("db_info.version"))


class DBSoftwareProvenance(DBBase):
class DBSoftwareProvenance(DBBase): # type: ignore
__tablename__ = "software_provenance"

key = Column(String, primary_key=True, index=True, unique=True)
Expand All @@ -92,7 +91,7 @@ class DBSoftwareProvenance(DBBase):
parent_id = Column(Integer, ForeignKey("db_info.version"))


class DBInformation(DBBase):
class DBInformation(DBBase): # type: ignore
"""A class which keeps track of the current database
settings.
"""
Expand All @@ -113,9 +112,9 @@ class DBInformation(DBBase):

def _match_conformers(
indexed_mapped_smiles: str,
db_conformers: List,
query_conformers: List,
) -> Dict[int, int]:
db_conformers: list,
query_conformers: list,
) -> dict[int, int]:
"""A method which attempts to match a set of new conformers to store with
conformers already present in the database by comparing the RMS of the
two sets.
Expand Down
37 changes: 23 additions & 14 deletions yammbs/_forcebalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
from numpy.typing import NDArray


def periodic_diff(a: float, b: float, v_periodic: float) -> float:
"""convenient function for computing the minimum difference in periodic coordinates
def periodic_diff(
a: NDArray[numpy.float64],
b: NDArray[numpy.float64],
v_periodic: float,
) -> NDArray[numpy.float64]:
"""
Convenience function for computing the minimum difference in periodic coordinates
Parameters
----------
a: np.ndarray or float
a
Reference values in a numpy array
b: np.ndarray or float
b
Target values in a numpy arrary
v_periodic: float > 0
v_periodic
Value of the periodic boundary
Returns
Expand All @@ -37,17 +43,20 @@ def periodic_diff(a: float, b: float, v_periodic: float) -> float:
return (a - b + h) % v_periodic - h


def compute_rmsd(ref: NDArray, tar: NDArray, v_periodic: float | None) -> float:
def compute_rmsd(
ref: NDArray[numpy.float64],
tar: NDArray[numpy.float64],
v_periodic: float | None = None,
) -> float:
"""
Compute the RMSD between two arrays, supporting periodic difference
"""

assert len(ref) == len(tar), "array length must match"
n = len(ref)
if n == 0:

if len(ref) == 0:
return 0.0
if v_periodic is not None:
diff = periodic_diff(ref, tar, v_periodic)
else:
diff = ref - tar
rmsd = numpy.sqrt(numpy.sum(diff**2) / n)
return rmsd

diff = ref - tar if v_periodic is None else periodic_diff(ref, tar, v_periodic)

return numpy.sqrt(numpy.sum(diff**2) / len(ref))
6 changes: 3 additions & 3 deletions yammbs/_minimize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import re
from multiprocessing import Pool
from typing import Union
from typing import Iterator

import numpy
import openmm
Expand Down Expand Up @@ -57,11 +57,11 @@ def _lazy_load_force_field(force_field_name: str) -> ForceField:


def _minimize_blob(
input: dict[str, dict[str, Union[str, numpy.ndarray]]],
input: dict[str, list],
force_field: str,
n_processes: int = 2,
chunksize=32,
) -> dict[str, list["MinimizationResult"]]:
) -> Iterator["MinimizationResult"]:
inputs = list()

inputs = [
Expand Down
81 changes: 2 additions & 79 deletions yammbs/_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A module for managing the database session."""

from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional
from typing import TYPE_CHECKING, Dict, List, NamedTuple

from yammbs._db import (
DB_VERSION,
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_software_provenance(self):
def set_provenance(
self,
general_provenance: Dict[str, str],
software_provenance: Dict[str, str],
software_provenance: Dict[str, str | None],
):
self.db_info.general_provenance = [
DBGeneralProvenance(key=key, value=value) for key, value in general_provenance.items()
Expand Down Expand Up @@ -223,80 +223,3 @@ def _mm_conformer_already_exists(
)

return records.count() > 0

def store_records_with_smiles(
self,
inchi_key: str,
records: List["MoleculeRecord"],
existing_db_record: Optional[DBMoleculeRecord] = None,
):
"""Stores a set of records which all store information for molecules with the
same SMILES representation AND the same fixed hydrogen InChI key.
Parameters
----------
inchi_key: str
The **fixed hydrogen** InChI key representation of the molecule stored in
the records.
records: List[MoleculeRecord]
The records to store.
existing_db_record: Optional[DBMoleculeRecord]
An optional existing DB record to check
"""

if existing_db_record is None:
existing_db_record = DBMoleculeRecord(
# inchi_key=inchi_key,
mapped_smiles=records[0].mapped_smiles,
qcarchive_id=records[0].qcarchive_id,
qcarchive_energy=records[0].qcarchive_energy,
)

# Retrieve the DB indexed SMILES that defines the ordering the atoms in each
# record should have and re-order the incoming records to match.
expected_smiles = existing_db_record.mapped_smiles

conformer_records = [
conformer_record for record in records for conformer_record in record.reorder(expected_smiles).conformers
]

existing_db_record.store_conformer_records(conformer_records)
self.db.add(existing_db_record)

def store_records_with_inchi_key(
self,
inchi_key: str,
records: List["MoleculeRecord"],
):
"""Stores a set of records which all store information for molecules with the
same fixed hydrogen InChI key.
Parameters
----------
inchi_key: str
The **fixed hydrogen** InChI key representation of the molecule stored in
the records.
records: List[MoleculeRecord]
The records to store.
"""

existing_db_records: List[DBMoleculeRecord] = (
self.db.query(DBMoleculeRecord).filter(DBMoleculeRecord.inchi_key == inchi_key).all()
)

db_records_by_smiles = self.map_records_by_smiles(existing_db_records)
# Sanity check that no two DB records have the same InChI key AND the
# same canonical SMILES pattern.
multiple = [smiles for smiles, dbrecords in db_records_by_smiles.items() if len(dbrecords) > 1]
if multiple:
raise RuntimeError(
"The database is not self consistent."
"There are multiple records with the same InChI key and SMILES."
f"InChI key: {inchi_key} and SMILES: {multiple}",
)
db_records_by_smiles = {k: v[0] for k, v in db_records_by_smiles.items()}

records_by_smiles = self.map_records_by_smiles(records)
for smiles, smiles_records in records_by_smiles.items():
db_record = db_records_by_smiles.get(smiles, None)
self.store_records_with_smiles(inchi_key, smiles_records, db_record)
Loading

0 comments on commit 64a5e7e

Please sign in to comment.