Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added improved node and edge matching #61

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ molli_xt.*.so
# Files not to be uploaded to repo
**/_temp_*

/.vscode/c_cpp_properties.json
/.vscode/c_cpp_properties.json
molli/chem/molli_dev.code-workspace
36 changes: 27 additions & 9 deletions molli/chem/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,9 @@ def set_mol2_type(self, m2t: str):
match mol2_type:
case "4":
if self.element != Element.N:
raise NotImplementedError(f"{mol2_type} not implemented for {mol2_elt}, only N")
raise NotImplementedError(
f"{mol2_type} not implemented for {mol2_elt}, only N"
)
else:
self.atype = AtomType.N_Ammonium
self.geom = AtomGeom.R4_Tetrahedral
Expand Down Expand Up @@ -547,7 +549,9 @@ def set_mol2_type(self, m2t: str):
case _ if mol2_elt == "Du":
# This case if to handle Du.X
self.element = (
Element[mol2_type] if mol2_type in Element._member_names_ else Element.Unknown
Element[mol2_type]
if mol2_type in Element._member_names_
else Element.Unknown
)
self.atype = AtomType.Dummy

Expand Down Expand Up @@ -577,15 +581,21 @@ def get_mol2_type(self):
case Element.C, _, _:
if self.atype == AtomType.Aromatic:
return f"{self.element.symbol}.ar"
elif (self.atype == AtomType.C_Guanidinium) & (self.geom == AtomGeom.R3_Planar):
elif (self.atype == AtomType.C_Guanidinium) & (
self.geom == AtomGeom.R3_Planar
):
return f"{self.element.symbol}.cat"
else:
return f"{self.element.symbol}"

case Element.N, _, _:
if (self.atype == AtomType.N_Ammonium) & (self.geom == AtomGeom.R4_Tetrahedral):
if (self.atype == AtomType.N_Ammonium) & (
self.geom == AtomGeom.R4_Tetrahedral
):
return f"{self.element.symbol}.4"
elif (self.atype == AtomType.N_Amide) & (self.geom == AtomGeom.R3_Planar):
elif (self.atype == AtomType.N_Amide) & (
self.geom == AtomGeom.R3_Planar
):
return f"{self.element.symbol}.am"
elif self.atype == AtomType.Aromatic:
return f"{self.element.symbol}.ar"
Expand All @@ -601,9 +611,13 @@ def get_mol2_type(self):
return f"{self.element.symbol}"

case Element.S, _, _:
if (self.atype == AtomType.O_Sulfoxide) & (self.geom == AtomGeom.R3_Pyramidal):
if (self.atype == AtomType.O_Sulfoxide) & (
self.geom == AtomGeom.R3_Pyramidal
):
return f"{self.element.symbol}.O"
elif (self.atype == AtomType.O_Sulfone) & (self.geom == AtomGeom.R4_Tetrahedral):
elif (self.atype == AtomType.O_Sulfone) & (
self.geom == AtomGeom.R4_Tetrahedral
):
return f"{self.element.symbol}.O2"
else:
return f"{self.element.symbol}"
Expand Down Expand Up @@ -689,7 +703,9 @@ def __init__(
self._atoms = list(Atom(a) for a in atoms)

case _:
raise NotImplementedError(f"Cannot interpret {other} of type {type(other)}")
raise NotImplementedError(
f"Cannot interpret {other} of type {type(other)}"
)

def __repr__(self) -> str:
return f"{type(self).__name__}(name={self.name!r}, formula={self.formula!r})"
Expand Down Expand Up @@ -796,7 +812,9 @@ def index_atom(self, _a: Atom) -> int:
# ) -> Generator[Atom, None, None]:
# return map(self.get_atom, atoms)

def yield_atoms_by_element(self, elt: Element | str | int) -> Generator[Atom, None, None]:
def yield_atoms_by_element(
self, elt: Element | str | int
) -> Generator[Atom, None, None]:
for a in self.atoms:
if a.element == Element.get(elt):
yield a
Expand Down
218 changes: 215 additions & 3 deletions molli/chem/bond.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations
from . import Atom, Element, AtomLike, Promolecule, PromoleculeLike
from . import (
Atom,
AtomType,
AtomStereo,
Element,
AtomLike,
Promolecule,
PromoleculeLike,
)
from dataclasses import dataclass, field, KW_ONLY
from typing import Iterable, List, Generator, Tuple, Any
from typing import Iterable, List, Generator, Tuple, Any, Callable
from copy import deepcopy
from enum import IntEnum
from collections import deque
Expand All @@ -10,6 +18,7 @@
import attrs
from bidict import bidict
from functools import cache
import networkx as nx


class BondType(IntEnum):
Expand Down Expand Up @@ -229,7 +238,9 @@ def __init__(

if isinstance(other, Connectivity):
atom_map = {other.atoms[i]: self.atoms[i] for i in range(self.n_atoms)}
self._bonds = list(b.evolve(a1=atom_map[b.a1], a2=atom_map[b.a2]) for b in other.bonds)
self._bonds = list(
b.evolve(a1=atom_map[b.a1], a2=atom_map[b.a2]) for b in other.bonds
)
else:
self._bonds = list()

Expand Down Expand Up @@ -398,3 +409,204 @@ def is_bond_in_ring(self, _b: Bond):
return True

return False

def to_nxgraph(self) -> nx.Graph:
"""
Converts an insantce of Connectivity class into networkx object

Returns:
--------
nx_mol: nx.Graph()
instance of Networkx Graph object
Notes:
------
In latest version, all the atom and bond attributes are added to Networkx Graph.
"""
# older working version:
# with ml.aux.timeit("Converting molli Connectivity into networkx object"):
# nx_mol = nx.Graph()
# for atom in self.atoms:
# nx_mol.add_node(atom, element=atom.element, label=atom.label, isotope=atom.isotope)
# for bond in self.bonds:
# nx_mol.add_edge(bond.a1, bond.a2, order = bond.order)

nx_mol = nx.Graph()
for atom in self.atoms:
nx_mol.add_node(atom, **atom.as_dict()) # recursion?

# TODO: re-run test examples

# nx_mol.add_nodes_from(self.atoms, **self.atom.as_dict())

for bond in self.bonds:
nx_mol.add_edge(bond.a1, bond.a2, **bond.as_dict())

# didn't work: b should be 3-tuple
# nx_mol.add_edges_from(
# [b.as_tuple() for b in self.bonds] # , **self.bonds[0].as_dict()
# )
return nx_mol

def find_cycle_containing_atom(self, start: AtomLike) -> list:
"""
Finds the first cycle containing "start" atom
Parameters:
-----------
start: ml.chem.AtomLike
atom or its atomic index or a unique identifier for starting searching for cycles (loops)
Returns:
--------
cycle: list
the first found cycle that countains "start" atom

Current implementation using networkx grpah. Should be rewritten w/o any extra dependencies.
"""
atom = next(self.yield_atoms_by_element(start))
nx_mol = self.to_nxgraph()

for cycle in nx.cycle_basis(nx_mol, atom):
if atom in cycle:
return cycle

@staticmethod
def _node_match(a1: dict, a2: dict) -> bool:
# print({x: a1[x] == a2[x] for x in a1 if x in a2})
"""
Callable helper function that compares attributes of the nodes(atoms) in nx.isomorphism.GraphMatcher().
Returns True is nodes(atoms) are considered equal, False otherwise.
For further information: refer to nx.isomorphism.GraphMatcher() documentation.
"""

# TODO: which attributes might be query?

if a2["element"] != Element.Unknown and a1["element"] != a2["element"]:
# print("element:", a1["element"], a2["element"])
return False

if (
a2["isotope"] is not None
# and a2["isotope"] is not None # think about it
and a1["isotope"] != a2["isotope"]
):
# print("isotope:", a1["isotope"], a2["isotope"])
return False

# NOTE "geom" and "label" are not compared

if a2["stereo"] != AtomStereo.Unknown and a1["stereo"] != a2["stereo"]:
# NOTE: no queries for now
# print("stereo:", a1["stereo"], a2["stereo"])
return False

if a1["atype"] != AtomType.Unknown and a2["atype"] != a2["atype"]:
# TODO: add groups for queries
# print("atype:", a1["atype"], a2["atype"])
return False

return True

@staticmethod
def _edge_match(e1: dict, e2: dict) -> bool: # TODO: needs improving!
"""
Callable helper function that compares attributes of the edges(bonds) in nx.isomorphism.GraphMatcher().
Returns True is edges(bonds) are considered equal, False otherwise.
For further information: refer to nx.isomorphism.GraphMatcher() documentation.
"""
# print({x: e1[x] == e2[x] for x in e1 if x in e2})

match e2["btype"]:
case BondType.Unknown:
pass
case BondType.Single | BondType.Double | BondType.Triple:
# should work for aromatic and resonating structures
if e1["btype"] < e2["btype"]:
# print("btype 1-3:", e1["btype"], e2["btype"])
return False
case BondType.Aromatic | BondType.Amide:
if e1["btype"] != e2["btype"]:
# print("btype aromatic or amide:", e1["btype"], e2["btype"])
return False
case BondType.NotConnected:
# print("bond not connected")
return False
case BondType.Dummy | _:
raise NotImplementedError

match e2["stereo"]:
case BondStereo.Unknown:
pass
case _:
if e1["stereo"] != e2["stereo"]:
# print("bond stereo not equal")
return False

match e2["label"]:
case None:
pass
case _:
if e1["label"] != e2["label"]:
# print("blabel:", e1["label"], e2["label"])
return False

return True

@staticmethod
def _edge_match_debug(e1: dict, e2: dict):
res = Connectivity._edge_match(e1, e2)
print("Bonds")
print(f"{e1}\n{e2}\n{res}\n")
# print(e1["btype"], e2["btype"], res)
return res

@staticmethod
def _node_match_debug(a1, a2):
res = Connectivity._node_match(a1, a2)
print("Atoms")
print(f"{a1}\n{a2}\n{res}\n")
# print(a1["atype"], a2["atype"], res)
return res

def match(
self,
pattern: Connectivity,
/,
*,
node_match: Callable[[dict, dict], bool] | None = None,
edge_match: Callable[[dict, dict], bool] | None = None,
) -> Generator[dict | Any, None, None]:
# TODO: add new parameters to docs
"""
Checks two molli connectivities for isomorphism.
Yields generator over subgraph isomorphism mappings.

```python
for mapping in connectivity.match(pattern):
...
```
Parameters:
-----------
pattern: Connectivity
query-Connectivity (Molecule) to match with given Connectivity
"""
nx_pattern = pattern.to_nxgraph()
nx_source = self.to_nxgraph()

if not node_match:
node_match = self._node_match
if not edge_match:
edge_match = self._edge_match

matcher = nx.isomorphism.GraphMatcher(
nx_source,
nx_pattern,
node_match=node_match,
edge_match=edge_match,
)

for ismorphism in matcher.subgraph_isomorphisms_iter():
yield {v: k for k, v in ismorphism.items()}

def connect(self, _a1: AtomLike, _a2: AtomLike, **kwds):
a1, a2 = self.get_atoms(_a1, _a2)
self.append_bond(b := Bond(a1, a2, **kwds))
return b
Loading
Loading