Skip to content

Commit

Permalink
[Enhance] Support passing kwargs to update_params (#796)
Browse files Browse the repository at this point in the history
* [Enhance]

Support step arugments and zero arguments with update_params

* Update mmengine/optim/optimizer/optimizer_wrapper.py

* Update mmengine/optim/optimizer/optimizer_wrapper.py

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 8, 2022
1 parent 57f6644 commit 381c5f1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
19 changes: 16 additions & 3 deletions mmengine/optim/optimizer/optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,20 +161,33 @@ def __init__(self,
# the loss factor will always be the same as `_accumulative_counts`.
self._remainder_counts = -1

def update_params(self, loss: torch.Tensor) -> None:
def update_params(self,
loss: torch.Tensor,
step_kwargs: Optional[Dict] = None,
zero_kwargs: Optional[Dict] = None) -> None:
"""Update parameters in :attr:`optimizer`.
Args:
loss (torch.Tensor): A tensor for back propagation.
step_kwargs (dict): Arguments for optimizer.step.
Defaults to None.
New in version v0.4.0.
zero_kwargs (dict): Arguments for optimizer.zero_grad.
Defaults to None.
New in version v0.4.0.
"""
if step_kwargs is None:
step_kwargs = {}
if zero_kwargs is None:
zero_kwargs = {}
loss = self.scale_loss(loss)
self.backward(loss)
# Update parameters only if `self._inner_count` is divisible by
# `self._accumulative_counts` or `self._inner_count` equals to
# `self._max_counts`
if self.should_update():
self.step()
self.zero_grad()
self.step(**step_kwargs)
self.zero_grad(**zero_kwargs)

def backward(self, loss: torch.Tensor, **kwargs) -> None:
"""Perform gradient back propagation.
Expand Down
7 changes: 5 additions & 2 deletions mmengine/optim/optimizer/optimizer_wrapper_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -46,7 +46,10 @@ def __init__(self, **optim_wrapper_dict: OptimWrapper):
f'but got {key}: {type(value)}')
self.optim_wrappers = optim_wrapper_dict

def update_params(self, loss: torch.Tensor) -> None:
def update_params(self,
loss: torch.Tensor,
step_kwargs: Optional[Dict] = None,
zero_kwargs: Optional[Dict] = None) -> None:
"""Update all optimizer wrappers would lead to a duplicate backward
errors, and OptimWrapperDict does not know which optimizer wrapper
should be updated.
Expand Down

0 comments on commit 381c5f1

Please sign in to comment.