diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index e78b208ff..74870a39d 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -16,6 +16,7 @@ # pylint: disable=invalid-name +import functools from typing import Callable, Optional, Union import torch @@ -29,15 +30,48 @@ __all__ = ['ns', 'ns_inv'] +def _ns_solve( + A: torch.Tensor, + b: torch.Tensor, + maxiter: int, + alpha: Optional[float] = None, +) -> torch.Tensor: + """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + if A.ndim != 2: + raise ValueError(f'`A` must be a 2D tensor, but has shape: {A.shape}') + torch.expand(A) + ndim = b.ndim + if ndim == 0: + raise ValueError(f'`b` must be a vector, but has shape: {b.shape}') + if ndim >= 2: + if any(size != 1 for size in b.shape[1:]): + raise ValueError(f'`b` must be a vector, but has shape: {b.shape}') + b = b[(...,) + (0,) * (ndim - 1)] # squeeze trailing dimensions + + inv_A_hat_b = b + term = b + if alpha is not None: + for _ in range(maxiter): + term = term - alpha * (A @ term) + inv_A_hat_b = inv_A_hat_b + term + else: + for _ in range(maxiter): + term = term - A @ term + inv_A_hat_b = inv_A_hat_b + term + + if ndim >= 2: + inv_A_hat_b = inv_A_hat_b[(...,) + (None,) * (ndim - 1)] # unqueeze trailing dimensions + return inv_A_hat_b + + def ns( A: Union[Callable[[TensorTree], TensorTree], torch.Tensor], b: TensorTree, maxiter: Optional[int] = None, *, alpha: Optional[float] = None, - dtype: Optional[torch.dtype] = None, ) -> TensorTree: - """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. + """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. Args: A: (tensor or tree of tensors or function) @@ -56,34 +90,36 @@ def ns( Returns: The Neumann Series (NS) matrix inversion approximation. """ - shapes = cat_shapes(b) - if len(shapes) >= 2: - raise NotImplementedError - - matvec = normalize_matvec(A) - A = materialize_matvec(matvec, shapes, dtype=dtype) if maxiter is None: - # size = sum(shapes) - maxiter = 1 - - A_flat, treespec = pytree.tree_flatten(A) + maxiter = 10 b_flat = pytree.tree_leaves(b) + if len(b_flat) == 0: + raise ValueError('`b` must be a non-empty pytree.') + if len(b_flat) >= 2: + raise ValueError('`b` must be a pytree with a single leaf.') + b_leaf = b_flat[0] + if b_leaf.ndim >= 2 and any(size != 1 for size in b.shape[1:]): + raise ValueError(f'`b` must be a vector or a scalar, but has shape: {b_leaf.shape}') - if alpha is not None: - - def f(A, b): - return b - alpha * (A @ b) + matvec = normalize_matvec(A) + A: TensorTree = materialize_matvec(matvec, b) + return pytree.tree_map(functools.partial(_ns_solve, maxiter=maxiter, alpha=alpha), A, b) - else: - def f(A, b): - return b - A @ b +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): + """Uses Neumann Series iteration to solve ``A^{-1}``.""" + if A.ndim != 2: + raise ValueError(f'`A` must be a 2D tensor, but has shape: {A.shape}') - inv_A_hat_b_flat = list(b_flat) - for _ in range(maxiter): - b_flat = list(map(f, A_flat, b_flat)) - inv_A_hat_b_flat = list(map(torch.add, inv_A_hat_b_flat, b_flat)) - return pytree.tree_unflatten(treespec, inv_A_hat_b_flat) + I = torch.eye(*A.shape, out=torch.empty_like(A)) + inv_A_hat = torch.zeros_like(A) + if alpha is not None: + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - alpha * A, rank) + else: + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - A, rank) + return inv_A_hat def ns_inv( @@ -92,7 +128,7 @@ def ns_inv( *, alpha: Optional[float] = None, ) -> TensorTree: - """Use Neumann Series iteration to solve ``A^{-1}``. + """Uses Neumann Series iteration to solve ``A^{-1}``. Args: A: (tensor or tree of tensors or function) @@ -112,18 +148,15 @@ def ns_inv( size = sum(cat_shapes(A)) maxiter = 10 * size - A_flat, treespec = pytree.tree_flatten(A) - - I_flat = [torch.eye(*a.size(), out=torch.empty_like(a)) for a in A_flat] - inv_A_hat_flat = [torch.zeros_like(a) for a in A_flat] - if alpha is not None: - for rank in range(maxiter): - power = [torch.linalg.matrix_power(i - alpha * a, rank) for i, a in zip(I_flat, A_flat)] - inv_A_hat_flat = [inv_a + p for inv_a, p in zip(inv_A_hat_flat, power)] - else: - for rank in range(maxiter): - power = [torch.linalg.matrix_power(i - a, rank) for i, a in zip(I_flat, A_flat)] - inv_A_hat_flat = [inv_a + p for inv_a, p in zip(inv_A_hat_flat, power)] - - inv_A_hat = pytree.tree_unflatten(treespec, inv_A_hat_flat) - return inv_A_hat + if isinstance(A, torch.Tensor): + I = torch.eye(*A.shape, out=torch.empty_like(A)) + inv_A_hat = torch.zeros_like(A) + if alpha is not None: + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - alpha * A, rank) + else: + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - A, rank) + return inv_A_hat + + return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index 2439b42fb..dd9683ec5 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -31,13 +31,13 @@ def cat_shapes(tree: TensorTree) -> Tuple[int, ...]: def normalize_matvec( - f: Union[Callable[[TensorTree], TensorTree], torch.Tensor] + matvec: Union[Callable[[TensorTree], TensorTree], torch.Tensor] ) -> Callable[[TensorTree], TensorTree]: """Normalize an argument for computing matrix-vector products.""" - if callable(f): - return f + if callable(matvec): + return matvec - assert isinstance(f, torch.Tensor) - if f.ndim != 2 or f.shape[0] != f.shape[1]: - raise ValueError(f'linear operator must be a square matrix, but has shape: {f.shape}') - return partial(torch.matmul, f) + assert isinstance(matvec, torch.Tensor) + if matvec.ndim != 2 or matvec.shape[0] != matvec.shape[1]: + raise ValueError(f'linear operator must be a square matrix, but has shape: {matvec.shape}') + return partial(torch.matmul, matvec) diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index 340ec983e..068b4de68 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -55,7 +55,8 @@ def _solve_inv( ) -> TensorTree: """Solves ``A x = b`` using matrix inversion. - It will materialize the matrix ``A`` in memory. + It assumes the matrix ``A`` is a constant matrix and will materialize the + matrix ``A`` in memory. Args: matvec: A function that returns the product between ``A`` and a vector. @@ -75,17 +76,15 @@ def _solve_inv( if len(b_flat) >= 2: raise ValueError('`b` must be a pytree with a single leaf.') - b_leaf: torch.Tensor = b_flat[0] - dtype = b_leaf.dtype - shape = b_leaf.shape - if len(shape) == 0: - return pytree.tree_truediv(b, materialize_matvec(matvec, shape=shape, dtype=dtype)) - if len(shape) == 1: + b_leaf = b_flat[0] + if b_leaf.ndim == 0: + return pytree.tree_truediv(b, materialize_matvec(matvec, b)) + if b_leaf.ndim == 1 or all(size == 1 for size in b_leaf.shape[1:]): if ns: return linalg.ns(matvec, b, **kwargs) - A = materialize_matvec(matvec, shape=shape, dtype=dtype) + A = materialize_matvec(matvec, b) return pytree.tree_map(lambda A, b: torch.linalg.inv(A) @ b, A, b) - raise NotImplementedError + raise ValueError(f'`b` must be a vector or a scalar, but has shape: {b_leaf.shape}') def solve_inv(**kwargs): diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index c72ee0cd6..a132b9f11 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -31,12 +31,9 @@ # ============================================================================== """Utilities for linear algebra solvers.""" -# pylint: disable=invalid-name - -from typing import Callable, Optional, Tuple +from typing import Callable import functorch -import torch from torchopt import pytree from torchopt.typing import TensorTree @@ -76,11 +73,6 @@ def ridge_matvec(y: TensorTree) -> TensorTree: return ridge_matvec -def materialize_matvec( - matvec: Callable[[TensorTree], TensorTree], - shape: Tuple[int, ...], - dtype: Optional[torch.dtype] = None, -) -> TensorTree: - """Materializes the matrix ``A`` used in ``matvec(x) = A x``.""" - x = torch.zeros(shape, dtype=dtype) +def materialize_matvec(matvec: Callable[[TensorTree], TensorTree], x: TensorTree) -> TensorTree: + """Materializes the matrix ``A`` used in ``matvec(x) = A @ x``.""" return functorch.jacfwd(matvec)(x)