Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 8, 2022
1 parent 4a7db3c commit 8865eb6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 67 deletions.
112 changes: 72 additions & 40 deletions torchopt/linalg/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# pylint: disable=invalid-name

import functools
from typing import Callable, Optional, Union

import torch
Expand All @@ -29,15 +30,47 @@
__all__ = ['ns', 'ns_inv']


def _ns_solve(
A: torch.Tensor,
b: torch.Tensor,
maxiter: int,
alpha: Optional[float] = None,
) -> torch.Tensor:
"""Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``."""
if A.ndim != 2:
raise ValueError(f'`A` must be a 2D tensor, but has shape: {A.shape}')
ndim = b.ndim
if ndim == 0:
raise ValueError(f'`b` must be a vector, but has shape: {b.shape}')
if ndim >= 2:
if any(size != 1 for size in b.shape[1:]):
raise ValueError(f'`b` must be a vector, but has shape: {b.shape}')
b = b[(...,) + (0,) * (ndim - 1)] # squeeze trailing dimensions

inv_A_hat_b = b
term = b
if alpha is not None:
for _ in range(maxiter):
term = term - alpha * (A @ term)
inv_A_hat_b = inv_A_hat_b + term
else:
for _ in range(maxiter):
term = term - A @ term
inv_A_hat_b = inv_A_hat_b + term

if ndim >= 2:
inv_A_hat_b = inv_A_hat_b[(...,) + (None,) * (ndim - 1)] # unqueeze trailing dimensions
return inv_A_hat_b


def ns(
A: Union[Callable[[TensorTree], TensorTree], torch.Tensor],
b: TensorTree,
maxiter: Optional[int] = None,
*,
alpha: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
) -> TensorTree:
"""Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.
"""Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.
Args:
A: (tensor or tree of tensors or function)
Expand All @@ -56,34 +89,36 @@ def ns(
Returns:
The Neumann Series (NS) matrix inversion approximation.
"""
shapes = cat_shapes(b)
if len(shapes) >= 2:
raise NotImplementedError

matvec = normalize_matvec(A)
A = materialize_matvec(matvec, shapes, dtype=dtype)
if maxiter is None:
# size = sum(shapes)
maxiter = 1

A_flat, treespec = pytree.tree_flatten(A)
maxiter = 10
b_flat = pytree.tree_leaves(b)
if len(b_flat) == 0:
raise ValueError('`b` must be a non-empty pytree.')
if len(b_flat) >= 2:
raise ValueError('`b` must be a pytree with a single leaf.')
b_leaf = b_flat[0]
if b_leaf.ndim >= 2 and any(size != 1 for size in b.shape[1:]):
raise ValueError(f'`b` must be a vector or a scalar, but has shape: {b_leaf.shape}')

if alpha is not None:

def f(A, b):
return b - alpha * (A @ b)
matvec = normalize_matvec(A)
A: TensorTree = materialize_matvec(matvec, b)
return pytree.tree_map(functools.partial(_ns_solve, maxiter=maxiter, alpha=alpha), A, b)

else:

def f(A, b):
return b - A @ b
def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None):
"""Uses Neumann Series iteration to solve ``A^{-1}``."""
if A.ndim != 2:
raise ValueError(f'`A` must be a 2D tensor, but has shape: {A.shape}')

inv_A_hat_b_flat = list(b_flat)
for _ in range(maxiter):
b_flat = list(map(f, A_flat, b_flat))
inv_A_hat_b_flat = list(map(torch.add, inv_A_hat_b_flat, b_flat))
return pytree.tree_unflatten(treespec, inv_A_hat_b_flat)
I = torch.eye(*A.shape, out=torch.empty_like(A))
inv_A_hat = torch.zeros_like(A)
if alpha is not None:
for rank in range(maxiter):
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - alpha * A, rank)
else:
for rank in range(maxiter):
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - A, rank)
return inv_A_hat


def ns_inv(
Expand All @@ -92,7 +127,7 @@ def ns_inv(
*,
alpha: Optional[float] = None,
) -> TensorTree:
"""Use Neumann Series iteration to solve ``A^{-1}``.
"""Uses Neumann Series iteration to solve ``A^{-1}``.
Args:
A: (tensor or tree of tensors or function)
Expand All @@ -112,18 +147,15 @@ def ns_inv(
size = sum(cat_shapes(A))
maxiter = 10 * size

A_flat, treespec = pytree.tree_flatten(A)

I_flat = [torch.eye(*a.size(), out=torch.empty_like(a)) for a in A_flat]
inv_A_hat_flat = [torch.zeros_like(a) for a in A_flat]
if alpha is not None:
for rank in range(maxiter):
power = [torch.linalg.matrix_power(i - alpha * a, rank) for i, a in zip(I_flat, A_flat)]
inv_A_hat_flat = [inv_a + p for inv_a, p in zip(inv_A_hat_flat, power)]
else:
for rank in range(maxiter):
power = [torch.linalg.matrix_power(i - a, rank) for i, a in zip(I_flat, A_flat)]
inv_A_hat_flat = [inv_a + p for inv_a, p in zip(inv_A_hat_flat, power)]

inv_A_hat = pytree.tree_unflatten(treespec, inv_A_hat_flat)
return inv_A_hat
if isinstance(A, torch.Tensor):
I = torch.eye(*A.shape, out=torch.empty_like(A))
inv_A_hat = torch.zeros_like(A)
if alpha is not None:
for rank in range(maxiter):
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - alpha * A, rank)
else:
for rank in range(maxiter):
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(I - A, rank)
return inv_A_hat

return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A)
14 changes: 7 additions & 7 deletions torchopt/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def cat_shapes(tree: TensorTree) -> Tuple[int, ...]:


def normalize_matvec(
f: Union[Callable[[TensorTree], TensorTree], torch.Tensor]
matvec: Union[Callable[[TensorTree], TensorTree], torch.Tensor]
) -> Callable[[TensorTree], TensorTree]:
"""Normalize an argument for computing matrix-vector products."""
if callable(f):
return f
if callable(matvec):
return matvec

assert isinstance(f, torch.Tensor)
if f.ndim != 2 or f.shape[0] != f.shape[1]:
raise ValueError(f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(torch.matmul, f)
assert isinstance(matvec, torch.Tensor)
if matvec.ndim != 2 or matvec.shape[0] != matvec.shape[1]:
raise ValueError(f'linear operator must be a square matrix, but has shape: {matvec.shape}')
return partial(torch.matmul, matvec)
17 changes: 8 additions & 9 deletions torchopt/linear_solve/inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def _solve_inv(
) -> TensorTree:
"""Solves ``A x = b`` using matrix inversion.
It will materialize the matrix ``A`` in memory.
It assumes the matrix ``A`` is a constant matrix and will materialize the
matrix ``A`` in memory.
Args:
matvec: A function that returns the product between ``A`` and a vector.
Expand All @@ -75,17 +76,15 @@ def _solve_inv(
if len(b_flat) >= 2:
raise ValueError('`b` must be a pytree with a single leaf.')

b_leaf: torch.Tensor = b_flat[0]
dtype = b_leaf.dtype
shape = b_leaf.shape
if len(shape) == 0:
return pytree.tree_truediv(b, materialize_matvec(matvec, shape=shape, dtype=dtype))
if len(shape) == 1:
b_leaf = b_flat[0]
if b_leaf.ndim == 0:
return pytree.tree_truediv(b, materialize_matvec(matvec, b))
if b_leaf.ndim == 1 or all(size == 1 for size in b_leaf.shape[1:]):
if ns:
return linalg.ns(matvec, b, **kwargs)
A = materialize_matvec(matvec, shape=shape, dtype=dtype)
A = materialize_matvec(matvec, b)
return pytree.tree_map(lambda A, b: torch.linalg.inv(A) @ b, A, b)
raise NotImplementedError
raise ValueError(f'`b` must be a vector or a scalar, but has shape: {b_leaf.shape}')


def solve_inv(**kwargs):
Expand Down
14 changes: 3 additions & 11 deletions torchopt/linear_solve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@
# ==============================================================================
"""Utilities for linear algebra solvers."""

# pylint: disable=invalid-name

from typing import Callable, Optional, Tuple
from typing import Callable

import functorch
import torch

from torchopt import pytree
from torchopt.typing import TensorTree
Expand Down Expand Up @@ -76,11 +73,6 @@ def ridge_matvec(y: TensorTree) -> TensorTree:
return ridge_matvec


def materialize_matvec(
matvec: Callable[[TensorTree], TensorTree],
shape: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
) -> TensorTree:
"""Materializes the matrix ``A`` used in ``matvec(x) = A x``."""
x = torch.zeros(shape, dtype=dtype)
def materialize_matvec(matvec: Callable[[TensorTree], TensorTree], x: TensorTree) -> TensorTree:
"""Materializes the matrix ``A`` used in ``matvec(x) = A @ x``."""
return functorch.jacfwd(matvec)(x)

0 comments on commit 8865eb6

Please sign in to comment.