Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions openfe/protocols/openmm_rfe/_rfe_utils/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
logger = logging.getLogger(__name__)


class HybridCompatibilityMixin(object):
class HybridCompatibilityMixin:
"""
Mixin that allows the MultistateSampler to accommodate the situation where
unsampled endpoints have a different number of degrees of freedom.
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
self._hybrid_factory = hybrid_factory
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
self._hybrid_system = hybrid_system
self._hybrid_positions = hybrid_positions
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)

def setup(self, reporter, lambda_protocol,
Expand Down Expand Up @@ -73,15 +74,17 @@ class creation of LambdaProtocol.
"""
n_states = len(lambda_protocol.lambda_schedule)

hybrid_system = self._factory.hybrid_system
lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system)

lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system)
thermostate = ThermodynamicState(
self._hybrid_system,
temperature=temperature
)

thermostate = ThermodynamicState(hybrid_system,
temperature=temperature)
compound_thermostate = CompoundThermodynamicState(
thermostate,
composable_states=[lambda_zero_state])
thermostate,
composable_states=[lambda_zero_state]
)

# create lists for storing thermostates and sampler states
thermodynamic_state_list = []
Expand All @@ -105,16 +108,20 @@ class creation of LambdaProtocol.
raise ValueError(errmsg)

# starting with the hybrid factory positions
box = hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(self._factory.hybrid_positions,
box_vectors=box)
box = self._hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(
self._hybrid_positions,
box_vectors=box
)

# Loop over the lambdas and create & store a compound thermostate at
# that lambda value
for lambda_val in lambda_schedule:
compound_thermostate_copy = copy.deepcopy(compound_thermostate)
compound_thermostate_copy.set_alchemical_parameters(
lambda_val, lambda_protocol)
lambda_val,
lambda_protocol
)
thermodynamic_state_list.append(compound_thermostate_copy)

# now generating a sampler_state for each thermodyanmic state,
Expand Down Expand Up @@ -143,7 +150,8 @@ class creation of LambdaProtocol.
# generating unsampled endstates
unsampled_dispersion_endstates = create_endstates(
copy.deepcopy(thermodynamic_state_list[0]),
copy.deepcopy(thermodynamic_state_list[-1]))
copy.deepcopy(thermodynamic_state_list[-1])
)
self.create(thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list, storage=reporter,
unsampled_thermodynamic_states=unsampled_dispersion_endstates)
Expand All @@ -159,10 +167,13 @@ class HybridRepexSampler(HybridCompatibilityMixin,
number of positions
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridRepexSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs)
self._factory = hybrid_factory
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)


class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
Expand All @@ -171,11 +182,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
of positions
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridSAMSSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory


class HybridMultiStateSampler(HybridCompatibilityMixin,
Expand All @@ -184,11 +197,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin,
MultiStateSampler that supports unsample end states with a different
number of positions
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridMultiStateSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory


def create_endstates(first_thermostate, last_thermostate):
Expand Down
12 changes: 7 additions & 5 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,15 +1128,17 @@ def run(
if sampler_settings.sampler_method.lower() == "repex":
sampler = _rfe_utils.multistate.HybridRepexSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
hybrid_system=hybrid_factory.hybrid_system,
hybrid_positions=hybrid_factory.hybrid_positions,
online_analysis_interval=rta_its,
online_analysis_target_error=early_termination_target_error,
online_analysis_minimum_iterations=rta_min_its,
)
elif sampler_settings.sampler_method.lower() == "sams":
sampler = _rfe_utils.multistate.HybridSAMSSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
hybrid_system=hybrid_factory.hybrid_system,
hybrid_positions=hybrid_factory.hybrid_positions,
online_analysis_interval=rta_its,
online_analysis_minimum_iterations=rta_min_its,
flatness_criteria=sampler_settings.sams_flatness_criteria,
Expand All @@ -1145,12 +1147,12 @@ def run(
elif sampler_settings.sampler_method.lower() == "independent":
sampler = _rfe_utils.multistate.HybridMultiStateSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
hybrid_system=hybrid_factory.hybrid_system,
hybrid_positions=hybrid_factory.hybrid_positions,
online_analysis_interval=rta_its,
online_analysis_target_error=early_termination_target_error,
online_analysis_minimum_iterations=rta_min_its,
)

else:
raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}")

Expand Down Expand Up @@ -1247,7 +1249,7 @@ def run(
if not dry: # pragma: no-cover
return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict}
else:
return {"debug": {"sampler": sampler}}
return {"debug": {"sampler": sampler, "hybrid_factory": hybrid_factory}}

@staticmethod
def structural_analysis(scratch, shared) -> dict:
Expand Down
37 changes: 20 additions & 17 deletions openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,14 @@ def test_dry_run_default_vacuum(
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
debug = dag_unit.run(dry=True)["debug"]
sampler = debug["sampler"]
assert isinstance(sampler, MultiStateSampler)
assert not sampler.is_periodic
assert sampler._thermodynamic_states[0].barostat is None

# Check hybrid OMM and MDTtraj Topologies
htf = sampler._hybrid_factory
htf = debug["hybrid_factory"]
# 16 atoms:
# 11 common atoms, 1 extra hydrogen in benzene, 4 extra in toluene
# 12 bonds in benzene + 4 extra toluene bonds
Expand Down Expand Up @@ -414,7 +415,7 @@ def test_dry_core_element_change(vac_settings, tmpdir):

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
system = sampler._hybrid_factory.hybrid_system
system = sampler._hybrid_system
assert system.getNumParticles() == 12
# Average mass between nitrogen and carbon
assert system.getParticleMass(1) == 12.0127235 * omm_unit.amu
Expand Down Expand Up @@ -518,7 +519,7 @@ def tip4p_hybrid_factory(
shared_basepath=shared_temp,
)

return dag_unit_result["debug"]["sampler"]._factory
return dag_unit_result["debug"]["hybrid_factory"]


def test_tip4p_particle_count(tip4p_hybrid_factory):
Expand Down Expand Up @@ -624,7 +625,7 @@ def test_dry_run_ligand_system_cutoff(

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
hs = sampler._factory.hybrid_system
hs = sampler._hybrid_system

nbfs = [
f
Expand Down Expand Up @@ -691,9 +692,10 @@ def test_dry_run_charge_backends(
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
htf = sampler._factory
hybrid_system = htf.hybrid_system
debug = dag_unit.run(dry=True)["debug"]
sampler = debug["sampler"]
htf = debug["hybrid_factory"]
hybrid_system = sampler._hybrid_system

# get the standard nonbonded force
nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)]
Expand Down Expand Up @@ -785,9 +787,10 @@ def check_propchgs(smc, charge_array):
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
htf = sampler._factory
hybrid_system = htf.hybrid_system
debug = dag_unit.run(dry=True)["debug"]
sampler = debug["sampler"]
htf = debug["hybrid_factory"]
hybrid_system = sampler._hybrid_system

# get the standard nonbonded force
nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)]
Expand Down Expand Up @@ -902,7 +905,7 @@ def test_dodecahdron_ligand_box(

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)["debug"]["sampler"]
hs = sampler._factory.hybrid_system
hs = sampler._hybrid_system

vectors = hs.getDefaultPeriodicBoxVectors()

Expand Down Expand Up @@ -1598,7 +1601,7 @@ def tyk2_xml(tmp_path_factory):

dryrun = pu.run(dry=True, shared_basepath=tmp)

system = dryrun["debug"]["sampler"]._hybrid_factory.hybrid_system
system = dryrun["debug"]["sampler"]._hybrid_system

return ET.fromstring(XmlSerializer.serialize(system))

Expand Down Expand Up @@ -2153,8 +2156,8 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, solv_settings,
unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = unit.run(dry=True)["debug"]["sampler"]
htf = sampler._factory
debug = unit.run(dry=True)["debug"]
htf = debug["hybrid_factory"]
_assert_total_charge(htf.hybrid_system, htf._atom_classes, 0, 0)

assert len(htf._atom_classes["core_atoms"]) == 14
Expand Down Expand Up @@ -2222,8 +2225,8 @@ def test_dry_run_complex_alchemwater_totcharge(
unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = unit.run(dry=True)["debug"]["sampler"]
htf = sampler._factory
debug = unit.run(dry=True)["debug"]
htf = debug["hybrid_factory"]
_assert_total_charge(htf.hybrid_system, htf._atom_classes, chgA, chgB)

assert len(htf._atom_classes["core_atoms"]) == core_atoms
Expand Down
Loading