diff --git a/tests/test_clip.py b/tests/test_clip.py index 7df650eb..7468afa2 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -45,7 +45,7 @@ def test_sgd( model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - chain = torchopt.combine.chain( + chain = torchopt.chain( torchopt.clip.clip_grad_norm(max_norm=max_norm), torchopt.sgd( lr=lr, diff --git a/torchopt/__init__.py b/torchopt/__init__.py index abc351b0..12fa9d12 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -16,6 +16,8 @@ from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual from torchopt._src.alias import adam, rmsprop, sgd +from torchopt._src.clip import clip_grad_norm +from torchopt._src.combine import chain from torchopt._src.optimizer import SGD, Adam, Optimizer, RMSProp, RMSprop, meta from torchopt._src.optimizer.meta import MetaAdam, MetaOptimizer, MetaRMSProp, MetaRMSprop, MetaSGD from torchopt._src.update import apply_updates @@ -33,6 +35,8 @@ 'adam', 'rmsprop', 'sgd', + 'clip_grad_norm', + 'chain', 'Optimizer', 'SGD', 'Adam', diff --git a/torchopt/_src/base.py b/torchopt/_src/base.py index 55201d52..e2bb8dbc 100644 --- a/torchopt/_src/base.py +++ b/torchopt/_src/base.py @@ -30,14 +30,19 @@ # limitations under the License. # ============================================================================== +import itertools from abc import abstractmethod from typing import Callable, NamedTuple, Optional, Tuple -from typing_extensions import Protocol - from torchopt._src.typing import Numeric, TensorTree +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol + + OptState = TensorTree # States are arbitrary nests of `torch.Tensor`. # Parameters are arbitrary nests of `torch.Tensor`. Params = TensorTree @@ -132,6 +137,76 @@ class GradientTransformation(NamedTuple): init: TransformInitFn update: TransformUpdateFn + # pylint: disable-next=redefined-builtin + def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation': + """Chain two gradient transformations together.""" + return ChainedGradientTransformation(self, next) + + +class ChainedGradientTransformation(GradientTransformation): + """A chain of gradient transformations. + + This class is a subclass of :class:`GradientTransformation` which allows for chaining of + gradient transformations. + """ + + transformations: Tuple[GradientTransformation, ...] + + def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation': + transformations = tuple( + itertools.chain.from_iterable( + t.transformations if isinstance(t, ChainedGradientTransformation) else (t,) + for t in transformations + ) + ) + + init_fns, update_fns = tuple(zip(*transformations)) + + def init_fn(params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, *, params=None, inplace=True): + if len(update_fns) != len(state): + raise ValueError( + 'The number of updates and states has to be the same in chain! Make sure you' + 'have called init first!' + ) + new_state = [] + for s, fn in zip(state, update_fns): # pylint: disable=invalid-name + updates, new_s = fn(updates, s, params=params, inplace=inplace) + new_state.append(new_s) + return updates, tuple(new_state) + + instance = super().__new__(cls, init_fn, update_fn) + instance.transformations = tuple(transformations) + return instance + + def __str__(self): + return '{}(\n {}\n)'.format( + self.__class__.__name__, ',\n '.join(repr(t) for t in self.transformations) + ) + + __repr__ = __str__ + + def __eq__(self, other: object) -> bool: + if isinstance(other, ChainedGradientTransformation): + return self.transformations == other.transformations + if isinstance(other, GradientTransformation): + return self.transformations == (other,) + return False + + def __hash__(self) -> int: + return hash(self.transformations) + + def __getstate__(self) -> Tuple[GradientTransformation, ...]: + return self.transformations + + def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None: + self.transformations = state + + def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]: + return ChainedGradientTransformation, (self.transformations,) + def identity() -> GradientTransformation: """Stateless identity transformation that leaves input gradients untouched. diff --git a/torchopt/_src/clip.py b/torchopt/_src/clip.py index 67e667c2..2101dd69 100644 --- a/torchopt/_src/clip.py +++ b/torchopt/_src/clip.py @@ -38,8 +38,7 @@ def clip_grad_norm( An ``(init_fn, update_fn)`` tuple. """ - def init_fn(params): - del params + def init_fn(_): return ClipState() def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument diff --git a/torchopt/_src/combine.py b/torchopt/_src/combine.py index c0576f51..384ad0b1 100644 --- a/torchopt/_src/combine.py +++ b/torchopt/_src/combine.py @@ -47,21 +47,5 @@ def chain(*args: base.GradientTransformation) -> base.GradientTransformation: Returns: A single ``(init_fn, update_fn)`` tuple. """ - init_fns, update_fns = tuple(zip(*args)) - def init_fn(params): - return tuple(fn(params) for fn in init_fns) - - def update_fn(updates, state, *, params=None, inplace=True): - if len(update_fns) != len(state): - raise ValueError( - 'The number of updates and states has to be the same in chain! Make sure you have ' - 'called init first!' - ) - new_state = [] - for s, fn in zip(state, update_fns): # pylint: disable=invalid-name - updates, new_s = fn(updates, s, params=params, inplace=inplace) - new_state.append(new_s) - return updates, tuple(new_state) - - return base.GradientTransformation(init_fn, update_fn) + return base.ChainedGradientTransformation(*args)