Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing backtracking linesearch support on gpus #814

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# ==============================================================================
"""Tests for `linesearch.py`."""

import contextlib
import functools
import io
import itertools
import math

Expand Down Expand Up @@ -160,17 +158,25 @@ def fn(params):
params = update.apply_updates(params, updates)
chex.assert_trees_all_close(final_params, params, atol=1e-2, rtol=1e-2)

@parameterized.product(jit=[True, False])
def test_recycling_value_and_grad(self, jit):
@chex.variants(
with_jit=True,
without_jit=True,
with_pmap=False,
with_device=True,
without_device=True,
)
def test_recycling_value_and_grad(self):
# A vmap or a pmap makes the cond in value_and_state_from_grad
# become a select and in that case this code cannot be optimal.
# So we skip the pmap test.
init_params = jnp.array([1.0, 10.0, 1.0])
final_params = jnp.array([1.0, -1.0, 1.0])

def fn(params):
jax.debug.print('function evaluated')
return jnp.sum((params - final_params) ** 2)

# Base learning rate ensures sufficient decrease, so the linesearch should
# not make more function evaluations than the total number of iterations.
value_and_grad = utils.value_and_grad_from_state(fn)

base_opt = alias.sgd(learning_rate=0.1)
solver = combine.chain(
base_opt,
Expand All @@ -181,36 +187,35 @@ def fn(params):
store_grad=True,
),
)
value_and_grad = utils.value_and_grad_from_state(fn)
init_state = solver.init(init_params)
max_iter = 40

update_fn = functools.partial(solver.update, value_fn=fn)
update_fn = jax.jit(update_fn)

def step_(params, state):
value, grad = value_and_grad(params, state=state)
def fake_fun(_):
return 1.0

fake_value_and_grad = utils.value_and_grad_from_state(fake_fun)

def step_(params, state, iter_num):
# Should still work as the value and grad are extracted from the state
value, grad = jax.lax.cond(
iter_num > 0,
lambda: fake_value_and_grad(params, state=state),
lambda: value_and_grad(params, state=state),
)
updates, state = update_fn(grad, state, params, value=value, grad=grad)
params = update.apply_updates(params, updates)
return params, state

if jit:
step = jax.jit(step_)
else:
step = step_
step = self.variant(step_)
params = init_params
state = init_state
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
for _ in range(max_iter):
params, state = step(params, state)
for iter_num in range(max_iter):
params, state = step(params, state, iter_num)
params = jax.block_until_ready(params)
chex.assert_trees_all_close(final_params, params, atol=1e-2, rtol=1e-2)

num_evals = stdout.getvalue().count('function evaluated')
# There are two function call sites, so as the function may be compiled
# twice, we may get a total of max_iter + 1.
self.assertLessEqual(num_evals, max_iter + 1)

def test_armijo_sgd(self):
def fn(params, x, y):
return jnp.sum((x.dot(params) - y) ** 2)
Expand Down
44 changes: 20 additions & 24 deletions optax/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# ==============================================================================
"""Tests for `utils.py`."""

import contextlib
import io
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -141,7 +140,7 @@ def test_log_prob(self, shape):
self.assertEqual(probs.shape, ())


class HelpersTest(parameterized.TestCase):
class HelpersTest(chex.TestCase):

@parameterized.parameters([
(1, 1),
Expand Down Expand Up @@ -296,17 +295,20 @@ def check_values_found(state, values_found):
self.assertLen(values_found, 3)
check_values_found(state, values_found)

@parameterized.product(jit=[True, False])
def test_value_and_grad_from_state(self, jit):
@chex.variants(
with_jit=True,
without_jit=True,
with_pmap=False,
with_device=True,
without_device=True,
)
def test_value_and_grad_from_state(self):
def fn(x):
return jnp.sum(x**2)

value_and_grad_ = utils.value_and_grad_from_state(fn)

if jit:
value_and_grad = jax.jit(value_and_grad_)
else:
value_and_grad = value_and_grad_
value_and_grad = self.variant(value_and_grad_)

params = jnp.array([1.0, 2.0, 3.0])

Expand Down Expand Up @@ -338,22 +340,16 @@ def fn(x):
params = update.apply_updates(params, updates)
params = jax.block_until_ready(params)

def fn_chatty(x):
jax.debug.print('function evaluated')
return jnp.sum(x**2)
def false_fn(_):
return 1.

value_and_grad_ = utils.value_and_grad_from_state(fn_chatty)
if jit:
value_and_grad = jax.jit(value_and_grad_)
else:
value_and_grad = value_and_grad_

# At the second step we should not need to evaluate the function
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
value_and_grad(params, state=state)
num_eval = stdout.getvalue().count('function evaluated')
self.assertEqual(num_eval, 0)
false_value_and_grad_ = utils.value_and_grad_from_state(false_fn)
false_value_and_grad = self.variant(false_value_and_grad_)

# At the second step we should not evaluate the function
# so in this case it should not return the output of false_fn
value, _ = false_value_and_grad(params, state=state)
self.assertNotEqual(value, 1.)

def test_extract_fns_kwargs(self):
def fn1(a, b):
Expand Down
Loading