Skip to content

Adapt Unittests to new RDKit SMILES + Linting #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions tests/test_elements/test_mol2mol/test_mol2mol_filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Test MolFilter, which invalidate molecules based on criteria defined in the respective filter."""
"""Unittest for MolFilters functionality.

MolFilters flag Molecules as invalid based on the criteria defined in the filter.

"""

import json
import tempfile
import unittest
from pathlib import Path
from typing import TYPE_CHECKING

from molpipeline import ErrorFilter, FilterReinserter, Pipeline
from molpipeline.any2mol import SmilesToMol
Expand All @@ -19,14 +24,16 @@
)
from molpipeline.utils.comparison import compare_recursive
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange

if TYPE_CHECKING:
from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange

# pylint: disable=duplicate-code # test case molecules are allowed to be duplicated
SMILES_ANTIMONY = "[SbH6+3]"
SMILES_BENZENE = "c1ccccc1"
SMILES_CHLOROBENZENE = "Clc1ccccc1"
SMILES_CL_BR = "NC(Cl)(Br)C(=O)O"
SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O"
SMILES_METAL_AU = "OC[C@H]1OC([S][Au])[C@H](O)[C@@H](O)[C@@H]1O"

SMILES_LIST = [
SMILES_ANTIMONY,
Expand Down Expand Up @@ -80,7 +87,7 @@ def test_element_filter(self) -> None:
6: 6,
1: (5, 6),
17: (0, 1),
}
},
},
"result": [SMILES_BENZENE, SMILES_CHLOROBENZENE],
},
Expand All @@ -99,14 +106,15 @@ def test_json_roundtrip(self) -> None:
-----
It is important to save the ElementFilter as a JSON file and then load it back.
This is because json.dumps() sets the keys of the dictionary to strings.

"""
element_filter = ElementFilter()
json_object = recursive_to_json(element_filter)
with tempfile.TemporaryDirectory() as temp_folder:
temp_file_path = Path(temp_folder) / "test.json"
with open(temp_file_path, "w", encoding="UTF-8") as out_file:
with temp_file_path.open("w", encoding="UTF-8") as out_file:
json.dump(json_object, out_file)
with open(temp_file_path, encoding="UTF-8") as in_file:
with temp_file_path.open(encoding="UTF-8") as in_file:
loaded_json_object = json.load(in_file)
recreated_element_filter = recursive_from_json(loaded_json_object)

Expand All @@ -117,7 +125,8 @@ def test_json_roundtrip(self) -> None:
with self.subTest(param_name=param_name):
self.assertTrue(
compare_recursive(original_value, recreated_params[param_name]),
f"Original: {original_value}, Recreated: {recreated_params[param_name]}",
f"Original: {original_value}, "
f"Recreated: {recreated_params[param_name]}",
)


Expand All @@ -132,6 +141,7 @@ def _create_pipeline() -> Pipeline:
-------
Pipeline
Pipeline with a complex filter.

"""
element_filter_1 = ElementFilter({6: 6, 1: 6})
element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1})
Expand All @@ -140,18 +150,17 @@ def _create_pipeline() -> Pipeline:
(
("element_filter_1", element_filter_1),
("element_filter_2", element_filter_2),
)
),
)

pipeline = Pipeline(
return Pipeline(
[
("Smiles2Mol", SmilesToMol()),
("MultiElementFilter", multi_element_filter),
("Mol2Smiles", MolToSmiles()),
("ErrorFilter", ErrorFilter()),
],
)
return pipeline

def test_complex_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
Expand All @@ -169,7 +178,7 @@ def test_complex_filter(self) -> None:
{
"params": {
"MultiElementFilter__mode": "any",
"MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False,
"MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, # noqa: E501
},
"result": [SMILES_CHLOROBENZENE],
},
Expand Down Expand Up @@ -198,7 +207,7 @@ def test_complex_filter_non_unique_names(self) -> None:

with self.assertRaises(ValueError):
ComplexFilter(
(("filter_1", element_filter_1), ("filter_1", element_filter_2))
(("filter_1", element_filter_1), ("filter_1", element_filter_2)),
)


Expand Down Expand Up @@ -285,7 +294,11 @@ def test_smarts_smiles_filter_wrong_pattern(self) -> None:
SmilesFilter(smiles_pats)

def test_smarts_filter_parallel(self) -> None:
"""Test if molecules are filtered correctly by allowed SMARTS patterns in parallel."""
"""Test if molecules are filtered correctly.

This test runs the SmartsFilter in parallel.

"""
smarts_pats: dict[str, IntOrIntCountRange] = {
"c": (4, None),
"Cl": 1,
Expand Down Expand Up @@ -352,31 +365,31 @@ def test_descriptor_filter(self) -> None:
},
{
"params": {
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)}
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)},
},
"result": [SMILES_CL_BR],
},
{
"params": {
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)}
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)},
},
"result": [],
},
{
"params": {
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)}
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)},
},
"result": [SMILES_CL_BR],
},
{
"params": {
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)}
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)},
},
"result": [SMILES_CL_BR],
},
{
"params": {
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)}
"DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)},
},
"result": [],
},
Expand Down Expand Up @@ -409,7 +422,7 @@ def test_invalidate_mixtures(self) -> None:
("mol2smi", mol2smi),
("error_filter", error_filter),
("error_replacer", error_replacer),
]
],
)
mols_processed = pipeline.fit_transform(mol_list)
self.assertEqual(expected_invalidated_mol_list, mols_processed)
Expand All @@ -424,7 +437,7 @@ def test_inorganic_filter(self) -> None:
inorganics_filter = InorganicsFilter()
mol2smiles = MolToSmiles()
error_filter = ErrorFilter.from_element_list(
[smiles2mol, inorganics_filter, mol2smiles]
[smiles2mol, inorganics_filter, mol2smiles],
)
pipeline = Pipeline(
[
Expand Down
Loading