Skip to content

Commit

Permalink
Improve TV derivative. (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
clonker committed Oct 13, 2022
1 parent f5ac3bb commit a3967ef
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
61 changes: 42 additions & 19 deletions deeptime/util/diff.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as _np
import numpy as np
from scipy import sparse as _sparse

from deeptime.util.data import sliding_window
Expand Down Expand Up @@ -134,7 +135,8 @@ def _cumtrapz_operator(xs):
return _sparse.csc_matrix((data, (row_data, col_data)), shape=(n - 1, n))


def tv_derivative(xs, ys, u0=None, alpha=10., tol=None, maxit=1000, fd_window_radius=5, verbose=False):
def tv_derivative(xs, ys, u0=None, alpha=10., tol=None, maxit=1000, fd_window_radius=5, epsilon=1e-6,
sparse=True, progress=None):
r""" Total-variation regularized derivative. Note that this is currently only implemented for one-dimensional
functions. See :footcite:`chartrand2011numerical` for theory and algorithmic details.
Expand All @@ -159,8 +161,13 @@ def tv_derivative(xs, ys, u0=None, alpha=10., tol=None, maxit=1000, fd_window_ra
fd_window_radius : int, default=5
Radius in which the finite differences are computed. For example, a value of `2` means that the local gradient
at :math:`x_n` is approximated using grid nodes :math:`x_{n-2}, x_{n-1}, x_n, x_{n+1}, x_{n+2}`.
verbose : bool, default=False
Print convergence information.
epsilon : float, default=1e-6
Small constant that is added to the norm of the current iterate of the TV-regularized derivative so it can
safely be normalized.
sparse : bool, default=True
Whether to use sparse matrices for finite difference, integration, and normalization operators.
progress : ProgressBar, optional, default=None
Optional progress bar, tested for tqdm.
Returns
-------
Expand All @@ -171,14 +178,14 @@ def tv_derivative(xs, ys, u0=None, alpha=10., tol=None, maxit=1000, fd_window_ra
----------
.. footbibliography::
"""
from deeptime.util.platform import handle_progress_bar
assert alpha > 0, "Regularization parameter may only be positive."
progress = handle_progress_bar(progress)
data = _np.asarray(ys, dtype=_np.float64).squeeze()
xs = _np.asarray(xs, dtype=_np.float64).squeeze()
n = data.shape[0]
assert xs.shape[0] == n, "the grid must have the same dimension as data"

epsilon = 1e-6

# grid of midpoints between xs, extrapolating first and last node:
#
# x--|--x--|---x---|-x-|-x
Expand All @@ -191,14 +198,21 @@ def tv_derivative(xs, ys, u0=None, alpha=10., tol=None, maxit=1000, fd_window_ra
assert midpoints.shape[0] == n + 1

diff = finite_difference_operator_midpoints(midpoints, k=1, window_radius=fd_window_radius)
if not sparse:
diff = diff.toarray()
assert diff.shape[0] == n
assert diff.shape[1] == n + 1

diff_t = diff.transpose().tocsc()
if sparse:
diff_t = diff.transpose().tocsc()
else:
diff_t = diff.T
assert diff.shape[0] == n
assert diff.shape[1] == n + 1

A = _cumtrapz_operator(midpoints)
if not sparse:
A = A.toarray()
AT = A.T
ATA = AT @ A

Expand All @@ -208,23 +222,32 @@ def tv_derivative(xs, ys, u0=None, alpha=10., tol=None, maxit=1000, fd_window_ra
if len(u0) == n:
u0 = _np.concatenate(([0], .5 * (u0[1:] + u0[:-1]), [0]))
u = u0
Aadj_offset = AT * (data[0] - data)

E_n = _sparse.dia_matrix((n, n), dtype=xs.dtype)
midpoints_diff = _np.diff(midpoints)

for ii in range(1, maxit + 1):
E_n.setdiag(midpoints_diff * (1. / _np.sqrt(_np.diff(u) ** 2.0 + epsilon)))
L = diff_t * E_n * diff
g = ATA.dot(u) + Aadj_offset + alpha * L * u
Aadj_offset = AT @ (data[0] - data)

if sparse:
E_n = _sparse.dia_matrix((n, n), dtype=xs.dtype)
else:
E_n = np.zeros((n, n), dtype=xs.dtype)
midpoints_diff = _np.gradient(midpoints, edge_order=2)

for _ in progress(range(1, maxit + 1)):
diagonal = midpoints_diff * (1. / _np.sqrt(_np.gradient(u, edge_order=2) ** 2.0 + epsilon))
if sparse:
E_n.setdiag(diagonal)
else:
np.fill_diagonal(E_n, diagonal)
L = diff_t @ E_n @ diff
g = ATA @ u + Aadj_offset + alpha * L @ u

# solve linear equation.
s = _np.linalg.solve((alpha * L + ATA).todense().astype(_np.float64), -g.astype(_np.float64))
lhs = alpha * L + ATA
if sparse:
lhs = lhs.toarray()
s = _np.linalg.solve(lhs.astype(_np.float64), -g.astype(_np.float64))

relative_change = _np.linalg.norm(s[0]) / _np.linalg.norm(u)
if verbose:
print(f'iteration {ii:4d}: relative change = {relative_change:.3e},'
f' gradient norm = {_np.linalg.norm(g):.3e}')
#print(f'iteration {ii:4d}: relative change = {relative_change:.3e},'
# f' gradient norm = {_np.linalg.norm(g):.3e}')

# Update current solution
u = u + s
Expand Down
2 changes: 1 addition & 1 deletion examples/methods/plot_tv_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
x0 = np.linspace(0, 2.0 * np.pi, 200)
testf = np.sin(x0) + np.random.normal(0.0, np.sqrt(noise_variance), x0.shape)
true_deriv = np.cos(x0)
df_tv = diff.tv_derivative(x0, testf, alpha=0.01, tol=1e-5, verbose=True, fd_window_radius=5)
df_tv = diff.tv_derivative(x0, testf, alpha=0.001, tol=1e-5, fd_window_radius=5, sparse=False)

plt.figure()
plt.plot(x0, np.sin(x0), label=r'$f(x) = \sin(x)$')
Expand Down
10 changes: 5 additions & 5 deletions tests/util/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import numpy as np
import pytest
from numpy.testing import assert_, assert_array_almost_equal

import deeptime.util.diff as diff


def test_tv_derivative(capsys):
@pytest.mark.parametrize('sparse', [False, True])
def test_tv_derivative(sparse):
noise_variance = .08 * .08
x0 = np.linspace(0, 2.0 * np.pi, 400)
testf = np.sin(x0) + np.random.normal(0.0, np.sqrt(noise_variance), x0.shape)
true_deriv = np.cos(x0)
df = diff.tv_derivative(x0, testf, alpha=0.01, tol=1e-5, verbose=True, fd_window_radius=5)
captured = capsys.readouterr()
df = diff.tv_derivative(x0, testf, alpha=0.01, tol=1e-5, fd_window_radius=5, sparse=sparse)
max_diff = np.max(np.abs(df - true_deriv))
assert_(max_diff < 0.5)
assert_("relative change" in captured.out)

df2 = diff.tv_derivative(x0, testf, u0=df, alpha=0.01, tol=1e-5, verbose=False, fd_window_radius=5)
df2 = diff.tv_derivative(x0, testf, u0=df, alpha=0.01, tol=1e-5, fd_window_radius=5)
assert_array_almost_equal(df, df2, decimal=1)

0 comments on commit a3967ef

Please sign in to comment.