Skip to content

Commit

Permalink
feat(linear_solve): matrix inversion linear solver with neumann serie…
Browse files Browse the repository at this point in the history
…s approximation (#98)

* feat(linear_solve): matrix inversion linear solver with neumann series approximation
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
  • Loading branch information
Benjamin-eecs authored Nov 9, 2022
1 parent 23253b1 commit 0185f2e
Show file tree
Hide file tree
Showing 15 changed files with 831 additions and 79 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ good-names=i,
lr,
mu,
nu,
x
x,
y

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98).
- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
Expand Down
4 changes: 3 additions & 1 deletion docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Implicit Meta-Gradient Module

------

Linear system solving
Linear system solvers
=====================

.. currentmodule:: torchopt.linear_solve
Expand All @@ -166,12 +166,14 @@ Linear system solving

solve_cg
solve_normal_cg
solve_inv

Indirect solvers
~~~~~~~~~~~~~~~~

.. autofunction:: solve_cg
.. autofunction:: solve_normal_cg
.. autofunction:: solve_inv

------

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ ints
Karush
Kuhn
Tucker
Neumann
260 changes: 251 additions & 9 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jaxopt
import numpy as np
import optax
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -82,7 +83,6 @@ def get_model_torch(

dataset = data.TensorDataset(
torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
# torch.empty((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS), dtype=dtype).uniform_(-1.0, +1.0),
torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)),
)
loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
Expand Down Expand Up @@ -113,7 +113,9 @@ def get_rr_dataset_torch() -> data.DataLoader:
inner_lr=[2e-2, 2e-3],
inner_update=[20, 50, 100],
)
def test_imaml(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None:
def test_imaml_solve_normal_cg(
dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int
) -> None:
np_dtype = helpers.dtype_torch2numpy(dtype)

jax_model, jax_params = get_model_jax(dtype=np_dtype)
Expand All @@ -136,7 +138,10 @@ def imaml_objective_torchopt(params, meta_params, data):
return loss

@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0), argnums=1, has_aux=True
functorch.grad(imaml_objective_torchopt, argnums=0),
argnums=1,
has_aux=True,
solve=torchopt.linear_solve.solve_normal_cg(),
)
def inner_solver_torchopt(params, meta_params, data):
# Initial functional optimizer based on TorchOpt
Expand Down Expand Up @@ -167,7 +172,11 @@ def imaml_objective_jax(params, meta_params, x, y):
loss = loss + regularization_loss
return loss

@jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True)
@jaxopt.implicit_diff.custom_root(
jax.grad(imaml_objective_jax, argnums=0),
has_aux=True,
solve=jaxopt.linear_solve.solve_normal_cg,
)
def inner_solver_jax(params, meta_params, x, y):
"""Solve ridge regression by conjugate gradient."""
# Initial functional optimizer based on torchopt
Expand Down Expand Up @@ -225,6 +234,134 @@ def outer_level(p, xs, ys):
helpers.assert_all_close(p, p_ref)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
inner_lr=[2e-2, 2e-3],
inner_update=[20, 50, 100],
ns=[False, True],
)
def test_imaml_solve_inv(
dtype: torch.dtype,
lr: float,
inner_lr: float,
inner_update: int,
ns: bool,
) -> None:
np_dtype = helpers.dtype_torch2numpy(dtype)

jax_model, jax_params = get_model_jax(dtype=np_dtype)
model, loader = get_model_torch(device='cpu', dtype=dtype)

fmodel, params = functorch.make_functional(model)
optim = torchopt.sgd(lr)
optim_state = optim.init(params)

optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(jax_params)

def imaml_objective_torchopt(params, meta_params, data):
x, y, f = data
y_pred = f(params, x)
regularization_loss = 0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
loss = F.cross_entropy(y_pred, y) + regularization_loss
return loss

@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0),
argnums=1,
solve=torchopt.linear_solve.solve_inv(ns=ns),
)
def inner_solver_torchopt(params, meta_params, data):
# Initial functional optimizer based on TorchOpt
x, y, f = data
optimizer = torchopt.sgd(lr=inner_lr)
opt_state = optimizer.init(params)
with torch.enable_grad():
# Temporarily enable gradient computation for conducting the optimization
for _ in range(inner_update):
pred = f(params, x)
loss = F.cross_entropy(pred, y) # compute loss
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params

def imaml_objective_jax(params, meta_params, x, y):
y_pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y))
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
loss = loss + regularization_loss
return loss

@jaxopt.implicit_diff.custom_root(
jax.grad(imaml_objective_jax, argnums=0),
solve=jaxopt.linear_solve.solve_normal_cg,
)
def inner_solver_jax(params, meta_params, x, y):
"""Solve ridge regression by conjugate gradient."""
# Initial functional optimizer based on torchopt
optimizer = optax.sgd(inner_lr)
opt_state = optimizer.init(params)

def compute_loss(params, meta_params, x, y):
pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y))
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
final_loss = loss + regularization_loss
return final_loss

for i in range(inner_update):
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
return params

for xs, ys in loader:
xs = xs.to(dtype=dtype)
data = (xs, ys, fmodel)
meta_params_copy = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
)
optimal_params = inner_solver_torchopt(meta_params_copy, params, data)
outer_loss = fmodel(optimal_params, xs).mean()

grads = torch.autograd.grad(outer_loss, params)
updates, optim_state = optim.update(grads, optim_state)
params = torchopt.apply_updates(params, updates)

xs = xs.numpy()
ys = ys.numpy()

def outer_level(p, xs, ys):
optimal_params = inner_solver_jax(copy.deepcopy(p), p, xs, ys)
outer_loss = jax_model(optimal_params, xs).mean()
return outer_loss

grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
jax_params = optax.apply_updates(jax_params, updates_jax)

jax_params_as_tensor = tuple(
nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params
)

for p, p_ref in zip(params, jax_params_as_tensor):
helpers.assert_all_close(p, p_ref)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down Expand Up @@ -341,7 +478,7 @@ def outer_level(p, xs, ys):
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
)
def test_rr(
def test_rr_solve_cg(
dtype: torch.dtype,
lr: float,
) -> None:
Expand Down Expand Up @@ -371,7 +508,7 @@ def ridge_objective_torch(params, l2reg, data):
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch(params, l2reg, data):
def ridge_solver_torch_cg(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

Expand All @@ -393,7 +530,7 @@ def ridge_objective_jax(params, l2reg, X_tr, y_tr):
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax(params, l2reg, X_tr, y_tr):
def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
Expand All @@ -413,7 +550,112 @@ def matvec(u):
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys))
w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
updates, optim_state = optim.update(grads, optim_state)
l2reg_torch = torchopt.apply_updates(l2reg_torch, updates)

xs = jnp.array(xs.numpy(), dtype=np_dtype)
ys = jnp.array(ys.numpy(), dtype=np_dtype)
xq = jnp.array(xq.numpy(), dtype=np_dtype)
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value

grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax)

l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
ns=[True, False],
)
def test_rr_solve_inv(
dtype: torch.dtype,
lr: float,
ns: bool,
) -> None:
if dtype == torch.float64 and ns:
pytest.skip('Neumann Series test skips torch.float64 due to numerical stability.')
helpers.seed_everything(42)
np_dtype = helpers.dtype_torch2numpy(dtype)
input_size = 10

init_params_torch = torch.randn(input_size, dtype=dtype)
l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True)

init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype)
l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype)

loader = get_rr_dataset_torch()

optim = torchopt.sgd(lr)
optim_state = optim.init(l2reg_torch)

optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(l2reg_jax)

def ridge_objective_torch(params, l2reg, data):
"""Ridge objective function."""
X_tr, y_tr = data
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params))
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch_inv(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

def matvec(u):
return X_tr.T @ (X_tr @ u)

solve = torchopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
ns=ns,
)

return solve(matvec=matvec, b=X_tr.T @ y_tr)

def ridge_objective_jax(params, l2reg, X_tr, y_tr):
"""Ridge objective function."""
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params))
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
return X_tr.T @ ((X_tr @ u))

return jaxopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
)

for xs, ys, xq, yq in loader:
xs = xs.to(dtype=dtype)
ys = ys.to(dtype=dtype)
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
Expand All @@ -426,7 +668,7 @@ def matvec(u):
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys)
w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value
Expand Down
3 changes: 2 additions & 1 deletion torchopt/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""Linear algebra functions."""

from torchopt.linalg.cg import cg
from torchopt.linalg.ns import ns, ns_inv


__all__ = ['cg']
__all__ = ['cg', 'ns', 'ns_inv']
Loading

0 comments on commit 0185f2e

Please sign in to comment.