diff --git a/datamol/viz/_lasso_highlight.py b/datamol/viz/_lasso_highlight.py index 29b87d10..7298d54f 100644 --- a/datamol/viz/_lasso_highlight.py +++ b/datamol/viz/_lasso_highlight.py @@ -6,7 +6,7 @@ # - possibility to do this for multiple target molecules at once # - have the option to write to a file like to_image -from typing import List, Iterator, Tuple, Union, Optional, Any, cast +from typing import List, Dict, Iterator, Tuple, Union, Optional, Any, cast from collections import defaultdict from collections import namedtuple @@ -400,6 +400,10 @@ def lasso_highlight_image( line_width: int = 2, scale_padding: float = 1.0, verbose: bool = False, + highlight_atoms: Optional[List[List[int]]] = None, + highlight_bonds: Optional[List[List[int]]] = None, + highlight_atom_colors: Optional[List[Dict[int, DatamolColor]]] = None, + highlight_bond_colors: Optional[List[Dict[int, DatamolColor]]] = None, **kwargs: Any, ): """Create an image of a list of molecules with substructure matches using lasso-based highlighting. @@ -408,7 +412,7 @@ def lasso_highlight_image( Args: target_molecules: One or a list of molecules to be highlighted. search_molecules: The substructure to be highlighted. - atom_indices: Atom indices to be highlighted substructure. + atom_indices: Atom indices to be highlighted as substructure using the lasso visualization. legends: A string or a list of string as legend for every molecules. n_cols: Number of molecules per column. mol_size: The size of the image to be returned @@ -421,6 +425,10 @@ def lasso_highlight_image( line_width: width of drawn lines. scale_padding: Padding around the molecule when drawing to scale. verbose: Whether to print the verbose information. + highlight_atoms: The atoms to highlight, a list for each molecule. It's the `highlightAtoms` argument of the RDKit drawer object. + highlight_bonds: The bonds to highlight, a list for each molecule. It's the `highlightBonds` argument of the RDKit drawer object. + highlight_atom_colors: The colors to use for highlighting atoms, a list of dict mapping atom index to color for each molecule. + highlight_bond_colors: The colors to use for highlighting bonds, a list of dict mapping bond index to color for each molecule. **kwargs: Additional arguments to pass to the drawing function. See RDKit documentation related to `MolDrawOptions` for more details at https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html. @@ -551,9 +559,38 @@ def lasso_highlight_image( # EN: the following is edge-case free after trying 6 different logics, but may break if RDKit changes the way it draws molecules scaling_val = Point2D(scale_padding, scale_padding) + if isinstance(highlight_atoms, list) and isinstance(highlight_atoms[0], int): + highlight_atoms = [highlight_atoms] * len(target_molecules) + if isinstance(highlight_bonds, list) and isinstance(highlight_bonds[0], int): + highlight_bonds = [highlight_bonds] * len(target_molecules) + if isinstance(highlight_atom_colors, dict): + highlight_atom_colors = [highlight_atom_colors] * len(target_molecules) + if isinstance(highlight_bond_colors, dict): + highlight_bond_colors = [highlight_bond_colors] * len(target_molecules) + + # make sure we are using rdkit colors + if highlight_atom_colors is not None: + highlight_atom_colors = [ + {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors + ] + if highlight_bond_colors is not None: + highlight_bond_colors = [ + {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_bond_colors + ] + + kwargs["highlightAtoms"] = highlight_atoms + kwargs["highlightBonds"] = highlight_bonds + kwargs["highlightAtomColors"] = highlight_atom_colors + kwargs["highlightBondColors"] = highlight_bond_colors + try: - drawer.DrawMolecules(mols_to_draw, legends=legends, **kwargs) - except Exception: + drawer.DrawMolecules( + mols_to_draw, + legends=legends, + **kwargs, + ) + except Exception as e: + logger.error(e) raise ValueError( "Failed to draw molecules. Some arguments neither match expected MolDrawOptions, nor DrawMolecule inputs. Please check the input arguments." ) @@ -567,8 +604,18 @@ def lasso_highlight_image( h_pos, w_pos = np.unravel_index(ind, (n_rows, n_cols)) offset_x = int(w_pos * mol_size[0]) offset_y = int(h_pos * mol_size[1]) + + ind_kwargs = kwargs.copy() + if isinstance(ind_kwargs["highlightAtoms"], list): + ind_kwargs["highlightAtoms"] = ind_kwargs["highlightAtoms"][ind] + if isinstance(ind_kwargs["highlightAtomColors"], list): + ind_kwargs["highlightAtomColors"] = ind_kwargs["highlightAtomColors"][ind] + if isinstance(ind_kwargs["highlightBonds"], list): + ind_kwargs["highlightBonds"] = ind_kwargs["highlightBonds"][ind] + if isinstance(ind_kwargs["highlightBondColors"], list): + ind_kwargs["highlightBondColors"] = ind_kwargs["highlightBondColors"][ind] drawer.SetOffset(offset_x, offset_y) - drawer.DrawMolecule(mol, legend=legends[ind], **kwargs) + drawer.DrawMolecule(mol, legend=legends[ind], **ind_kwargs) offset = None if draw_mols_same_scale: offset = drawer.Offset() diff --git a/datamol/viz/utils.py b/datamol/viz/utils.py index c2d8cdb2..c0e6ad75 100644 --- a/datamol/viz/utils.py +++ b/datamol/viz/utils.py @@ -141,6 +141,12 @@ def to_rdkit_color(color: Optional[DatamolColor]) -> Optional[RDKitColor]: Args: color: A datamol color: hex, rgb, rgba or None. """ + if color is None: + return None + if isinstance(color, str): return mcolors.to_rgba(color) # type: ignore + if isinstance(color, (tuple, list)) and len(color) in [3, 4] and any(x > 1 for x in color): + return tuple(x / 255 if i < 3 else x for i, x in enumerate(color)) + return color diff --git a/tests/test_viz_lasso_highlight.py b/tests/test_viz_lasso_highlight.py index 542c9464..a56ba6be 100644 --- a/tests/test_viz_lasso_highlight.py +++ b/tests/test_viz_lasso_highlight.py @@ -17,6 +17,30 @@ def test_from_mol(): assert dm.lasso_highlight_image(mol, smarts_list) +def test_with_highlight(): + smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]" + mol = dm.to_mol(smi) + smarts_list = "CONN" + highlight_atoms = [4, 5, 6] + highlight_bonds = [1, 2, 3, 4] + highlight_atom_colors = {4: (230, 230, 250), 5: (230, 230, 250), 6: (230, 230, 250)} + highlight_bond_colors = { + 1: (230, 230, 250), + 2: (230, 230, 250), + 3: (230, 230, 250), + 4: (230, 230, 250), + } + assert dm.lasso_highlight_image( + mol, + smarts_list, + highlight_atoms=highlight_atoms, + highlight_bonds=highlight_bonds, + highlight_atom_colors=highlight_atom_colors, + highlight_bond_colors=highlight_bond_colors, + continuousHighlight=False, + ) + + def test_original_working_solution_list_single_str(): smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]" smarts_list = ["CONN"]