Skip to content

Commit

Permalink
feat: add params to update_fn's signature
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Aug 26, 2022
1 parent 0c30794 commit 4a985ea
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 22 deletions.
1 change: 1 addition & 0 deletions torchopt/_src/accelerated_op/adam_op/adam_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
*,
eps_root: float = 0.0,
inplace: bool = True,
) -> None:
Expand Down
11 changes: 8 additions & 3 deletions torchopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# ==============================================================================

from abc import abstractmethod
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple, Optional, Tuple

from typing_extensions import Protocol

Expand Down Expand Up @@ -82,7 +82,12 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods

@abstractmethod
def __call__(
self, updates: Updates, state: OptState, inplace: bool = True
self,
updates: Updates,
state: OptState,
*,
params: Optional[Params] = None,
inplace: bool = True,
) -> Tuple[Updates, OptState]:
"""The `update` function.
Expand Down Expand Up @@ -140,7 +145,7 @@ def identity() -> GradientTransformation:
def init_fn(_):
return EmptyState()

def update_fn(updates, state, inplace=True): # pylint: disable=unused-argument
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
return updates, state

return GradientTransformation(init_fn, update_fn)
2 changes: 1 addition & 1 deletion torchopt/_src/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def init_fn(params):
del params
return ClipState()

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
available_updates = []
for g in updates:
if g is not None:
Expand Down
4 changes: 2 additions & 2 deletions torchopt/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ def chain(*args: base.GradientTransformation) -> base.GradientTransformation:
def init_fn(params):
return tuple(fn(params) for fn in init_fns)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True):
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you have '
'called init first!'
)
new_state = []
for s, fn in zip(state, update_fns): # pylint: disable=invalid-name
updates, new_s = fn(updates, s, inplace=inplace)
updates, new_s = fn(updates, s, params=params, inplace=inplace)
new_state.append(new_s)
return updates, tuple(new_state)

Expand Down
2 changes: 1 addition & 1 deletion torchopt/_src/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def register_hook(hook) -> GradientTransformation:
def init_fn(_):
return EmptyState()

def update_fn(updates, state, inplace=True): # pylint: disable=unused-argument
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
def f(g):
return g.register_hook(hook) if g is not None else None

Expand Down
8 changes: 4 additions & 4 deletions torchopt/_src/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def step(self, closure=None):
def f(p):
return p.grad

for i, (param, state) in enumerate(zip(self.param_groups, self.state_groups)):
grad = pytree.tree_map(f, param)
updates, new_state = self.impl.update(grad, state, inplace=True)
self.param_groups[i] = apply_updates(param, updates)
for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)):
grad = pytree.tree_map(f, params)
updates, new_state = self.impl.update(grad, state, params=params, inplace=True)
self.param_groups[i] = apply_updates(params, updates)
self.state_groups[i] = new_state

return loss
Expand Down
7 changes: 6 additions & 1 deletion torchopt/_src/optimizer/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def step(self, loss: torch.Tensor):
flattened_params, container_tree = pytree.tree_flatten(param_container)
flattened_params = tuple(flattened_params)
grad = torch.autograd.grad(loss, flattened_params, create_graph=True, allow_unused=True)
updates, new_state = self.impl.update(grad, new_state, inplace=False)
updates, new_state = self.impl.update(
grad,
new_state,
params=flattened_params,
inplace=False,
)
self.state_groups[i] = new_state
new_params = apply_updates(flattened_params, updates, inplace=False)
unflattened_new_params = container_tree.unflatten(new_params)
Expand Down
18 changes: 9 additions & 9 deletions torchopt/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def init_fn(params):
del params
return ScaleState()

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
if inplace:

def f(g):
Expand Down Expand Up @@ -114,7 +114,7 @@ def init_fn(params):
)
return ScaleByScheduleState(count=zero)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
step_size = step_size_fn(state.count)
if inplace:
updates = pytree.tree_map(lambda g, step_size: g.mul_(step_size), updates, step_size)
Expand All @@ -125,7 +125,7 @@ def update_fn(updates, state, inplace=True):
return base.GradientTransformation(init_fn, update_fn)


def _update_moment(updates, moments, decay, order, inplace=True):
def _update_moment(updates, moments, decay, *, order, inplace=True):
"""Compute the exponential moving average of the ``order``-th moment."""
assert order in (1, 2)

Expand Down Expand Up @@ -215,7 +215,7 @@ def init_fn(params):
)
return ScaleByAdamState(mu=mu, nu=nu, count=zero)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
mu = _update_moment(updates, state.mu, b1, order=1, inplace=inplace)
nu = _update_moment(updates, state.nu, b2, order=2, inplace=inplace)
count_inc = inc_count(updates, state.count)
Expand Down Expand Up @@ -281,12 +281,12 @@ def init_fn(params):
)
return ScaleByAdamState(mu=mu, nu=nu, count=zero)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
count_inc = inc_count(updates, state.count)

treedef = pytree.tree_structure(updates)

op = AdamOp(b1, b2, eps, eps_root, inplace)
op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace)
out = pytree.tree_map(op, state.mu, state.nu, updates, count_inc)

new_mu, new_nu, new_updates = pytree.tree_transpose(treedef, TRIPLE_PYTREEDEF, out)
Expand Down Expand Up @@ -334,7 +334,7 @@ def init_fn(params):
)
)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
if nesterov:
if inplace:

Expand Down Expand Up @@ -410,7 +410,7 @@ def init_fn(params):
nu = pytree.tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment
return ScaleByRmsState(nu=nu)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
nu = _update_moment(updates, state.nu, alpha, order=2, inplace=inplace)

if inplace:
Expand Down Expand Up @@ -461,7 +461,7 @@ def init_fn(params):
nu = pytree.tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment
return ScaleByRStdDevState(mu=mu, nu=nu)

def update_fn(updates, state, inplace=True):
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
mu = _update_moment(updates, state.mu, alpha, order=1, inplace=inplace)
nu = _update_moment(updates, state.nu, alpha, order=2, inplace=inplace)

Expand Down
2 changes: 1 addition & 1 deletion torchopt/_src/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def apply_updates(
params: 'base.Params', updates: 'base.Updates', inplace: bool = True
params: 'base.Params', updates: 'base.Updates', *, inplace: bool = True
) -> 'base.Params':
"""Applies an update to the corresponding parameters.
Expand Down

0 comments on commit 4a985ea

Please sign in to comment.