Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ defaults:
jobs:
tests:
runs-on: ${{ matrix.OS }}
name: "💻-${{matrix.os }} 🐍-${{ matrix.python-version }} 🗃️${{ matrix.pydantic-version }}"
name: "💻-${{matrix.os }} 🐍-${{ matrix.python-version }} 🎨-${{ matrix.py3Dmol }}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! we will need to update the branch rules after we merge this PR in. I am happy with it for now, but it might be better to call it something like "min deps" so we can keep adding things to the install step. We can wait until the next optional dep we add to do this.

strategy:
fail-fast: false
matrix:
os: ['ubuntu-latest', macos-latest]
pydantic-version: [">1"]
py3Dmol: ["no"]
python-version:
- "3.11"
- "3.12"
- "3.13"
include:
- os: "ubuntu-latest"
python-version: "3.11"
pydantic-version: "<2"
python-version: "3.12"
py3Dmol: "yes"

env:
OE_LICENSE: ${{ github.workspace }}/oe_license.txt
Expand All @@ -58,6 +58,11 @@ jobs:
python=${{ matrix.python-version }}
init-shell: bash

- name: "Install py3Dmol (optional)"
if: ${{ matrix.py3Dmol == 'yes' }}
run: micromamba install py3Dmol


- name: "Install"
run: python -m pip install --no-deps -e .

Expand Down
86 changes: 85 additions & 1 deletion gufe/mapping/ligandatommapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from __future__ import annotations

import json
from typing import Any
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
from rdkit import Chem

from gufe.components import SmallMoleculeComponent
from gufe.visualization.mapping_visualization import draw_mapping

from ..tokenization import JSON_HANDLER
from ..utils import requires_package
from ..visualization import mapping_visualization as viz
from . import AtomMapping

if TYPE_CHECKING:
import py3Dmol


class LigandAtomMapping(AtomMapping):
"""
Expand Down Expand Up @@ -194,6 +200,84 @@ def draw_to_file(self, fname: str, d2d=None):
)
)

@requires_package("py3Dmol")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that using this decorator will actually only raise an issue with py3Dmol missing when you actually use the function.

def view_3d(
self,
spheres: Optional[bool] = True,
show_atomIDs: Optional[bool] = False,
style: Optional[str] = "stick",
shift: Optional[Union[Tuple[float, float, float], NDArray[np.float64]]] = None,
) -> py3Dmol.view:
"""
Render relative transformation edge in 3D using py3Dmol.

By default matching atoms will be annotated using colored spheres.

py3Dmol is an optional dependency, it can be installed with:
pip install py3Dmol

Parameters
----------
spheres : bool, optional
Whether or not to show matching atoms as spheres.
show_atomIDs: bool, optional
Whether or not to show atom ids in the mapping visualization
style : str, optional
Style in which to represent the molecules in py3Dmol.
shift : Tuple of floats, optional
Amount to shift molB by in order to visualize the two ligands.
If None, the default shift will be estimated as the largest
intraMol distance of both mols.

Returns
-------
view : py3Dmol.view
View of the system containing both molecules in the edge.
"""
import py3Dmol

if shift is None:
shift = np.array([viz._get_max_dist_in_x(self) * 1.5, 0, 0])
else:
shift = np.array(shift)

molA = self.componentA.to_rdkit()
molB = self.componentB.to_rdkit()

# 0 * shift is the centrepoint
# shift either side of the mapping +- a shift to clear the centre view
lmol = viz._translate(molA, -1 * shift)
rmol = viz._translate(molB, +1 * shift)

view = py3Dmol.view(width=600, height=600)
view.addModel(Chem.MolToMolBlock(lmol), "molA")
view.addModel(Chem.MolToMolBlock(rmol), "molB")

if spheres:
viz._add_spheres(view, lmol, rmol, self.componentA_to_componentB)

if show_atomIDs:
view.addPropertyLabels(
"index",
{"not": {"resn": ["molA_overlay", "molA_overlay"]}},
{
"fontColor": "black",
"font": "sans-serif",
"fontSize": "10",
"showBackground": "false",
"alignment": "center",
},
)

# middle fig
view.addModel(Chem.MolToMolBlock(molA), "molA_overlay")
view.addModel(Chem.MolToMolBlock(molB), "molB_overlay")

view.setStyle({style: {}})

view.zoomTo()
return view

def with_annotations(self, annotations: dict[str, Any]) -> LigandAtomMapping:
"""Create a new mapping based on this one with extra annotations.

Expand Down
15 changes: 15 additions & 0 deletions gufe/tests/test_ligandatommapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

from .test_tokenization import GufeTokenizableTestsMixin

try:
import py3Dmol

HAS_PY3DMOL = True
except ImportError:
HAS_PY3DMOL = False


def mol_from_smiles(smiles: str) -> gufe.SmallMoleculeComponent:
m = Chem.AddHs(Chem.MolFromSmiles(smiles))
Expand Down Expand Up @@ -289,6 +296,14 @@ def test_with_fancy_annotations(simple_mapping):
assert m == m2


@pytest.mark.skipif(not HAS_PY3DMOL, reason="optional dep py3Dmol not found")
def test_visualize_3D_mapping(simple_mapping):
"""
smoke test just checking if nothing goes horribly wrong
"""
simple_mapping.view_3d()


class TestLigandAtomMappingBoundsChecks:
@pytest.fixture
def molA(self):
Expand Down
Empty file added gufe/tests/test_mapping
Empty file.
8 changes: 7 additions & 1 deletion gufe/tests/test_mapping_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
from rdkit import Chem

import gufe
from gufe.visualization.mapping_visualization import (
_get_unique_bonds_and_atoms,
_match_elements,
Expand All @@ -13,6 +12,13 @@
draw_unhighlighted_molecule,
)

try:
import py3Dmol

HAS_PY3DMOL = True
except ImportError:
HAS_PY3DMOL = False

# default colors currently used
_HIGHLIGHT_COLOR = (220 / 255, 50 / 255, 32 / 255, 1)
_CHANGED_ELEMENTS_COLOR = (0, 90 / 255, 181 / 255, 1)
Expand Down
38 changes: 38 additions & 0 deletions gufe/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe

import functools
import io
import warnings
from typing import Callable


class ensure_filelike:
Expand Down Expand Up @@ -51,3 +53,39 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
if self.do_close:
self.context.close()


# taken from openfe who shamelessly borrowed from openff.toolkit
def requires_package(package_name: str) -> Callable:
"""
Helper function to denote that a function requires some optional
dependency. A function decorated with this decorator will raise
``MissingDependencyError`` if the package is not found by
``importlib.import_module()``.

Parameters
----------
package_name : str
The directory path to enter within the context
Raises
------
MissingDependencyError
"""

def test_import_for_require_package(function: Callable) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs):
import importlib

try:
importlib.import_module(package_name)
except (ImportError, ModuleNotFoundError):
raise ImportError(function.__name__ + " requires package: " + package_name)
except Exception as e:
raise e

return function(*args, **kwargs)

return wrapper

return test_import_for_require_package
103 changes: 102 additions & 1 deletion gufe/visualization/mapping_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
from collections.abc import Collection
from itertools import chain
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union

import numpy as np
from numpy.typing import NDArray
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Geometry.rdGeometry import Point3D

from ..utils import requires_package

if TYPE_CHECKING:
import py3Dmol

# highlight core element changes differently from unique atoms
# RGBA color value needs to be between 0 and 1, so divide by 255
Expand Down Expand Up @@ -296,3 +304,96 @@ def draw_unhighlighted_molecule(mol, d2d=None):
bond_colors=[{}],
highlight_color=red,
)


def _translate(mol: Chem.Mol, shift: Union[Tuple[float, float, float], NDArray[np.float64]]) -> Chem.Mol:
"""
shifts the molecule by the shift vector

Parameters
----------
mol : Chem.Mol
rdkit mol that get shifted
shift : Tuple[float, float, float]
shift vector

Returns
-------
Chem.Mol
shifted Molecule (copy of original one)
"""
mol = Chem.Mol(mol)
conf = mol.GetConformer()
for i, atom in enumerate(mol.GetAtoms()):
x, y, z = conf.GetAtomPosition(i)
point = Point3D(x + shift[0], y + shift[1], z + shift[2])
conf.SetAtomPosition(i, point)
return mol


@requires_package("py3Dmol")
def _add_spheres(view, mol1: Chem.Mol, mol2: Chem.Mol, mapping: Dict[int, int]) -> None:
"""
will add spheres according to mapping to the view. (inplace!)

Parameters
----------
view : py3Dmol.view
view to be edited
mol1 : Chem.Mol
molecule 1 of the mapping
mol2 : Chem.Mol
molecule 2 of the mapping
mapping : Dict[int, int]
mapping of atoms from mol1 to mol2
"""
from matplotlib import pyplot as plt
from matplotlib.colors import rgb2hex

# Get colourmap of size mapping
cmap = plt.get_cmap("hsv", len(mapping))
for i, pair in enumerate(mapping.items()):
p1 = mol1.GetConformer().GetAtomPosition(pair[0])
p2 = mol2.GetConformer().GetAtomPosition(pair[1])
color = rgb2hex(cmap(i))
view.addSphere(
{
"center": {"x": p1.x, "y": p1.y, "z": p1.z},
"radius": 0.6,
"color": color,
"alpha": 0.8,
}
)
view.addSphere(
{
"center": {"x": p2.x, "y": p2.y, "z": p2.z},
"radius": 0.6,
"color": color,
"alpha": 0.8,
}
)


def _get_max_dist_in_x(atom_mapping) -> float:
"""helper function
find the correct mol shift, so no overlap happens in vis

Returns
-------
float
maximal size of mol in x dimension
"""
posA = atom_mapping.componentA.to_rdkit().GetConformer().GetPositions()
posB = atom_mapping.componentB.to_rdkit().GetConformer().GetPositions()
max_d = []

for pos in [posA, posB]:
d = np.zeros(shape=(len(pos), len(pos)))
for i, pA in enumerate(pos):
for j, pB in enumerate(pos[i:], start=i):
d[i, j] = (pB - pA)[0]

max_d.append(np.max(d))

estm = float(np.round(max(max_d), 1))
return estm if (estm > 5) else 5
23 changes: 23 additions & 0 deletions news/add_mapping_3d_viz.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* Added ``LigandAtomMapping.view_3d()`` method (previously implemented as ``openfe.utils.visualization_3D.view_mapping()`` (`PR #646 <https://github.com/OpenFreeEnergy/gufe/pull/646>`_).

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
Loading