Skip to content

Commit

Permalink
Renamed tree_scalar_mul to tree_scale and tree_add_scalar_mul to tree…
Browse files Browse the repository at this point in the history
…_add_scale with deprecation warnings (#1196)
  • Loading branch information
shreyans413 committed Mar 6, 2025
1 parent 2e66ce8 commit d3cd178
Show file tree
Hide file tree
Showing 15 changed files with 96 additions and 51 deletions.
8 changes: 4 additions & 4 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Tree
.. autosummary::
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_add_scale
tree_batch_shape
tree_cast
tree_div
Expand All @@ -112,7 +112,7 @@ Tree
tree_ones_like
tree_random_like
tree_split_key_like
tree_scalar_mul
tree_scale
tree_set
tree_sub
tree_sum
Expand All @@ -130,7 +130,7 @@ Tree add

Tree add and scalar multiply
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_add_scalar_mul
.. autofunction:: tree_add_scale

Tree batch reshaping
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -194,7 +194,7 @@ Tree with random values

Tree scalar multiply
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_scalar_mul
.. autofunction:: tree_scale

Set values in a tree
~~~~~~~~~~~~~~~~~~~~
Expand Down
12 changes: 6 additions & 6 deletions examples/contrib/reduce_on_plateau.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/perturbations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {
"id": "EkwIh76L1Azl"
},
Expand All @@ -880,7 +880,7 @@
"source": [
"eta = 1e-1\n",
"\n",
"loss_step = pert_loss(otu.tree_add_scalar_mul(tree_a, -eta, grad), rng)\n",
"loss_step = pert_loss(otu.tree_add_scale(tree_a, -eta, grad), rng)\n",
"\n",
"print(f'initial loss value = {init_loss:.3f}')\n",
"print(f'loss after gradient step = {loss_step:.3f}')"
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def _normalize_tree(x):
# divide by the L2 norm of the tree weights.
return otu.tree_scalar_mul(1.0 / otu.tree_l2_norm(x), x)
return otu.tree_scale(1.0 / otu.tree_l2_norm(x), x)


def global_norm(updates: base.PyTree) -> chex.Array:
Expand All @@ -42,7 +42,7 @@ def global_norm(updates: base.PyTree) -> chex.Array:
def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):
normalized_eigvec, unnormalized_eigvec, eig, iter_num = loop_vars
residual = otu.tree_sub(
unnormalized_eigvec, otu.tree_scalar_mul(eig, normalized_eigvec)
unnormalized_eigvec, otu.tree_scale(eig, normalized_eigvec)
)
residual_norm = otu.tree_l2_norm(residual)
converged = jnp.abs(residual_norm / eig) < error_tolerance
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def body_fn(
learning_rate = jnp.where(
iter_num > 0, decrease_factor * learning_rate, learning_rate
)
new_params = otu.tree_add_scalar_mul(params, learning_rate, updates)
new_params = otu.tree_add_scale(params, learning_rate, updates)

value_fn_ = functools.partial(value_fn, **fn_kwargs)
if store_grad:
Expand Down Expand Up @@ -424,7 +424,7 @@ def body_fn(
"Using a stepsize of 0 to avoid infinite or nan values.",
)
# At the end, we just scale the updates with the learning rate found.
new_updates = otu.tree_scalar_mul(new_learning_rate, updates)
new_updates = otu.tree_scale(new_learning_rate, updates)
info = BacktrackingLinesearchInfo(
num_linesearch_steps=search_state.iter_num,
decrease_error=search_state.decrease_error,
Expand Down Expand Up @@ -700,7 +700,7 @@ def _value_and_slope_on_line(
* ``slope_step`` is the derivative of the function in terms of the
stepsize at the step.
"""
step = otu.tree_add_scalar_mul(params, stepsize, updates)
step = otu.tree_add_scale(params, stepsize, updates)
value_step, grad_step = value_and_grad_fn(step, **fn_kwargs)
slope_step = otu.tree_real(otu.tree_vdot(otu.tree_conj(grad_step), updates))
return step, value_step, grad_step, slope_step
Expand Down Expand Up @@ -1605,7 +1605,7 @@ def update_fn(
init_state,
)
learning_rate = final_state.stepsize
scaled_updates = otu.tree_scalar_mul(learning_rate, updates)
scaled_updates = otu.tree_scale(learning_rate, updates)
info_step = ZoomLinesearchInfo(
num_linesearch_steps=final_state.count,
decrease_error=final_state.decrease_error,
Expand Down
12 changes: 6 additions & 6 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,15 +622,15 @@ def update_fn(updates, state, params=None):
m = otu.tree_update_moment(g, state.m, b1, 1)
v = otu.tree_update_moment(diff, state.v, b2, 1)

sq = otu.tree_add_scalar_mul(g, 1 - b2, diff)
sq = otu.tree_add_scale(g, 1 - b2, diff)
n = otu.tree_update_moment_per_elem_norm(sq, state.n, b3, 2)

t = numerics.safe_increment(state.t)
m_hat = otu.tree_bias_correction(m, b1, t)
v_hat = otu.tree_bias_correction(v, b2, t)
n_hat = otu.tree_bias_correction(n, b3, t)

u = otu.tree_add_scalar_mul(m_hat, 1 - b2, v_hat)
u = otu.tree_add_scale(m_hat, 1 - b2, v_hat)
denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat)
u = otu.tree_div(u, denom)

Expand Down Expand Up @@ -1463,7 +1463,7 @@ def update_fn(
jnp.array(0.0),
jnp.minimum(gap / (grad_sq_norm + eps), max_learning_rate),
)
updates = otu.tree_scalar_mul(step, updates)
updates = otu.tree_scale(step, updates)
return updates, state

return base.GradientTransformationExtraArgs(base.init_empty_state, update_fn)
Expand Down Expand Up @@ -1543,22 +1543,22 @@ def right_product(vec, idx):
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
alpha = rhos[idx] * otu.tree_real(otu.tree_vdot(dwi, vec))
vec = otu.tree_add_scalar_mul(vec, -alpha, dui)
vec = otu.tree_add_scale(vec, -alpha, dui)
return vec, alpha

precond_updates, alphas = jax.lax.scan(
right_product, updates, indices, reverse=True
)

precond_updates = otu.tree_scalar_mul(identity_scale, precond_updates)
precond_updates = otu.tree_scale(identity_scale, precond_updates)

def left_product(vec, idx_alpha):
idx, alpha = idx_alpha
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
beta = rhos[idx] * otu.tree_real(otu.tree_vdot(dui, vec))
vec = otu.tree_add_scalar_mul(vec, alpha - beta, dwi)
vec = otu.tree_add_scale(vec, alpha - beta, dwi)
return vec, beta

precond_updates, _ = jax.lax.scan(
Expand Down
2 changes: 1 addition & 1 deletion optax/contrib/_cocob.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def scale_by_cocob(
def init_fn(params):
init_adapt = otu.tree_zeros_like(params)
init_scale = otu.tree_ones_like(params)
init_scale = otu.tree_scalar_mul(eps, init_scale)
init_scale = otu.tree_scale(eps, init_scale)
return COCOBState(
init_particles=params,
cumulative_gradients=init_adapt,
Expand Down
4 changes: 2 additions & 2 deletions optax/contrib/_dog.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def update_fn(
state.first_step, init_learning_rate, learning_rate
)

new_updates = otu.tree_scalar_mul(learning_rate, updates)
new_updates = otu.tree_scale(learning_rate, updates)
return new_updates, DoGState(
first_step=jnp.asarray(False),
init_params=init_params,
Expand Down Expand Up @@ -266,7 +266,7 @@ def update_fn(
)
learning_rate = estim_sq_dist / (jnp.sqrt(weighted_sq_norm_grads) + eps)

new_updates = otu.tree_scalar_mul(learning_rate, updates)
new_updates = otu.tree_scale(learning_rate, updates)
return new_updates, state._replace(
estim_sq_dist=estim_sq_dist,
weighted_sq_norm_grads=weighted_sq_norm_grads,
Expand Down
4 changes: 2 additions & 2 deletions optax/perturbations/_make_pert.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _compute_residuals(
]
# creates [inputs + Z_1, ..., inputs + Z_num_samples]
inputs_pert = _tree_vmap(
lambda z: otu.tree_add_scalar_mul(inputs, sigma, z), samples
lambda z: otu.tree_add_scale(inputs, sigma, z), samples
)
# applies fun: [fun(inputs + Z_1), ..., fun(inputs + Z_num_samples)]
outputs_pert = _tree_vmap(fun, inputs_pert)
Expand Down Expand Up @@ -185,7 +185,7 @@ def fun_perturb_jvp(
# TODO(qberthet): implement with the jvp of the grad log prob.
# computes 1/M * sum_i fun(inputs + sigma * Z_i) < - grad log_prob(Z_i), g>
tangent_out = _tree_mean_across([
otu.tree_scalar_mul(-scalar_dot_prod, output)
otu.tree_scale(-scalar_dot_prod, output)
for scalar_dot_prod, output in zip(list_dot_prods, outputs_pert)
])
return tangent_out
Expand Down
2 changes: 1 addition & 1 deletion optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def apply_element_tree(tree):
self.rng_jax, self.example_tree, sampler=_make_pert.Normal().sample
)
tree_out_noisy = apply_element_tree(
otu.tree_add_scalar_mul(self.example_tree, 1e-4, tree_noise)
otu.tree_add_scale(self.example_tree, 1e-4, tree_noise)
)
chex.assert_trees_all_close(tree_out, tree_out_noisy, rtol=1e-4)

Expand Down
4 changes: 2 additions & 2 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def projection_l2_sphere(tree: Any, scale: float = 1.0) -> Any:
.. versionadded:: 0.2.4
"""
factor = scale / otu.tree_l2_norm(tree)
return otu.tree_scalar_mul(factor, tree)
return otu.tree_scale(factor, tree)


def projection_l2_ball(tree: Any, scale: float = 1.0) -> Any:
Expand Down Expand Up @@ -278,7 +278,7 @@ def projection_l2_ball(tree: Any, scale: float = 1.0) -> Any:
return jax.lax.cond(
l2_norm <= scale,
lambda tree: tree,
lambda tree: otu.tree_scalar_mul(factor, tree),
lambda tree: otu.tree_scale(factor, tree),
operand=tree,
)

Expand Down
2 changes: 1 addition & 1 deletion optax/transforms/_adding.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def update_fn(updates, state, params=None):
noise = otu.tree_random_like(
sample_key, target_tree=updates, sampler=jax.random.normal
)
updates = otu.tree_add_scalar_mul(
updates = otu.tree_add_scale(
tree_x=updates, scalar=standard_deviation, tree_y=noise
)
return updates, AddNoiseState(count=count_inc, rng_key=rng_key)
Expand Down
4 changes: 2 additions & 2 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from optax.tree_utils._state_utils import tree_map_params
from optax.tree_utils._state_utils import tree_set
from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_add_scale
from optax.tree_utils._tree_math import tree_batch_shape
from optax.tree_utils._tree_math import tree_bias_correction
from optax.tree_utils._tree_math import tree_clip
Expand All @@ -41,7 +41,7 @@
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
from optax.tree_utils._tree_math import tree_real
from optax.tree_utils._tree_math import tree_scalar_mul
from optax.tree_utils._tree_math import tree_scale
from optax.tree_utils._tree_math import tree_sub
from optax.tree_utils._tree_math import tree_sum
from optax.tree_utils._tree_math import tree_update_infinity_moment
Expand Down
54 changes: 50 additions & 4 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax
import jax.numpy as jnp
from optax._src import numerics

import warnings

def tree_add(tree_x: Any, tree_y: Any, *other_trees: Any) -> Any:
r"""Add two (or more) pytrees.
Expand Down Expand Up @@ -81,7 +81,7 @@ def tree_div(tree_x: Any, tree_y: Any) -> Any:
return jax.tree.map(operator.truediv, tree_x, tree_y)


def tree_scalar_mul(
def tree_scale(
scalar: Union[float, jax.Array],
tree: Any,
) -> Any:
Expand All @@ -98,8 +98,30 @@ def tree_scalar_mul(
"""
return jax.tree.map(lambda x: scalar * x, tree)


def tree_add_scalar_mul(
def tree_scalar_mul(
scalar: Union[float, jax.Array],
tree: Any,
) -> Any:
r"""Deprecated alias for tree_scale.
This function is deprecated and will be removed in a future version.
Use ``tree_scale`` instead.
Args:
scalar: scalar value.
tree: pytree.
Returns:
A pytree with the same structure as ``tree``.
"""
warnings.warn(
"tree_scalar_mul is deprecated and will be removed in a future version. Use tree_scale instead.",
DeprecationWarning,
stacklevel=2
)
return tree_scale(scalar, tree)

def tree_add_scale(
tree_x: Any, scalar: Union[float, jax.Array], tree_y: Any
) -> Any:
r"""Add two trees, where the second tree is scaled by a scalar.
Expand All @@ -122,6 +144,30 @@ def tree_add_scalar_mul(
is_leaf=lambda x: x is None,
)

def tree_add_scalar_mul(
tree_x: Any,
scalar: Union[float, jax.Array],
tree_y: Any,
) -> Any:
r"""Deprecated alias for tree_add_scale.
This function is deprecated and will be removed in a future version.
Use ``tree_add_scale`` instead.
Args:
tree_x: first pytree.
scalar: scalar value.
tree_y: second pytree.
Returns:
A pytree with the same structure as ``tree_x`` and ``tree_y``.
"""
warnings.warn(
"tree_add_scalar_mul is deprecated and will be removed in a future version. Use tree_add_scale instead.",
DeprecationWarning,
stacklevel=2
)
return tree_add_scale(tree_x, scalar, tree_y)

_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)

Expand Down
Loading

0 comments on commit d3cd178

Please sign in to comment.