From 37e82681f78b72851c5bf8ad1c4416f3ec8b7ec8 Mon Sep 17 00:00:00 2001 From: Bo Liu Date: Wed, 9 Nov 2022 15:54:19 +0800 Subject: [PATCH] feat(linear_solve): matrix inversion linear solver with neumann series approximation (#98) * feat(linear_solve): matrix inversion linear solver with neumann series approximation Co-authored-by: Xuehai Pan --- .pylintrc | 3 +- CHANGELOG.md | 1 + docs/source/api/api.rst | 4 +- docs/source/spelling_wordlist.txt | 1 + tests/test_implicit.py | 260 ++++++++++++++++++++++++++++- torchopt/linalg/__init__.py | 3 +- torchopt/linalg/cg.py | 57 ++----- torchopt/linalg/ns.py | 161 ++++++++++++++++++ torchopt/linalg/utils.py | 55 ++++++ torchopt/linear_solve/__init__.py | 3 +- torchopt/linear_solve/cg.py | 32 +++- torchopt/linear_solve/inv.py | 122 ++++++++++++++ torchopt/linear_solve/normal_cg.py | 34 +++- torchopt/linear_solve/utils.py | 52 +++++- torchopt/pytree.py | 122 +++++++++++++- 15 files changed, 831 insertions(+), 79 deletions(-) create mode 100644 torchopt/linalg/ns.py create mode 100644 torchopt/linalg/utils.py create mode 100644 torchopt/linear_solve/inv.py diff --git a/.pylintrc b/.pylintrc index 73dbcdf0..6b86faaf 100644 --- a/.pylintrc +++ b/.pylintrc @@ -267,7 +267,8 @@ good-names=i, lr, mu, nu, - x + x, + y # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/CHANGELOG.md b/CHANGELOG.md index d0747f6c..278ad1d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98). - Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105). - Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107). - Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48). diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 1ec8e1d2..97d8af30 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -157,7 +157,7 @@ Implicit Meta-Gradient Module ------ -Linear system solving +Linear system solvers ===================== .. currentmodule:: torchopt.linear_solve @@ -166,12 +166,14 @@ Linear system solving solve_cg solve_normal_cg + solve_inv Indirect solvers ~~~~~~~~~~~~~~~~ .. autofunction:: solve_cg .. autofunction:: solve_normal_cg +.. autofunction:: solve_inv ------ diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 25f11953..cd8bb152 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -92,3 +92,4 @@ ints Karush Kuhn Tucker +Neumann diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 661a6627..5655988a 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -24,6 +24,7 @@ import jaxopt import numpy as np import optax +import pytest import torch import torch.nn as nn import torch.nn.functional as F @@ -82,7 +83,6 @@ def get_model_torch( dataset = data.TensorDataset( torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)), - # torch.empty((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS), dtype=dtype).uniform_(-1.0, +1.0), torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)), ) loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False) @@ -113,7 +113,9 @@ def get_rr_dataset_torch() -> data.DataLoader: inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None: +def test_imaml_solve_normal_cg( + dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int +) -> None: np_dtype = helpers.dtype_torch2numpy(dtype) jax_model, jax_params = get_model_jax(dtype=np_dtype) @@ -136,7 +138,10 @@ def imaml_objective_torchopt(params, meta_params, data): return loss @torchopt.diff.implicit.custom_root( - functorch.grad(imaml_objective_torchopt, argnums=0), argnums=1, has_aux=True + functorch.grad(imaml_objective_torchopt, argnums=0), + argnums=1, + has_aux=True, + solve=torchopt.linear_solve.solve_normal_cg(), ) def inner_solver_torchopt(params, meta_params, data): # Initial functional optimizer based on TorchOpt @@ -167,7 +172,11 @@ def imaml_objective_jax(params, meta_params, x, y): loss = loss + regularization_loss return loss - @jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True) + @jaxopt.implicit_diff.custom_root( + jax.grad(imaml_objective_jax, argnums=0), + has_aux=True, + solve=jaxopt.linear_solve.solve_normal_cg, + ) def inner_solver_jax(params, meta_params, x, y): """Solve ridge regression by conjugate gradient.""" # Initial functional optimizer based on torchopt @@ -225,6 +234,134 @@ def outer_level(p, xs, ys): helpers.assert_all_close(p, p_ref) +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + inner_lr=[2e-2, 2e-3], + inner_update=[20, 50, 100], + ns=[False, True], +) +def test_imaml_solve_inv( + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, + ns: bool, +) -> None: + np_dtype = helpers.dtype_torch2numpy(dtype) + + jax_model, jax_params = get_model_jax(dtype=np_dtype) + model, loader = get_model_torch(device='cpu', dtype=dtype) + + fmodel, params = functorch.make_functional(model) + optim = torchopt.sgd(lr) + optim_state = optim.init(params) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(jax_params) + + def imaml_objective_torchopt(params, meta_params, data): + x, y, f = data + y_pred = f(params, x) + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + loss = F.cross_entropy(y_pred, y) + regularization_loss + return loss + + @torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective_torchopt, argnums=0), + argnums=1, + solve=torchopt.linear_solve.solve_inv(ns=ns), + ) + def inner_solver_torchopt(params, meta_params, data): + # Initial functional optimizer based on TorchOpt + x, y, f = data + optimizer = torchopt.sgd(lr=inner_lr) + opt_state = optimizer.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(inner_update): + pred = f(params, x) + loss = F.cross_entropy(pred, y) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, meta_params): + regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates + params = torchopt.apply_updates(params, updates, inplace=True) + return params + + def imaml_objective_jax(params, meta_params, x, y): + y_pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y)) + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2))) + loss = loss + regularization_loss + return loss + + @jaxopt.implicit_diff.custom_root( + jax.grad(imaml_objective_jax, argnums=0), + solve=jaxopt.linear_solve.solve_normal_cg, + ) + def inner_solver_jax(params, meta_params, x, y): + """Solve ridge regression by conjugate gradient.""" + # Initial functional optimizer based on torchopt + optimizer = optax.sgd(inner_lr) + opt_state = optimizer.init(params) + + def compute_loss(params, meta_params, x, y): + pred = jax_model(params, x) + loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y)) + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params.values(), meta_params.values()): + regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2))) + final_loss = loss + regularization_loss + return final_loss + + for i in range(inner_update): + grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients + updates, opt_state = optimizer.update(grads, opt_state) # get updates + params = optax.apply_updates(params, updates) + return params + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + data = (xs, ys, fmodel) + meta_params_copy = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params + ) + optimal_params = inner_solver_torchopt(meta_params_copy, params, data) + outer_loss = fmodel(optimal_params, xs).mean() + + grads = torch.autograd.grad(outer_loss, params) + updates, optim_state = optim.update(grads, optim_state) + params = torchopt.apply_updates(params, updates) + + xs = xs.numpy() + ys = ys.numpy() + + def outer_level(p, xs, ys): + optimal_params = inner_solver_jax(copy.deepcopy(p), p, xs, ys) + outer_loss = jax_model(optimal_params, xs).mean() + return outer_loss + + grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + jax_params = optax.apply_updates(jax_params, updates_jax) + + jax_params_as_tensor = tuple( + nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params + ) + + for p, p_ref in zip(params, jax_params_as_tensor): + helpers.assert_all_close(p, p_ref) + + @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4], @@ -341,7 +478,7 @@ def outer_level(p, xs, ys): dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4], ) -def test_rr( +def test_rr_solve_cg( dtype: torch.dtype, lr: float, ) -> None: @@ -371,7 +508,7 @@ def ridge_objective_torch(params, l2reg, data): return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) - def ridge_solver_torch(params, l2reg, data): + def ridge_solver_torch_cg(params, l2reg, data): """Solve ridge regression by conjugate gradient.""" X_tr, y_tr = data @@ -393,7 +530,7 @@ def ridge_objective_jax(params, l2reg, X_tr, y_tr): return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) - def ridge_solver_jax(params, l2reg, X_tr, y_tr): + def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr): """Solve ridge regression by conjugate gradient.""" def matvec(u): @@ -413,7 +550,112 @@ def matvec(u): xq = xq.to(dtype=dtype) yq = yq.to(dtype=dtype) - w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys)) + w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys)) + outer_loss = F.mse_loss(xq @ w_fit, yq) + + grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch) + updates, optim_state = optim.update(grads, optim_state) + l2reg_torch = torchopt.apply_updates(l2reg_torch, updates) + + xs = jnp.array(xs.numpy(), dtype=np_dtype) + ys = jnp.array(ys.numpy(), dtype=np_dtype) + xq = jnp.array(xq.numpy(), dtype=np_dtype) + yq = jnp.array(yq.numpy(), dtype=np_dtype) + + def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): + w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys) + y_pred = xq @ w_fit + loss_value = jnp.mean(jnp.square(y_pred - yq)) + return loss_value + + grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq) + updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates + l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax) + + l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype) + helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + + +@helpers.parametrize( + dtype=[torch.float64, torch.float32], + lr=[1e-3, 1e-4], + ns=[True, False], +) +def test_rr_solve_inv( + dtype: torch.dtype, + lr: float, + ns: bool, +) -> None: + if dtype == torch.float64 and ns: + pytest.skip('Neumann Series test skips torch.float64 due to numerical stability.') + helpers.seed_everything(42) + np_dtype = helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype) + l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype) + + loader = get_rr_dataset_torch() + + optim = torchopt.sgd(lr) + optim_state = optim.init(l2reg_torch) + + optim_jax = optax.sgd(lr) + optim_state_jax = optim_jax.init(l2reg_jax) + + def ridge_objective_torch(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + @torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1) + def ridge_solver_torch_inv(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_inv( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + ns=ns, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + def ridge_objective_jax(params, l2reg, X_tr, y_tr): + """Ridge objective function.""" + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params)) + return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss + + @jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0)) + def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr): + """Solve ridge regression by conjugate gradient.""" + + def matvec(u): + return X_tr.T @ ((X_tr @ u)) + + return jaxopt.linear_solve.solve_inv( + matvec=matvec, + b=X_tr.T @ y_tr, + ridge=len(y_tr) * l2reg.item(), + ) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys)) outer_loss = F.mse_loss(xq @ w_fit, yq) grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch) @@ -426,7 +668,7 @@ def matvec(u): yq = jnp.array(yq.numpy(), dtype=np_dtype) def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): - w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys) + w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys) y_pred = xq @ w_fit loss_value = jnp.mean(jnp.square(y_pred - yq)) return loss_value diff --git a/torchopt/linalg/__init__.py b/torchopt/linalg/__init__.py index 4fff4df2..20dc16aa 100644 --- a/torchopt/linalg/__init__.py +++ b/torchopt/linalg/__init__.py @@ -32,6 +32,7 @@ """Linear algebra functions.""" from torchopt.linalg.cg import cg +from torchopt.linalg.ns import ns, ns_inv -__all__ = ['cg'] +__all__ = ['cg', 'ns', 'ns_inv'] diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 28307666..94daee53 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -34,51 +34,23 @@ # pylint: disable=invalid-name from functools import partial -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union import torch from torchopt import pytree +from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.pytree import tree_vdot_real from torchopt.typing import TensorTree __all__ = ['cg'] -def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float: - """Computes dot(x.conj(), y).real.""" - x = x.contiguous().view(-1) - y = y.contiguous().view(-1) - prod = torch.dot(x.real, y.real).item() - if x.is_complex() and y.is_complex(): - prod += torch.dot(x.imag, y.imag).item() - return prod - - -def tree_vdot_real(tree_x: TensorTree, tree_y: TensorTree) -> float: - """Computes dot(tree_x.conj(), tree_y).real.sum().""" - leaves_x, treespec = pytree.tree_flatten(tree_x) - leaves_y = treespec.flatten_up_to(tree_y) - return sum(map(_vdot_real_kernel, leaves_x, leaves_y)) # type: ignore[arg-type] - - def _identity(x: TensorTree) -> TensorTree: return x -def _normalize_matvec( - f: Union[Callable[[TensorTree], TensorTree], torch.Tensor] -) -> Callable[[TensorTree], TensorTree]: - """Normalize an argument for computing matrix-vector products.""" - if callable(f): - return f - - assert isinstance(f, torch.Tensor) - if f.ndim != 2 or f.shape[0] != f.shape[1]: - raise ValueError(f'linear operator must be a square matrix, but has shape: {f.shape}') - return partial(torch.matmul, f) - - # pylint: disable-next=too-many-locals def _cg_solve( A: Callable[[TensorTree], TensorTree], @@ -126,37 +98,32 @@ def body_fn(value): return x_final -def _shapes(tree: TensorTree) -> List[int]: - flattened = pytree.tree_leaves(tree) - return pytree.tree_leaves([tuple(term.shape) for term in flattened]) # type: ignore[arg-type] - - def _isolve( _isolve_solve: Callable, - A: Union[torch.Tensor, Callable[[TensorTree], TensorTree]], + A: Union[TensorTree, Callable[[TensorTree], TensorTree]], b: TensorTree, x0: Optional[TensorTree] = None, *, rtol: float = 1e-5, atol: float = 0.0, maxiter: Optional[int] = None, - M: Optional[Union[torch.Tensor, Callable[[TensorTree], TensorTree]]] = None, + M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, ) -> TensorTree: if x0 is None: x0 = pytree.tree_map(torch.zeros_like, b) if maxiter is None: - size = sum(_shapes(b)) + size = sum(cat_shapes(b)) maxiter = 10 * size # copied from SciPy if M is None: M = _identity - A = _normalize_matvec(A) - M = _normalize_matvec(M) + A = normalize_matvec(A) + M = normalize_matvec(M) - if _shapes(x0) != _shapes(b): + if cat_shapes(x0) != cat_shapes(b): raise ValueError( - 'arrays in x0 and b must have matching shapes: ' f'{_shapes(x0)} vs {_shapes(b)}' + f'Tensors in x0 and b must have matching shapes: {cat_shapes(x0)} vs. {cat_shapes(b)}.' ) isolve_solve = partial(_isolve_solve, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M) @@ -166,14 +133,14 @@ def _isolve( def cg( - A: Union[torch.Tensor, Callable[[TensorTree], TensorTree]], + A: Union[TensorTree, Callable[[TensorTree], TensorTree]], b: TensorTree, x0: Optional[TensorTree] = None, *, rtol: float = 1e-5, atol: float = 0.0, maxiter: Optional[int] = None, - M: Optional[Union[torch.Tensor, Callable[[TensorTree], TensorTree]]] = None, + M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, ) -> TensorTree: """Use Conjugate Gradient iteration to solve ``Ax = b``. diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py new file mode 100644 index 00000000..4da8ef9f --- /dev/null +++ b/torchopt/linalg/ns.py @@ -0,0 +1,161 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + +# pylint: disable=invalid-name + +import functools +from typing import Callable, Optional, Union + +import torch + +from torchopt import pytree +from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.typing import TensorTree + + +__all__ = ['ns', 'ns_inv'] + + +def _ns_solve( + A: torch.Tensor, + b: torch.Tensor, + maxiter: int, + alpha: Optional[float] = None, +) -> torch.Tensor: + """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') + + inv_A_hat_b = b + v = b + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + for _ in range(maxiter): + v = v - alpha * (A @ v) + inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = alpha * inv_A_hat_b + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + for _ in range(maxiter): + v = v - A @ v + inv_A_hat_b = inv_A_hat_b + v + + return inv_A_hat_b + + +def ns( + A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + b: TensorTree, + maxiter: Optional[int] = None, + *, + alpha: Optional[float] = None, +) -> TensorTree: + """Uses Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. + + Args: + A: (tensor or tree of tensors or function) + 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when + called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and + must return array(s) with the same structure and shape as its argument. + b: (tensor or tree of tensors) + Right hand side of the linear system representing a single vector. Can be stored as an + array or Python container of array(s) with any shape. + maxiter: (integer, optional) + Maximum number of iterations. Iteration will stop after maxiter steps even if the + specified tolerance has not been achieved. + alpha: (float, optional) + Decay coefficient. + + Returns: + The Neumann Series (NS) matrix inversion approximation. + """ + if maxiter is None: + maxiter = 10 + + if not callable(A): + return pytree.tree_map(functools.partial(_ns_solve, maxiter=maxiter, alpha=alpha), A, b) + + matvec = normalize_matvec(A) + inv_A_hat_b = b + v = b + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + for _ in range(maxiter): + # v = v - alpha * (A @ v) + v = pytree.tree_sub_scalar_mul(v, matvec(v), alpha=alpha) + # inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = pytree.tree_add(inv_A_hat_b, v) + # inv_A_hat_b = alpha * inv_A_hat_b + inv_A_hat_b = pytree.tree_scalar_mul(alpha, inv_A_hat_b) + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + for _ in range(maxiter): + # v = v - A @ v + v = pytree.tree_sub(v, matvec(v)) + # inv_A_hat_b = inv_A_hat_b + v + inv_A_hat_b = pytree.tree_add(inv_A_hat_b, v) + + return inv_A_hat_b + + +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): + """Uses Neumann Series iteration to solve ``A^{-1}``.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') + + I = torch.eye(*A.shape, out=torch.empty_like(A)) + inv_A_hat = torch.zeros_like(A) + if alpha is not None: + # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] + M = I - alpha * A + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) + inv_A_hat = alpha * inv_A_hat + else: + # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... + M = I - A + for rank in range(maxiter): + inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) + return inv_A_hat + + +def ns_inv( + A: TensorTree, + maxiter: Optional[int] = None, + *, + alpha: Optional[float] = None, +) -> TensorTree: + """Uses Neumann Series iteration to solve ``A^{-1}``. + + Args: + A: (tensor or tree of tensors or function) + 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when + called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and + must return array(s) with the same structure and shape as its argument. + maxiter: (integer, optional) + Maximum number of iterations. Iteration will stop after maxiter steps even if the + specified tolerance has not been achieved. + alpha: (float, optional) + Decay coefficient. + + Returns: + The Neumann Series (NS) matrix inversion approximation. + """ + if maxiter is None: + size = sum(cat_shapes(A)) + maxiter = 10 * size # copied from SciPy + + return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py new file mode 100644 index 00000000..f2440b9a --- /dev/null +++ b/torchopt/linalg/utils.py @@ -0,0 +1,55 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for linear algebra.""" + +import itertools +from typing import Callable, Tuple, Union + +import torch + +from torchopt import pytree +from torchopt.typing import TensorTree + + +def cat_shapes(tree: TensorTree) -> Tuple[int, ...]: + """Concatenates the shapes of the leaves of a tree of tensors.""" + leaves = pytree.tree_leaves(tree) + return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves)) + + +def normalize_matvec( + matvec: Union[TensorTree, Callable[[TensorTree], TensorTree]] +) -> Callable[[TensorTree], TensorTree]: + """Normalizes an argument for computing matrix-vector product.""" + if callable(matvec): + return matvec + + mat_flat, treespec = pytree.tree_flatten(matvec) + for mat in mat_flat: + if not isinstance(mat, torch.Tensor) or mat.ndim != 2 or mat.shape[0] != mat.shape[1]: + raise TypeError(f'Linear operator must be a square matrix, but has shape: {mat.shape}') + + def _matvec(x: TensorTree) -> TensorTree: + x_flat = pytree.tree_leaves(x) + if len(x_flat) != len(mat_flat): + raise ValueError( + f'`x` must have the same number of leaves as `matvec`, ' + f'but has {len(x_flat)} leaves and `matvec` has {len(mat_flat)} leaves' + ) + + y_flat = map(torch.matmul, mat_flat, x_flat) + return pytree.tree_unflatten(treespec, y_flat) + + return _matvec diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py index e27ba4c0..8d9115d3 100644 --- a/torchopt/linear_solve/__init__.py +++ b/torchopt/linear_solve/__init__.py @@ -32,7 +32,8 @@ """Linear algebra solvers.""" from torchopt.linear_solve.cg import solve_cg +from torchopt.linear_solve.inv import solve_inv from torchopt.linear_solve.normal_cg import solve_normal_cg -__all__ = ['solve_cg', 'solve_normal_cg'] +__all__ = ['solve_cg', 'solve_normal_cg', 'solve_inv'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index 9c150038..2ffc8217 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -53,7 +53,7 @@ def _solve_cg( ) -> TensorTree: """Solves ``A x = b`` using conjugate gradient. - It assumes that ``A`` is a hermitian, positive definite matrix. + This assumes that ``A`` is a hermitian, positive definite matrix. Args: matvec: A function that returns the product between ``A`` and a vector. @@ -75,5 +75,33 @@ def _solve_cg( def solve_cg(**kwargs): - """Wrapper for :func:`solve_cg`.""" + """A wrapper that returns a solver function to solve ``A x = b`` using conjugate gradient. + + This assumes that ``A`` is a hermitian, positive definite matrix. + + Args: + ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. + init: Optional initialization to be used by conjugate gradient. + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A x = b`` using + conjugate gradient where ``matvec(v) = A v``. + + See Also: + Conjugate gradient iteration :func:`torchopt.linalg.cg`. + + Example:: + + >>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + + """ return functools.partial(_solve_cg, **kwargs) diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py new file mode 100644 index 00000000..bf36f40e --- /dev/null +++ b/torchopt/linear_solve/inv.py @@ -0,0 +1,122 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/google/jaxopt/blob/main/jaxopt/_src/linear_solve.py +# ============================================================================== +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra solver for ``A x = b`` using matrix inversion.""" + +# pylint: disable=invalid-name + +import functools +from typing import Callable, Optional + +import torch + +from torchopt import linalg, pytree +from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec +from torchopt.typing import TensorTree + + +__all__ = ['solve_inv'] + + +def _solve_inv( + matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x + b: TensorTree, + ridge: Optional[float] = None, + ns: bool = False, + **kwargs, +) -> TensorTree: + """Solves ``A x = b`` using matrix inversion. + + If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it + in memory. + + Args: + matvec: A function that returns the product between ``A`` and a vector. + b: A tensor for the right hand side of the equation. + ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. + ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, + materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation + solver :func:`torchopt.linalg.ns`. + + Returns: + The solution with the same shape as ``b``. + """ + if ridge is not None: + # (x) -> A @ x + ridge * x + # i.e. (x) -> (A + ridge * I) @ x + matvec = make_ridge_matvec(matvec, ridge=ridge) + + b_flat = pytree.tree_leaves(b) + if len(b_flat) == 1 and b_flat[0].ndim == 0: + A, *_ = materialize_matvec(matvec, b) + return pytree.tree_truediv(b, A) + + if ns: + return linalg.ns(matvec, b, **kwargs) + + A, _, tree_ravel, tree_unravel = materialize_matvec(matvec, b) + return tree_unravel(pytree.tree_map(torch.linalg.solve, A, tree_ravel(b))) + + +def solve_inv(**kwargs): + """A wrapper that returns a solver function to solve ``A x = b`` using matrix inversion. + + If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it + in memory. + + Args: + ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. + ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, + materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation + solver :func:`torchopt.linalg.ns`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A x = b`` using matrix + inversion where ``matvec(v) = A v``. + + See Also: + Neumann Series matrix inversion approximation :func:`torchopt.linalg.ns`. + + Example:: + + >>> A = {'a': torch.eye(5, 5), 'b': torch.eye(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_inv(ns=True, maxiter=10) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + + """ + return functools.partial(_solve_inv, **kwargs) diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 38cb2834..3646d7f4 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -61,7 +61,8 @@ def _solve_normal_cg( b: A tree of tensors for the right hand side of the equation. ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. init: Optional initialization to be used by normal conjugate gradient. - **kwargs: Additional keyword arguments for the conjugate gradient solver. + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. Returns: The solution with the same structure as ``b``. @@ -86,5 +87,34 @@ def _solve_normal_cg( def solve_normal_cg(**kwargs): - """Wrapper for :func:`solve_normal_cg`.""" + """A wrapper that returns a solver function to solve ``A^T A x = A^T b`` using conjugate gradient. + + This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, + positive definite. + + Args: + ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. + init: Optional initialization to be used by normal conjugate gradient. + **kwargs: Additional keyword arguments for the conjugate gradient solver + :func:`torchopt.linalg.cg`. + + Returns: + A solver function with signature ``(matvec, b) -> x`` that solves ``A^T A x = A^T b`` using + conjugate gradient where ``matvec(v) = A v``. + + See Also: + Conjugate gradient iteration :func:`torchopt.linalg.cg`. + + Example:: + + >>> A = {'a': torch.randn(5, 5), 'b': torch.randn(3, 3)} + >>> x = {'a': torch.randn(5), 'b': torch.randn(3)} + >>> def matvec(x: TensorTree) -> TensorTree: + ... return {'a': A['a'] @ x['a'], 'b': A['b'] @ x['b']} + >>> b = matvec(x) + >>> solver = solve_normal_cg(init={'a': torch.zeros(5), 'b': torch.zeros(3)}) + >>> x_hat = solver(matvec, b) + >>> assert torch.allclose(x_hat['a'], x['a']) and torch.allclose(x_hat['b'], x['b']) + + """ return functools.partial(_solve_normal_cg, **kwargs) diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 5fede71d..a7e93e65 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -31,9 +31,7 @@ # ============================================================================== """Utilities for linear algebra solvers.""" -# pylint: disable=invalid-name - -from typing import Callable +from typing import Callable, Tuple import functorch @@ -41,11 +39,6 @@ from torchopt.typing import TensorTree -def tree_add(tree_x: TensorTree, tree_y: TensorTree, alpha: float = 1.0) -> TensorTree: - """Computes tree_x + alpha * tree_y.""" - return pytree.tree_map(lambda x, y: x.add(y, alpha=alpha), tree_x, tree_y) - - def make_rmatvec( matvec: Callable[[TensorTree], TensorTree], example_x: TensorTree ) -> Callable[[TensorTree], TensorTree]: @@ -75,6 +68,47 @@ def make_ridge_matvec( def ridge_matvec(y: TensorTree) -> TensorTree: """Computes ``A.T @ A @ v + ridge * v`` from ``matvec(x) = A @ x``.""" - return tree_add(matvec(y), y, alpha=ridge) + return pytree.tree_add_scalar_mul(matvec(y), y, alpha=ridge) return ridge_matvec + + +def materialize_matvec( + matvec: Callable[[TensorTree], TensorTree], x: TensorTree +) -> Tuple[ + TensorTree, + Callable[[TensorTree], TensorTree], + Callable[[TensorTree], TensorTree], + Callable[[TensorTree], TensorTree], +]: + """Materializes the matrix ``A`` used in ``matvec(x) = A @ x``.""" + x_flat, treespec = pytree.tree_flatten(x) + shapes = tuple(t.shape for t in x_flat) + + if all(t.ndim == 1 for t in x_flat): + + def tree_ravel(x: TensorTree) -> TensorTree: + return x + + def tree_unravel(y: TensorTree) -> TensorTree: + return y + + matvec_ravel = matvec + + else: + + def tree_ravel(x: TensorTree) -> TensorTree: + return pytree.tree_map(lambda t: t.contiguous().view(-1), x) + + def tree_unravel(y: TensorTree) -> TensorTree: + shapes_iter = iter(shapes) + return pytree.tree_map(lambda t: t.contiguous().view(next(shapes_iter)), y) + + def matvec_ravel(y: TensorTree) -> TensorTree: + return tree_ravel(matvec(tree_unravel(y))) + + nargs = len(x_flat) + jacobian_tree = functorch.jacfwd(matvec_ravel)(tree_ravel(x)) + jacobian_flat = pytree.tree_leaves(jacobian_tree) + jacobian_diag = [jacobian_flat[i + i * nargs] for i in range(nargs)] + return pytree.tree_unflatten(treespec, jacobian_diag), matvec_ravel, tree_ravel, tree_unravel diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 65aba6a2..f1dd26e0 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -14,17 +14,35 @@ # ============================================================================== """The PyTree utilities.""" +import functools +import operator from typing import Callable, List, Optional, Tuple import optree import optree.typing as typing # pylint: disable=unused-import +import torch import torch.distributed.rpc as rpc from optree import * # pylint: disable=wildcard-import,unused-wildcard-import -from torchopt.typing import Future, PyTree, RRef, T +from torchopt.typing import Future, RRef, Scalar, T, TensorTree -__all__ = [*optree.__all__, 'tree_flatten_as_tuple', 'tree_wait'] +__all__ = [ + *optree.__all__, + 'tree_flatten_as_tuple', + 'tree_pos', + 'tree_neg', + 'tree_add', + 'tree_add_scalar_mul', + 'tree_sub', + 'tree_sub_scalar_mul', + 'tree_mul', + 'tree_matmul', + 'tree_scalar_mul', + 'tree_truediv', + 'tree_vdot_real', + 'tree_wait', +] def tree_flatten_as_tuple( @@ -48,10 +66,98 @@ def tree_flatten_as_tuple( return tuple(leaves), treespec +def acc_add(*args: T) -> T: + """Accumulate addition.""" + return functools.reduce(operator.add, args) + + +def acc_mul(*args: T) -> T: + """Accumulate multiplication.""" + return functools.reduce(operator.mul, args) + + +def acc_matmul(*args: T) -> T: + """Accumulate matrix multiplication.""" + return functools.reduce(operator.matmul, args) + + +def tree_pos(tree: PyTree[T]) -> PyTree[T]: + """Applies `operator.pos` over leaves.""" + return tree_map(operator.pos, tree) + + +def tree_neg(tree: PyTree[T]) -> PyTree[T]: + """Applies `operator.neg` over leaves.""" + return tree_map(operator.neg, tree) + + +def tree_add(*trees: PyTree[T]) -> PyTree[T]: + """Tree addition over leaves.""" + return tree_map(acc_add, *trees) + + +def tree_add_scalar_mul( + tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None +) -> TensorTree: + """Computes tree_x + alpha * tree_y.""" + if alpha is None: + return tree_map(lambda x, y: x.add(y), tree_x, tree_y) + return tree_map(lambda x, y: x.add(y, alpha=alpha), tree_x, tree_y) + + +def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]: + """Tree subtraction over leaves.""" + return tree_map(operator.sub, minuend_tree, subtrahend_tree) + + +def tree_sub_scalar_mul( + tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None +) -> TensorTree: + """Computes tree_x - alpha * tree_y.""" + if alpha is None: + return tree_map(lambda x, y: x.sub(y), tree_x, tree_y) + return tree_map(lambda x, y: x.sub(y, alpha=alpha), tree_x, tree_y) + + +def tree_mul(*trees: PyTree[T]) -> PyTree[T]: + """Tree multiplication over leaves.""" + return tree_map(acc_mul, *trees) + + +def tree_matmul(*trees: PyTree[T]) -> PyTree[T]: + """Tree matrix multiplication over leaves.""" + return tree_map(acc_matmul, *trees) + + +def tree_scalar_mul(scalar: Scalar, multiplicand_tree: PyTree[T]) -> PyTree[T]: + """Tree scalar multiplication over leaves.""" + return tree_map(lambda x: scalar * x, multiplicand_tree) + + +def tree_truediv(dividend_tree: PyTree[T], divisor_tree: PyTree[T]) -> PyTree[T]: + """Tree division over leaves.""" + return tree_map(operator.truediv, dividend_tree, divisor_tree) + + +def _vdot_real_kernel(x: torch.Tensor, y: torch.Tensor) -> float: + """Computes dot(x.conj(), y).real.""" + x = x.contiguous().view(-1) + y = y.contiguous().view(-1) + vdot = torch.dot(x.real, y.real).item() + if x.is_complex() and y.is_complex(): + vdot += torch.dot(x.imag, y.imag).item() + return vdot + + +def tree_vdot_real(tree_x: TensorTree, tree_y: TensorTree) -> float: + """Computes dot(tree_x.conj(), tree_y).real.sum().""" + leaves_x, treespec = tree_flatten(tree_x) + leaves_y = treespec.flatten_up_to(tree_y) + return sum(map(_vdot_real_kernel, leaves_x, leaves_y)) # type: ignore[arg-type] + + def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]: r"""Convert a tree of :class:`Future`\s to a tree of results.""" - import torch # pylint: disable=import-outside-toplevel - futures, treespec = tree_flatten(future_tree) results = torch.futures.wait_all(futures) @@ -61,7 +167,7 @@ def tree_wait(future_tree: PyTree[Future[T]]) -> PyTree[T]: if rpc.is_available(): - def tree_as_rref(tree: PyTree[T]) -> 'PyTree[RRef[T]]': + def tree_as_rref(tree: PyTree[T]) -> PyTree[RRef[T]]: r"""Convert a tree of local objects to a tree of :class:`RRef`\s.""" # pylint: disable-next=import-outside-toplevel,redefined-outer-name,reimported from torch.distributed.rpc import RRef @@ -69,17 +175,17 @@ def tree_as_rref(tree: PyTree[T]) -> 'PyTree[RRef[T]]': return tree_map(RRef, tree) def tree_to_here( - rref_tree: 'PyTree[RRef[T]]', + rref_tree: PyTree[RRef[T]], timeout: float = rpc.api.UNSET_RPC_TIMEOUT, ) -> PyTree[T]: r"""Convert a tree of :class:`RRef`\s to a tree of local objects.""" return tree_map(lambda x: x.to_here(timeout=timeout), rref_tree) - def tree_local_value(rref_tree: 'PyTree[RRef[T]]'): + def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: r"""Return the local value of a tree of :class:`RRef`\s.""" return tree_map(lambda x: x.local_value(), rref_tree) __all__.extend(['tree_as_rref', 'tree_to_here']) -del Callable, List, Optional, Tuple, optree, rpc, PyTree, T, RRef +del Callable, List, Optional, Tuple, optree, rpc, Scalar, T, RRef