Skip to content

Commit

Permalink
fix Wp word issue & impl option for other methods
Browse files Browse the repository at this point in the history
  • Loading branch information
yihengwuKP committed Aug 9, 2024
1 parent ca18225 commit 4bf3ac3
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pysages/methods/abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ class ABF(GriddedSamplingMethod):
`restraints.upper`.
use_np_pinv: Optional[Bool] = False
If set to True, the Wp will be calculated using np.linalg.pinv(Jxi.T)@p
rather than solve_pos_def(Jxi @ Jxi.T, Jxi @ p).
If set to True, the product W times momentum p
will be calculated using pseudo-inverse from numpy
rather than using the solving function from scipy
This is computationally more expensive but numerically more stable.
"""

Expand Down
13 changes: 12 additions & 1 deletion pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ class CFF(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_np_pinv: Optional[Bool] = False
If set to True, the product W times momentum p
will be calculated using pseudo-inverse from numpy
rather than using the solving function from scipy
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -171,6 +177,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs):
self.fmodel = MLP(dims, dims, topology, transform=scale)
self.optimizer = kwargs.get("optimizer", default_optimizer)
self.foptimizer = kwargs.get("foptimizer", default_foptimizer)
self.use_np_pinv = self.kwargs.get("use_np_pinv", False)

def build(self, snapshot, helpers):
return _cff(self, snapshot, helpers)
Expand All @@ -180,6 +187,7 @@ def _cff(method: CFF, snapshot, helpers):
cv = method.cv
grid = method.grid
train_freq = method.train_freq
use_np_pinv = method.use_np_pinv
dt = snapshot.dt

# Neural network paramters
Expand Down Expand Up @@ -221,7 +229,10 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
if use_np_pinv:
Wp = np.linalg.pinv(Jxi.T) @ p
else:
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
13 changes: 12 additions & 1 deletion pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ class FUNN(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_np_pinv: Optional[Bool] = False
If set to True, the product W times momentum p
will be calculated using pseudo-inverse from numpy
rather than using the solving function from scipy
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -142,6 +148,7 @@ def __init__(self, cvs, grid, topology, **kwargs):
self.model = MLP(dims, dims, topology, transform=scale)
default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6))
self.optimizer = kwargs.get("optimizer", default_optimizer)
self.use_np_pinv = self.kwargs.get("use_np_pinv", False)

def build(self, snapshot, helpers):
return _funn(self, snapshot, helpers)
Expand All @@ -151,6 +158,7 @@ def _funn(method, snapshot, helpers):
cv = method.cv
grid = method.grid
train_freq = method.train_freq
use_np_pinv = method.use_np_pinv

dt = snapshot.dt
dims = grid.shape.size
Expand Down Expand Up @@ -186,7 +194,10 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
if use_np_pinv:
Wp = np.linalg.pinv(Jxi.T) @ p
else:
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
13 changes: 12 additions & 1 deletion pysages/methods/sirens.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ class Sirens(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_np_pinv: Optional[Bool] = False
If set to True, the product W times momentum p
will be calculated using pseudo-inverse from numpy
rather than using the solving function from scipy
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -172,6 +178,7 @@ def __init__(self, cvs, grid, topology, **kwargs):
scale = partial(_scale, grid=grid)
self.model = Siren(dims, 1, topology, transform=scale)
self.optimizer = optimizer
self.use_np_pinv = self.kwargs.get("use_np_pinv", False)

def __check_init_invariants__(self, mode, kT, optimizer):
if mode not in ("abf", "cff"):
Expand All @@ -196,6 +203,7 @@ def _sirens(method: Sirens, snapshot, helpers):
cv = method.cv
grid = method.grid
train_freq = method.train_freq
use_np_pinv = method.use_np_pinv
dt = snapshot.dt

# Neural network paramters
Expand Down Expand Up @@ -244,7 +252,10 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
if use_np_pinv:
Wp = np.linalg.pinv(Jxi.T) @ p
else:
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
13 changes: 12 additions & 1 deletion pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ class SpectralABF(GriddedSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_np_pinv: Optional[Bool] = False
If set to True, the product W times momentum p
will be calculated using pseudo-inverse from numpy
rather than using the solving function from scipy
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -135,6 +141,7 @@ def __init__(self, cvs, grid, **kwargs):
self.fit_threshold = self.kwargs.get("fit_threshold", 500)
self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev])
self.model = SpectralGradientFit(self.grid)
self.use_np_pinv = self.kwargs.get("use_np_pinv", False)

def build(self, snapshot, helpers, *_args, **_kwargs):
"""
Expand All @@ -148,6 +155,7 @@ def _spectral_abf(method, snapshot, helpers):
grid = method.grid
fit_freq = method.fit_freq
fit_threshold = method.fit_threshold
use_np_pinv = method.use_np_pinv

dt = snapshot.dt
dims = grid.shape.size
Expand Down Expand Up @@ -181,7 +189,10 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
if use_np_pinv:
Wp = np.linalg.pinv(Jxi.T) @ p
else:
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
# Second order backward finite difference
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
Expand Down

0 comments on commit 4bf3ac3

Please sign in to comment.