Skip to content

Commit

Permalink
refactor: chain
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Aug 27, 2022
1 parent 1ec3d0f commit e8bd609
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 22 deletions.
2 changes: 1 addition & 1 deletion tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_sgd(

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

chain = torchopt.combine.chain(
chain = torchopt.chain(
torchopt.clip.clip_grad_norm(max_norm=max_norm),
torchopt.sgd(
lr=lr,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual
from torchopt._src.alias import adam, rmsprop, sgd
from torchopt._src.clip import clip_grad_norm
from torchopt._src.combine import chain
from torchopt._src.optimizer import SGD, Adam, Optimizer, RMSProp, RMSprop, meta
from torchopt._src.optimizer.meta import MetaAdam, MetaOptimizer, MetaRMSProp, MetaRMSprop, MetaSGD
from torchopt._src.update import apply_updates
Expand All @@ -33,6 +35,8 @@
'adam',
'rmsprop',
'sgd',
'clip_grad_norm',
'chain',
'Optimizer',
'SGD',
'Adam',
Expand Down
79 changes: 77 additions & 2 deletions torchopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@
# limitations under the License.
# ==============================================================================

import itertools
from abc import abstractmethod
from typing import Callable, NamedTuple, Optional, Tuple

from typing_extensions import Protocol

from torchopt._src.typing import Numeric, TensorTree


try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol


OptState = TensorTree # States are arbitrary nests of `torch.Tensor`.
# Parameters are arbitrary nests of `torch.Tensor`.
Params = TensorTree
Expand Down Expand Up @@ -132,6 +137,76 @@ class GradientTransformation(NamedTuple):
init: TransformInitFn
update: TransformUpdateFn

# pylint: disable-next=redefined-builtin
def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation':
"""Chain two gradient transformations together."""
return ChainedGradientTransformation(self, next)


class ChainedGradientTransformation(GradientTransformation):
"""A chain of gradient transformations.
This class is a subclass of :class:`GradientTransformation` which allows for chaining of
gradient transformations.
"""

transformations: Tuple[GradientTransformation, ...]

def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation':
transformations = tuple(
itertools.chain.from_iterable(
t.transformations if isinstance(t, ChainedGradientTransformation) else (t,)
for t in transformations
)
)

init_fns, update_fns = tuple(zip(*transformations))

def init_fn(params):
return tuple(fn(params) for fn in init_fns)

def update_fn(updates, state, *, params=None, inplace=True):
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you'
'have called init first!'
)
new_state = []
for s, fn in zip(state, update_fns): # pylint: disable=invalid-name
updates, new_s = fn(updates, s, params=params, inplace=inplace)
new_state.append(new_s)
return updates, tuple(new_state)

instance = super().__new__(cls, init_fn, update_fn)
instance.transformations = tuple(transformations)
return instance

def __str__(self):
return '{}(\n {}\n)'.format(
self.__class__.__name__, ',\n '.join(repr(t) for t in self.transformations)
)

__repr__ = __str__

def __eq__(self, other: object) -> bool:
if isinstance(other, ChainedGradientTransformation):
return self.transformations == other.transformations
if isinstance(other, GradientTransformation):
return self.transformations == (other,)
return False

def __hash__(self) -> int:
return hash(self.transformations)

def __getstate__(self) -> Tuple[GradientTransformation, ...]:
return self.transformations

def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None:
self.transformations = state

def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]:
return ChainedGradientTransformation, (self.transformations,)


def identity() -> GradientTransformation:
"""Stateless identity transformation that leaves input gradients untouched.
Expand Down
3 changes: 1 addition & 2 deletions torchopt/_src/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def clip_grad_norm(
An ``(init_fn, update_fn)`` tuple.
"""

def init_fn(params):
del params
def init_fn(_):
return ClipState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
Expand Down
18 changes: 1 addition & 17 deletions torchopt/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,5 @@ def chain(*args: base.GradientTransformation) -> base.GradientTransformation:
Returns:
A single ``(init_fn, update_fn)`` tuple.
"""
init_fns, update_fns = tuple(zip(*args))

def init_fn(params):
return tuple(fn(params) for fn in init_fns)

def update_fn(updates, state, *, params=None, inplace=True):
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you have '
'called init first!'
)
new_state = []
for s, fn in zip(state, update_fns): # pylint: disable=invalid-name
updates, new_s = fn(updates, s, params=params, inplace=inplace)
new_state.append(new_s)
return updates, tuple(new_state)

return base.GradientTransformation(init_fn, update_fn)
return base.ChainedGradientTransformation(*args)

0 comments on commit e8bd609

Please sign in to comment.