diff --git a/torchopt/_src/accelerated_op/adam_op/adam_op.py b/torchopt/_src/accelerated_op/adam_op/adam_op.py index a04f1048..7f5bfc0f 100644 --- a/torchopt/_src/accelerated_op/adam_op/adam_op.py +++ b/torchopt/_src/accelerated_op/adam_op/adam_op.py @@ -112,6 +112,7 @@ def __init__( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, + *, eps_root: float = 0.0, inplace: bool = True, ) -> None: diff --git a/torchopt/_src/base.py b/torchopt/_src/base.py index d6d101ec..55201d52 100644 --- a/torchopt/_src/base.py +++ b/torchopt/_src/base.py @@ -31,7 +31,7 @@ # ============================================================================== from abc import abstractmethod -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple, Optional, Tuple from typing_extensions import Protocol @@ -82,7 +82,12 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods @abstractmethod def __call__( - self, updates: Updates, state: OptState, inplace: bool = True + self, + updates: Updates, + state: OptState, + *, + params: Optional[Params] = None, + inplace: bool = True, ) -> Tuple[Updates, OptState]: """The `update` function. @@ -140,7 +145,7 @@ def identity() -> GradientTransformation: def init_fn(_): return EmptyState() - def update_fn(updates, state, inplace=True): # pylint: disable=unused-argument + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument return updates, state return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/_src/clip.py b/torchopt/_src/clip.py index 5da4c313..67e667c2 100644 --- a/torchopt/_src/clip.py +++ b/torchopt/_src/clip.py @@ -42,7 +42,7 @@ def init_fn(params): del params return ClipState() - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument available_updates = [] for g in updates: if g is not None: diff --git a/torchopt/_src/combine.py b/torchopt/_src/combine.py index 79d1cb72..c0576f51 100644 --- a/torchopt/_src/combine.py +++ b/torchopt/_src/combine.py @@ -52,7 +52,7 @@ def chain(*args: base.GradientTransformation) -> base.GradientTransformation: def init_fn(params): return tuple(fn(params) for fn in init_fns) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): if len(update_fns) != len(state): raise ValueError( 'The number of updates and states has to be the same in chain! Make sure you have ' @@ -60,7 +60,7 @@ def update_fn(updates, state, inplace=True): ) new_state = [] for s, fn in zip(state, update_fns): # pylint: disable=invalid-name - updates, new_s = fn(updates, s, inplace=inplace) + updates, new_s = fn(updates, s, params=params, inplace=inplace) new_state.append(new_s) return updates, tuple(new_state) diff --git a/torchopt/_src/hook.py b/torchopt/_src/hook.py index d31ccf91..408e17b4 100644 --- a/torchopt/_src/hook.py +++ b/torchopt/_src/hook.py @@ -36,7 +36,7 @@ def register_hook(hook) -> GradientTransformation: def init_fn(_): return EmptyState() - def update_fn(updates, state, inplace=True): # pylint: disable=unused-argument + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument def f(g): return g.register_hook(hook) if g is not None else None diff --git a/torchopt/_src/optimizer/base.py b/torchopt/_src/optimizer/base.py index b071d5d0..14f6bf86 100644 --- a/torchopt/_src/optimizer/base.py +++ b/torchopt/_src/optimizer/base.py @@ -102,10 +102,10 @@ def step(self, closure=None): def f(p): return p.grad - for i, (param, state) in enumerate(zip(self.param_groups, self.state_groups)): - grad = pytree.tree_map(f, param) - updates, new_state = self.impl.update(grad, state, inplace=True) - self.param_groups[i] = apply_updates(param, updates) + for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)): + grad = pytree.tree_map(f, params) + updates, new_state = self.impl.update(grad, state, params=params, inplace=True) + self.param_groups[i] = apply_updates(params, updates) self.state_groups[i] = new_state return loss diff --git a/torchopt/_src/optimizer/meta/base.py b/torchopt/_src/optimizer/meta/base.py index 6fe677d9..395fa17e 100644 --- a/torchopt/_src/optimizer/meta/base.py +++ b/torchopt/_src/optimizer/meta/base.py @@ -60,7 +60,12 @@ def step(self, loss: torch.Tensor): flattened_params, container_tree = pytree.tree_flatten(param_container) flattened_params = tuple(flattened_params) grad = torch.autograd.grad(loss, flattened_params, create_graph=True, allow_unused=True) - updates, new_state = self.impl.update(grad, new_state, inplace=False) + updates, new_state = self.impl.update( + grad, + new_state, + params=flattened_params, + inplace=False, + ) self.state_groups[i] = new_state new_params = apply_updates(flattened_params, updates, inplace=False) unflattened_new_params = container_tree.unflatten(new_params) diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index 9b5d1fc7..98a1cd58 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -73,7 +73,7 @@ def init_fn(params): del params return ScaleState() - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument if inplace: def f(g): @@ -114,7 +114,7 @@ def init_fn(params): ) return ScaleByScheduleState(count=zero) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument step_size = step_size_fn(state.count) if inplace: updates = pytree.tree_map(lambda g, step_size: g.mul_(step_size), updates, step_size) @@ -125,7 +125,7 @@ def update_fn(updates, state, inplace=True): return base.GradientTransformation(init_fn, update_fn) -def _update_moment(updates, moments, decay, order, inplace=True): +def _update_moment(updates, moments, decay, *, order, inplace=True): """Compute the exponential moving average of the ``order``-th moment.""" assert order in (1, 2) @@ -215,7 +215,7 @@ def init_fn(params): ) return ScaleByAdamState(mu=mu, nu=nu, count=zero) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument mu = _update_moment(updates, state.mu, b1, order=1, inplace=inplace) nu = _update_moment(updates, state.nu, b2, order=2, inplace=inplace) count_inc = inc_count(updates, state.count) @@ -281,12 +281,12 @@ def init_fn(params): ) return ScaleByAdamState(mu=mu, nu=nu, count=zero) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument count_inc = inc_count(updates, state.count) treedef = pytree.tree_structure(updates) - op = AdamOp(b1, b2, eps, eps_root, inplace) + op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc) new_mu, new_nu, new_updates = pytree.tree_transpose(treedef, TRIPLE_PYTREEDEF, out) @@ -334,7 +334,7 @@ def init_fn(params): ) ) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument if nesterov: if inplace: @@ -410,7 +410,7 @@ def init_fn(params): nu = pytree.tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment return ScaleByRmsState(nu=nu) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument nu = _update_moment(updates, state.nu, alpha, order=2, inplace=inplace) if inplace: @@ -461,7 +461,7 @@ def init_fn(params): nu = pytree.tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment return ScaleByRStdDevState(mu=mu, nu=nu) - def update_fn(updates, state, inplace=True): + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument mu = _update_moment(updates, state.mu, alpha, order=1, inplace=inplace) nu = _update_moment(updates, state.nu, alpha, order=2, inplace=inplace) diff --git a/torchopt/_src/update.py b/torchopt/_src/update.py index d0cfd112..753292d7 100644 --- a/torchopt/_src/update.py +++ b/torchopt/_src/update.py @@ -35,7 +35,7 @@ def apply_updates( - params: 'base.Params', updates: 'base.Updates', inplace: bool = True + params: 'base.Params', updates: 'base.Updates', *, inplace: bool = True ) -> 'base.Params': """Applies an update to the corresponding parameters.