Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(linear_solve): matrix inversion linear solver with neumann series approximation #98

Merged
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d158462
feat(linear_solve): matrix inversion linear solver with neumann serie…
XuehaiPan Oct 13, 2022
e171f38
fix: fix ns
XuehaiPan Nov 6, 2022
db88ee3
fix: add test
Benjamin-eecs Nov 6, 2022
0946dbe
test: update tests
XuehaiPan Nov 6, 2022
b3cef2f
fix: add test
Benjamin-eecs Nov 6, 2022
e3a38a7
fix: add test
Benjamin-eecs Nov 7, 2022
de27c02
fix: add test
Benjamin-eecs Nov 7, 2022
183e769
fix: pass test
Benjamin-eecs Nov 7, 2022
155ab8d
fix: pass lint
Benjamin-eecs Nov 7, 2022
081d0b1
fix: pass lint
Benjamin-eecs Nov 7, 2022
7fc64b7
fix: pass lint
Benjamin-eecs Nov 7, 2022
1199814
fix: pass lint
Benjamin-eecs Nov 7, 2022
ede0cba
fix: update test
Benjamin-eecs Nov 7, 2022
54f8e01
chore: update CHANGELOG
Benjamin-eecs Nov 7, 2022
dc49c4f
merge: resolve conflicts
Benjamin-eecs Nov 7, 2022
d1d316c
fix: resolve comments
Benjamin-eecs Nov 7, 2022
31cb010
docs: add solve_inv
Benjamin-eecs Nov 7, 2022
909bd81
chore: update Makefile
Benjamin-eecs Nov 7, 2022
ff54c48
docs: update
Benjamin-eecs Nov 7, 2022
7cb77c8
wip
XuehaiPan Nov 7, 2022
03d6972
wip
XuehaiPan Nov 7, 2022
e8c4b38
wip
XuehaiPan Nov 7, 2022
fcf5148
wip
XuehaiPan Nov 7, 2022
8cd84bb
fix: pass test
Benjamin-eecs Nov 7, 2022
4a7db3c
wip
XuehaiPan Nov 8, 2022
8865eb6
wip
XuehaiPan Nov 8, 2022
afbcf11
wip
XuehaiPan Nov 8, 2022
77cd2ee
wip
XuehaiPan Nov 8, 2022
6753929
wip
XuehaiPan Nov 8, 2022
70ec405
feat: support normalize matvec with tensortree
XuehaiPan Nov 8, 2022
2bd55ba
feat: support implicit matvec
XuehaiPan Nov 8, 2022
564c7ec
feat: support implicit matvec
XuehaiPan Nov 8, 2022
4e7c1f9
fix: fix jacobian tree compose
XuehaiPan Nov 9, 2022
071026a
feat: multi-tensor support for solve_inv
XuehaiPan Nov 9, 2022
08de4bc
docs: update linear_solve docs
XuehaiPan Nov 9, 2022
5a0e458
chore: update ns_inv
XuehaiPan Nov 9, 2022
71ddf58
chore: add shortcuts
XuehaiPan Nov 9, 2022
67b774d
docs: update dictionary
XuehaiPan Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ good-names=i,
lr,
mu,
nu,
x
x,
y

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98).
- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
Expand Down
4 changes: 3 additions & 1 deletion docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Implicit Meta-Gradient Module

------

Linear system solving
Linear system solvers
=====================

.. currentmodule:: torchopt.linear_solve
Expand All @@ -166,12 +166,14 @@ Linear system solving

solve_cg
solve_normal_cg
solve_inv

Indirect solvers
~~~~~~~~~~~~~~~~

.. autofunction:: solve_cg
.. autofunction:: solve_normal_cg
.. autofunction:: solve_inv

------

Expand Down
260 changes: 251 additions & 9 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jaxopt
import numpy as np
import optax
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -82,7 +83,6 @@ def get_model_torch(

dataset = data.TensorDataset(
torch.randint(0, 1, (BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS)),
# torch.empty((BATCH_SIZE * NUM_UPDATES, MODEL_NUM_INPUTS), dtype=dtype).uniform_(-1.0, +1.0),
torch.randint(0, MODEL_NUM_CLASSES, (BATCH_SIZE * NUM_UPDATES,)),
)
loader = data.DataLoader(dataset, BATCH_SIZE, shuffle=False)
Expand Down Expand Up @@ -113,7 +113,9 @@ def get_rr_dataset_torch() -> data.DataLoader:
inner_lr=[2e-2, 2e-3],
inner_update=[20, 50, 100],
)
def test_imaml(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None:
def test_imaml_solve_normal_cg(
dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int
) -> None:
np_dtype = helpers.dtype_torch2numpy(dtype)

jax_model, jax_params = get_model_jax(dtype=np_dtype)
Expand All @@ -136,7 +138,10 @@ def imaml_objective_torchopt(params, meta_params, data):
return loss

@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0), argnums=1, has_aux=True
functorch.grad(imaml_objective_torchopt, argnums=0),
argnums=1,
has_aux=True,
solve=torchopt.linear_solve.solve_normal_cg(),
)
def inner_solver_torchopt(params, meta_params, data):
# Initial functional optimizer based on TorchOpt
Expand Down Expand Up @@ -167,7 +172,11 @@ def imaml_objective_jax(params, meta_params, x, y):
loss = loss + regularization_loss
return loss

@jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True)
@jaxopt.implicit_diff.custom_root(
jax.grad(imaml_objective_jax, argnums=0),
has_aux=True,
solve=jaxopt.linear_solve.solve_normal_cg,
)
def inner_solver_jax(params, meta_params, x, y):
"""Solve ridge regression by conjugate gradient."""
# Initial functional optimizer based on torchopt
Expand Down Expand Up @@ -225,6 +234,134 @@ def outer_level(p, xs, ys):
helpers.assert_all_close(p, p_ref)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
inner_lr=[2e-2, 2e-3],
inner_update=[20, 50, 100],
ns=[False, True],
)
def test_imaml_solve_inv(
dtype: torch.dtype,
lr: float,
inner_lr: float,
inner_update: int,
ns: bool,
) -> None:
np_dtype = helpers.dtype_torch2numpy(dtype)

jax_model, jax_params = get_model_jax(dtype=np_dtype)
model, loader = get_model_torch(device='cpu', dtype=dtype)

fmodel, params = functorch.make_functional(model)
optim = torchopt.sgd(lr)
optim_state = optim.init(params)

optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(jax_params)

def imaml_objective_torchopt(params, meta_params, data):
x, y, f = data
y_pred = f(params, x)
regularization_loss = 0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
loss = F.cross_entropy(y_pred, y) + regularization_loss
return loss

@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0),
argnums=1,
solve=torchopt.linear_solve.solve_inv(ns=ns),
)
def inner_solver_torchopt(params, meta_params, data):
# Initial functional optimizer based on TorchOpt
x, y, f = data
optimizer = torchopt.sgd(lr=inner_lr)
opt_state = optimizer.init(params)
with torch.enable_grad():
# Temporarily enable gradient computation for conducting the optimization
for _ in range(inner_update):
pred = f(params, x)
loss = F.cross_entropy(pred, y) # compute loss
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params

def imaml_objective_jax(params, meta_params, x, y):
y_pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y))
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
loss = loss + regularization_loss
return loss

@jaxopt.implicit_diff.custom_root(
jax.grad(imaml_objective_jax, argnums=0),
solve=jaxopt.linear_solve.solve_normal_cg,
)
def inner_solver_jax(params, meta_params, x, y):
"""Solve ridge regression by conjugate gradient."""
# Initial functional optimizer based on torchopt
optimizer = optax.sgd(inner_lr)
opt_state = optimizer.init(params)

def compute_loss(params, meta_params, x, y):
pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y))
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
final_loss = loss + regularization_loss
return final_loss

for i in range(inner_update):
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
return params

for xs, ys in loader:
xs = xs.to(dtype=dtype)
data = (xs, ys, fmodel)
meta_params_copy = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
)
optimal_params = inner_solver_torchopt(meta_params_copy, params, data)
outer_loss = fmodel(optimal_params, xs).mean()

grads = torch.autograd.grad(outer_loss, params)
updates, optim_state = optim.update(grads, optim_state)
params = torchopt.apply_updates(params, updates)

xs = xs.numpy()
ys = ys.numpy()

def outer_level(p, xs, ys):
optimal_params = inner_solver_jax(copy.deepcopy(p), p, xs, ys)
outer_loss = jax_model(optimal_params, xs).mean()
return outer_loss

grads_jax = jax.grad(outer_level, argnums=0)(jax_params, xs, ys)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
jax_params = optax.apply_updates(jax_params, updates_jax)

jax_params_as_tensor = tuple(
nn.Parameter(torch.tensor(np.asarray(jax_params[j]), dtype=dtype)) for j in jax_params
)

for p, p_ref in zip(params, jax_params_as_tensor):
helpers.assert_all_close(p, p_ref)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down Expand Up @@ -341,7 +478,7 @@ def outer_level(p, xs, ys):
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
)
def test_rr(
def test_rr_solve_cg(
dtype: torch.dtype,
lr: float,
) -> None:
Expand Down Expand Up @@ -371,7 +508,7 @@ def ridge_objective_torch(params, l2reg, data):
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch(params, l2reg, data):
def ridge_solver_torch_cg(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

Expand All @@ -393,7 +530,7 @@ def ridge_objective_jax(params, l2reg, X_tr, y_tr):
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax(params, l2reg, X_tr, y_tr):
def ridge_solver_jax_cg(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
Expand All @@ -413,7 +550,112 @@ def matvec(u):
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch(init_params_torch, l2reg_torch, (xs, ys))
w_fit = ridge_solver_torch_cg(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
updates, optim_state = optim.update(grads, optim_state)
l2reg_torch = torchopt.apply_updates(l2reg_torch, updates)

xs = jnp.array(xs.numpy(), dtype=np_dtype)
ys = jnp.array(ys.numpy(), dtype=np_dtype)
xq = jnp.array(xq.numpy(), dtype=np_dtype)
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax_cg(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value

grads_jax = jax.grad(outer_level, argnums=1)(init_params_jax, l2reg_jax, xs, ys, xq, yq)
updates_jax, optim_state_jax = optim_jax.update(grads_jax, optim_state_jax) # get updates
l2reg_jax = optax.apply_updates(l2reg_jax, updates_jax)

l2reg_jax_as_tensor = torch.tensor(np.asarray(l2reg_jax), dtype=dtype)
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
ns=[True, False],
)
def test_rr_solve_inv(
dtype: torch.dtype,
lr: float,
ns: bool,
) -> None:
if dtype == torch.float64 and ns:
pytest.skip('Neumann Series test skips torch.float64 due to numerical stability.')
helpers.seed_everything(42)
np_dtype = helpers.dtype_torch2numpy(dtype)
input_size = 10

init_params_torch = torch.randn(input_size, dtype=dtype)
l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True)

init_params_jax = jnp.array(init_params_torch.detach().numpy(), dtype=np_dtype)
l2reg_jax = jnp.array(l2reg_torch.detach().numpy(), dtype=np_dtype)

loader = get_rr_dataset_torch()

optim = torchopt.sgd(lr)
optim_state = optim.init(l2reg_torch)

optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(l2reg_jax)

def ridge_objective_torch(params, l2reg, data):
"""Ridge objective function."""
X_tr, y_tr = data
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params))
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch_inv(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

def matvec(u):
return X_tr.T @ (X_tr @ u)

solve = torchopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
ns=ns,
)

return solve(matvec=matvec, b=X_tr.T @ y_tr)

def ridge_objective_jax(params, l2reg, X_tr, y_tr):
"""Ridge objective function."""
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * jnp.sum(jnp.square(params))
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax_inv(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
return X_tr.T @ ((X_tr @ u))

return jaxopt.linear_solve.solve_inv(
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
)

for xs, ys, xq, yq in loader:
xs = xs.to(dtype=dtype)
ys = ys.to(dtype=dtype)
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

w_fit = ridge_solver_torch_inv(init_params_torch, l2reg_torch, (xs, ys))
outer_loss = F.mse_loss(xq @ w_fit, yq)

grads, *_ = torch.autograd.grad(outer_loss, l2reg_torch)
Expand All @@ -426,7 +668,7 @@ def matvec(u):
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys)
w_fit = ridge_solver_jax_inv(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value
Expand Down
3 changes: 2 additions & 1 deletion torchopt/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""Linear algebra functions."""

from torchopt.linalg.cg import cg
from torchopt.linalg.ns import ns, ns_inv


__all__ = ['cg']
__all__ = ['cg', 'ns', 'ns_inv']
Loading