From 381c5f103c3c77b55ee8ecf233ce0916eac67b61 Mon Sep 17 00:00:00 2001 From: Ming-Hsuan-Tu Date: Thu, 8 Dec 2022 13:15:42 +0800 Subject: [PATCH] [Enhance] Support passing kwargs to update_params (#796) * [Enhance] Support step arugments and zero arguments with update_params * Update mmengine/optim/optimizer/optimizer_wrapper.py * Update mmengine/optim/optimizer/optimizer_wrapper.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/optim/optimizer/optimizer_wrapper.py | 19 ++++++++++++++++--- .../optim/optimizer/optimizer_wrapper_dict.py | 7 +++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index 58dbc051d2..644d5739fc 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -161,20 +161,33 @@ def __init__(self, # the loss factor will always be the same as `_accumulative_counts`. self._remainder_counts = -1 - def update_params(self, loss: torch.Tensor) -> None: + def update_params(self, + loss: torch.Tensor, + step_kwargs: Optional[Dict] = None, + zero_kwargs: Optional[Dict] = None) -> None: """Update parameters in :attr:`optimizer`. Args: loss (torch.Tensor): A tensor for back propagation. + step_kwargs (dict): Arguments for optimizer.step. + Defaults to None. + New in version v0.4.0. + zero_kwargs (dict): Arguments for optimizer.zero_grad. + Defaults to None. + New in version v0.4.0. """ + if step_kwargs is None: + step_kwargs = {} + if zero_kwargs is None: + zero_kwargs = {} loss = self.scale_loss(loss) self.backward(loss) # Update parameters only if `self._inner_count` is divisible by # `self._accumulative_counts` or `self._inner_count` equals to # `self._max_counts` if self.should_update(): - self.step() - self.zero_grad() + self.step(**step_kwargs) + self.zero_grad(**zero_kwargs) def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation. diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py index 6155b62df0..8a4b258003 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py +++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -46,7 +46,10 @@ def __init__(self, **optim_wrapper_dict: OptimWrapper): f'but got {key}: {type(value)}') self.optim_wrappers = optim_wrapper_dict - def update_params(self, loss: torch.Tensor) -> None: + def update_params(self, + loss: torch.Tensor, + step_kwargs: Optional[Dict] = None, + zero_kwargs: Optional[Dict] = None) -> None: """Update all optimizer wrappers would lead to a duplicate backward errors, and OptimWrapperDict does not know which optimizer wrapper should be updated.