Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion openfe/protocols/openmm_rfe/hybridtop_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,6 @@ def _execute(
**inputs,
) -> dict[str, Any]:
log_system_probe(logging.INFO, paths=[ctx.scratch])

# Get the relevant inputs
system = deserialize(setup_results.outputs["system"])
positions = to_openmm(np.load(setup_results.outputs["positions"]) * offunit.nm)
Expand Down
27 changes: 21 additions & 6 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,16 @@ def _get_names(result: dict) -> tuple[str, str]:

# TODO: I don't like this [0][0] indexing, but I can't think of a better way currently
protocol_data = list(result["protocol_result"]["data"].values())[0][0]

name_A = protocol_data["inputs"]["ligandmapping"]["componentA"]["molprops"]["ofe-name"]
name_B = protocol_data["inputs"]["ligandmapping"]["componentB"]["molprops"]["ofe-name"]
try:
name_A = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentA"][
"molprops"
]["ofe-name"]
name_B = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentB"][
"molprops"
]["ofe-name"]
except KeyError:
name_A = protocol_data["inputs"]["ligandmapping"]["componentA"]["molprops"]["ofe-name"]
name_B = protocol_data["inputs"]["ligandmapping"]["componentB"]["molprops"]["ofe-name"]

return str(name_A), str(name_B)

Expand All @@ -232,9 +239,17 @@ def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]:
"""Determine the simulation type based on the component types."""

protocol_data = list(result["protocol_result"]["data"].values())[0][0]
component_types = [
x["__module__"] for x in protocol_data["inputs"]["stateA"]["components"].values()
]
try:
component_types = [
x["__module__"]
for x in protocol_data["inputs"]["setup_results"]["inputs"]["stateA"][
"components"
].values()
]
except KeyError:
component_types = [
x["__module__"] for x in protocol_data["inputs"]["stateA"]["components"].values()
]
if "gufe.components.solventcomponent" not in component_types:
return "vacuum"
elif "gufe.components.proteincomponent" in component_types:
Expand Down
117 changes: 85 additions & 32 deletions openfecli/tests/test_rbfe_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import os
from importlib import resources
from os import path
from pathlib import Path
from unittest import mock

import numpy as np
import pytest
from click.testing import CliRunner
from openff.units import unit
Expand Down Expand Up @@ -89,25 +91,6 @@ def test_plan_tyk2(tyk2_ligands, tyk2_protein, expected_transformations):
assert "n_protocol_repeats=3" in result.output


@pytest.fixture
def mock_execute(expected_transformations):
def fake_execute(*args, **kwargs):
return {
"repeat_id": kwargs["repeat_id"],
"generation": kwargs["generation"],
"nc": "file.nc",
"last_checkpoint": "checkpoint.nc",
"unit_estimate": 4.2 * unit.kilocalories_per_mole,
}

with mock.patch(
"openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit._execute"
) as m:
m.side_effect = fake_execute

yield m


@pytest.fixture
def ref_gather():
return """\
Expand All @@ -124,25 +107,95 @@ def ref_gather():
"""


def test_run_tyk2(tyk2_ligands, tyk2_protein, expected_transformations, mock_execute, ref_gather):
@pytest.fixture
def fake_setup_execute_results():
"""Use for mocking the expensive _execute step and instead directly return plausible results."""

def _fake_execute_results(*args, **kwargs):
return {
"repeat_id": kwargs["repeat_id"],
"generation": kwargs["generation"],
"system": Path("system.xml.bz2"),
"positions": Path("positions.npy"),
"pdb_structure": Path("hybrid_system.pdb"),
"selection_indices": np.arange(50),
}

return _fake_execute_results


@pytest.fixture
def fake_sim_execute_results():
"""Use for mocking the expensive _execute step and instead directly return plausible results."""

def _fake_execute_results(*args, **kwargs):
return {
"repeat_id": kwargs["repeat_id"],
"generation": kwargs["generation"],
"nc": Path("file.nc"),
"checkpoint": Path("chk.chk"),
}

return _fake_execute_results


@pytest.fixture
def fake_analysis_execute_results():
"""Use for mocking the expensive _execute step and instead directly return plausible results."""

def _fake_execute_results(*args, **kwargs):
return {
"repeat_id": kwargs["repeat_id"],
"generation": kwargs["generation"],
"pdb_structure": Path("hybrid_system.pdb"),
"checkpoint": Path("chk.chk"),
"selection_indices": np.arange(50),
"unit_estimate": 4.2 * unit.kilocalories_per_mole,
}

return _fake_execute_results


def test_run_tyk2(
tyk2_ligands,
tyk2_protein,
expected_transformations,
fake_setup_execute_results,
fake_sim_execute_results,
fake_analysis_execute_results,
ref_gather,
):
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(
result_setup = runner.invoke(
plan_rbfe_network,
[
"-M", tyk2_ligands,
"-p", tyk2_protein,
],
) # fmt: skip

assert_click_success(result)

for f in expected_transformations:
fn = path.join("alchemicalNetwork/transformations", f)
result2 = runner.invoke(quickrun, [fn])
assert_click_success(result2)

gather_result = runner.invoke(gather, ["--report", "ddg", ".", "--tsv"])

assert_click_success(gather_result)
assert gather_result.stdout == ref_gather
assert_click_success(result_setup)
with (
mock.patch(
"openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologySetupUnit._execute",
side_effect=fake_setup_execute_results,
),
mock.patch(
"openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateSimulationUnit._execute",
side_effect=fake_sim_execute_results,
),
mock.patch(
"openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateAnalysisUnit._execute",
side_effect=fake_analysis_execute_results,
),
):
for f in expected_transformations:
fn = path.join("alchemicalNetwork/transformations", f)
result_run = runner.invoke(quickrun, [fn])
assert_click_success(result_run)

result_gather = runner.invoke(gather, ["--report", "ddg", ".", "--tsv"])

assert_click_success(result_gather)
assert result_gather.stdout == ref_gather
Loading