Skip to content

Commit

Permalink
UPDATE: optimize speed of DMD by storing solutions
Browse files Browse the repository at this point in the history
  • Loading branch information
MarekWadinger committed Mar 21, 2024
1 parent f3858ea commit 2f117dd
Showing 1 changed file with 78 additions and 42 deletions.
120 changes: 78 additions & 42 deletions river/decomposition/odmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
x(t) = Phi exp(diag(ln(Lambda) / dt) * t) Phi^+ x(0) (MIT lecture)
continuous time eigenvalues exp(Lambda * dt) (Zhang et al. 2019)
- [ ] Figure out how to use as both MiniBatchRegressor and MiniBatchTransformer
- [ ] Find out why some values of A change sign between consecutive updates
References:
[^1]: Zhang, H., Clarence Worth Rowley, Deem, E.A. and Cattafesta, L.N.
Expand All @@ -29,9 +30,8 @@
import numpy as np
import pandas as pd
import scipy as sp
from scipy.sparse.linalg._eigen.arpack.arpack import ArpackNoConvergence

from river.base import MiniBatchRegressor
from river.base import MiniBatchRegressor, MiniBatchTransformer

from .osvd import OnlineSVD

Expand All @@ -41,7 +41,7 @@
]


class OnlineDMD(MiniBatchRegressor):
class OnlineDMD(MiniBatchRegressor, MiniBatchTransformer):
"""Online Dynamic Mode Decomposition (DMD).
This regressor is a class that implements online dynamic mode decomposition
Expand Down Expand Up @@ -155,15 +155,18 @@ def __init__(
w: float = 1.0,
initialize: int = 1,
exponential_weighting: bool = False,
eig_rtol: float | None = None,
seed: int | None = None,
) -> None:
self.r = int(r)
if self.r != 0:
self._svd = OnlineSVD(n_components=self.r, force_orth=True)
# Forcing orthogonality makes the results more unstable
self._svd = OnlineSVD(n_components=self.r, force_orth=False)
self.w = float(w)
assert self.w > 0 and self.w <= 1
self.initialize = int(initialize)
self.exponential_weighting = exponential_weighting
self.eig_rtol = eig_rtol
self.seed = seed

np.random.seed(self.seed)
Expand All @@ -175,50 +178,75 @@ def __init__(
self._P: np.ndarray
self._Y: np.ndarray # for xi computation

self._A_last: np.ndarray
self._A_allclose: bool = False
self._n_cached: int = 0 # TODO: remove before merge
self._n_computed: int = 0 # TODO: remove before merge

# Properties to be reset at each update
self._eig: tuple(np.ndarray, np.ndarray) | None = None
self._modes: np.ndarray | None = None
self._xi: np.ndarray | None = None

@property
def eig(self) -> tuple[np.ndarray, np.ndarray]:
"""Compute and return DMD eigenvalues and DMD modes at current step"""
# TODO: need to check if SVD is initialized in case r < m. Otherwise, transformation will fail.
try:
if self._eig is None:
# TODO: need to check if SVD is initialized in case r < m. Otherwise, transformation will fail.
# TODO: explore faster ways to compute eig
# TODO: find out whether Phi should have imaginary part
Lambda, Phi = sp.linalg.eig(self.A, check_finite=False)
except ArpackNoConvergence:
Lambda, Phi = sp.linalg.schur(self.A, check_finite=False)
# TODO: Figure out if we need to sort indices in descending order
if not np.array_equal(Lambda, sorted(Lambda, reverse=True)):
sort_idx = np.argsort(Lambda)[::-1]
Lambda = Lambda[sort_idx]
Phi = Phi[:, sort_idx]
return Lambda, Phi

sort_idx = np.argsort(Lambda)
if not np.array_equal(sort_idx, range(len(Lambda))):
sort_idx = sort_idx[::-1]
Lambda = Lambda[sort_idx]
Phi = Phi[:, sort_idx]
self._eig = Lambda, Phi
self._n_computed += 1
return self._eig

@property
def modes(self) -> np.ndarray:
"""Reconstruct high dimensional DMD modes"""
_, Phi = self.eig
if self.r < self.m:
# Sign of eigenvectors and singular vectors may change based on underlying algorithm initialization
# TODO: verify sign of singular values
return np.abs(self._svd._U) @ np.diag(self._svd._S) @ np.abs(Phi)
else:
return np.abs(Phi)
if self._modes is None:
L, Phi = self.eig
if self.r < self.m:
# Sign of eigenvectors and singular vectors may change based on underlying algorithm initialization
# TODO: shall we use discrete time singlar values or continuous time singlar values?
self._modes = self._svd._U @ np.diag(self._svd._S) @ Phi
else:
self._modes = Phi
return self._modes.real

@property
def xi(self) -> np.ndarray:
"""Amlitudes of the singular values of the input matrix."""
Lambda, Phi = self.eig
# Compute Discrete temporal dynamics matrix (Vandermonde matrix).
C = np.vander(Lambda, self.n_seen, increasing=True)
# xi = self.Phi.conj().T @ self._Y @ np.linalg.pinv(self.C)
if self._xi is None:
Lambda, Phi = self.eig
# Compute Discrete temporal dynamics matrix (Vandermonde matrix).
C = np.vander(Lambda, self.n_seen, increasing=True)
# xi = self.Phi.conj().T @ self._Y @ np.linalg.pinv(self.C)

from scipy.optimize import minimize
from scipy.optimize import minimize

def objective_function(x):
return np.linalg.norm(
self._Y[:, : self.r].T - Phi @ np.diag(x) @ C, "fro"
) + 0.5 * np.linalg.norm(x, 1)
def objective_function(x):
return np.linalg.norm(
self._Y[:, : self.r].T - Phi @ np.diag(x) @ C, "fro"
) + 0.5 * np.linalg.norm(x, 1)

# Minimize the objective function
xi = minimize(objective_function, np.ones(self.r)).x
self._xi = xi
return self._xi

@property
def A_allclose(self) -> bool:
"""Check if A has changed since last update of eigenvalues"""
if self.eig_rtol is None:
return False
return np.allclose(np.abs(self._A_last), np.abs(self.A), rtol=1, atol=1)

# Minimize the objective function
xi = minimize(objective_function, np.ones(self.r)).x
return xi

def _init_update(self) -> None:
if self.initialize > 0 and self.initialize < self.m:
Expand All @@ -230,6 +258,7 @@ def _init_update(self) -> None:
self.r = self.m

self.A = np.random.randn(self.r, self.r)
self._A_last = self.A.copy()
self._X_init = np.empty((self.initialize, self.m))
self._Y_init = np.empty((self.initialize, self.m))
self._Y = np.empty((0, self.m))
Expand Down Expand Up @@ -280,6 +309,15 @@ def _update_A_P(
# ensure P is SPD by taking its symmetric part
self._P = (self._P + self._P.T) / 2

# Reset properties
if not self.A_allclose:
self._eig = None
self._A_last = self.A.copy()
else:
self._n_cached += 1

self._modes = None

def update(
self,
x: dict | np.ndarray,
Expand Down Expand Up @@ -371,7 +409,7 @@ def revert(
raise RuntimeError(
f"Cannot revert {self.__class__.__name__} before "
"initialization. If used with Rolling or TimeRolling, window "
f"size should be increased to {self.initialize}."
f"size should be increased to {self.initialize + 1 if y is None else 0}."
)
if y is None:
# raise ValueError("revert method not implemented for y = None.")
Expand Down Expand Up @@ -479,7 +517,7 @@ def learn_many(
if self.r == 0:
self.r = self.m

assert np.linalg.matrix_rank(X) >= self.m
assert np.linalg.matrix_rank(X) >= self.r
# Exponential weighting factor - older snapshots are weighted less
if self.exponential_weighting:
weights = (np.sqrt(self.w) ** np.arange(n - 1, -1, -1))[
Expand Down Expand Up @@ -587,7 +625,7 @@ def truncation_error(self, X: np.ndarray, Y: np.ndarray) -> float:
Y_hat = self.A @ X.T
return float(np.linalg.norm(Y - Y_hat.T) / np.linalg.norm(Y))

def transform_one(self, x: dict | np.ndarray) -> np.ndarray:
def transform_one(self, x: dict | np.ndarray) -> dict:
"""
Transforms the given input sample.
Expand All @@ -600,10 +638,11 @@ def transform_one(self, x: dict | np.ndarray) -> np.ndarray:
if isinstance(x, dict):
x = np.array(list(x.values()))

M = self.modes
return x @ M
return dict(zip(range(self.r), x @ self.modes))

def transform_many(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
def transform_many(
self, X: np.ndarray | pd.DataFrame
) -> np.ndarray | pd.DataFrame:
"""
Transforms the given input sequence.
Expand All @@ -613,9 +652,6 @@ def transform_many(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
Returns:
np.ndarray: The transformed input.
"""
if isinstance(X, pd.DataFrame):
X = X.values

M = self.modes
return X @ M

Expand Down

0 comments on commit 2f117dd

Please sign in to comment.