Skip to content

Commit

Permalink
Sirens sampling method (#330)
Browse files Browse the repository at this point in the history
Implements the Sirens sampling method described in here https://arxiv.org/abs/2202.01876.
  • Loading branch information
pabloferz authored Jul 31, 2024
1 parent 15a1700 commit e0d16d0
Show file tree
Hide file tree
Showing 10 changed files with 659 additions and 77 deletions.
1 change: 1 addition & 0 deletions pysages/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from .harmonic_bias import HarmonicBias
from .metad import Metadynamics
from .restraints import CVRestraints
from .sirens import Sirens
from .spectral_abf import SpectralABF
from .spline_string import SplineString
from .umbrella_integration import UmbrellaIntegration
Expand Down
13 changes: 6 additions & 7 deletions pysages/methods/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs):
default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6))
self.optimizer = kwargs.get("optimizer", default_optimizer)

def build(self, snapshot, helpers):
def build(self, snapshot, helpers, *_args, **_kwargs):
return _ann(self, snapshot, helpers)


Expand Down Expand Up @@ -155,7 +155,7 @@ def update(state, data):
in_training_regime = ncalls > train_freq
# We only train every `train_freq` timesteps
in_training_step = in_training_regime & (ncalls % train_freq == 1)
hist, phi, prob, nn = learn_free_energy(state, in_training_step)
hist, prob, phi, nn = learn_free_energy(state, in_training_step)
# Compute the collective variable and its jacobian
xi, Jxi = cv(data)
I_xi = get_grid_index(xi)
Expand Down Expand Up @@ -208,10 +208,10 @@ def learn_free_energy(state):
#
hist = np.zeros_like(state.hist)
#
return hist, phi, prob, nn
return hist, prob, phi, nn

def skip_learning(state):
return state.hist, state.phi, state.prob, state.nn
return state.hist, state.prob, state.phi, state.nn

def _learn_free_energy(state, in_training_step):
return cond(in_training_step, learn_free_energy, skip_learning, state)
Expand Down Expand Up @@ -241,7 +241,7 @@ def predict_force(data):
params = pack(nn.params, layout)
return nn.std * f64(model_grad(params, f32(x)).flatten())

def zero_force(data):
def zero_force(_data):
return np.zeros(dims)

def estimate_force(xi, I_xi, nn, in_training_regime):
Expand Down Expand Up @@ -282,7 +282,6 @@ def analyze(result: Result[ANN]):
"""

method = result.method
states = result.states

grid = method.grid
mesh = (compute_mesh(grid) + 1) * grid.size / 2 + grid.lower
Expand All @@ -306,7 +305,7 @@ def fes_fn(x):
transpose = grid_transposer(grid)
d = mesh.shape[-1]

for s in states:
for s in result.states:
histograms.append(transpose(s.hist))
free_energies.append(transpose(s.phi.max() - s.phi))
nns.append(s.nn)
Expand Down
6 changes: 3 additions & 3 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def update(state, data):
ncalls = state.ncalls + 1
in_training_regime = ncalls > train_freq
in_training_step = in_training_regime & (ncalls % train_freq == 1)
histp, fe, prob, nn, fnn = learn_free_energy(state, in_training_step)
histp, prob, fe, nn, fnn = learn_free_energy(state, in_training_step)
# Compute the collective variable and its jacobian
xi, Jxi = cv(data)
#
Expand Down Expand Up @@ -281,7 +281,7 @@ def train(nn, fnn, data):
return NNData(params, nn.mean, s), NNData(fparams, f_mean, s)

def skip_learning(state):
return state.hist, state.fe, state.prob, state.nn, state.fnn
return state.histp, state.prob, state.fe, state.nn, state.fnn

def learn_free_energy(state):
prob = state.prob + state.histp * np.exp(state.fe / kT)
Expand All @@ -294,7 +294,7 @@ def learn_free_energy(state):
fe = nn.std * model.apply(params, inputs).reshape(fe.shape)
fe = fe - fe.min()

return histp, fe, prob, nn, fnn
return histp, prob, fe, nn, fnn

def _learn_free_energy(state, in_training_step):
return cond(in_training_step, learn_free_energy, skip_learning, state)
Expand Down
Loading

0 comments on commit e0d16d0

Please sign in to comment.