Skip to content

Commit 0857f9f

Browse files
nijkahKKIEEKzhouzaida
authored
[Feature] Support torch ZeroRedundancyOptimizer (#551)
* [Feature] Support torch ZeRORedundancyOptimizer Co-authored-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com> * lint * Fix saving optimizer state_dict * Fix handling import error * Add test case * fix UT * Revert "fix UT" This reverts commit dd64538. * fix handling import in UT * Fix saving zero checkpoint and delete redundant master_only * lint * test unittest * Fix handling impor error * Fix UT condition * Edit docstrings * Fix typo * Skip redundant procudure in checkpoint hook * fix typo again * Update mmengine/optim/optimizer/zero_optimizer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Add api info * lint * Fix lint * Handling AmpOptimWrapper case * handling overlap_with_ddp * Fix error Signed-off-by: Junhwa Song <ethan9867@gmail.com> Signed-off-by: Hakjin Lee <nijkah@gmail.com> Co-authored-by: Junhwa Song <ethan9867@gmail.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
1 parent bf369da commit 0857f9f

File tree

6 files changed

+156
-5
lines changed

6 files changed

+156
-5
lines changed

docs/en/api/optim.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Optimizer
2323
OptimWrapper
2424
OptimWrapperDict
2525
DefaultOptimWrapperConstructor
26+
ZeroRedundancyOptimizer
2627

2728
.. autosummary::
2829
:toctree: generated

docs/zh_cn/api/optim.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Optimizer
2323
OptimWrapper
2424
OptimWrapperDict
2525
DefaultOptimWrapperConstructor
26+
ZeroRedundancyOptimizer
2627

2728
.. autosummary::
2829
:toctree: generated

mmengine/hooks/checkpoint_hook.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
from typing import Callable, Dict, List, Optional, Sequence, Union
88

9-
from mmengine.dist import master_only
9+
from mmengine.dist import is_main_process
1010
from mmengine.fileio import FileClient, get_file_backend
1111
from mmengine.registry import HOOKS
1212
from mmengine.utils import is_list_of, is_seq_of
@@ -309,7 +309,6 @@ def _get_metric_score(self, metrics, key_indicator):
309309

310310
return eval_res[key_indicator]
311311

312-
@master_only
313312
def _save_checkpoint(self, runner) -> None:
314313
"""Save the current checkpoint and delete outdated checkpoint.
315314
@@ -331,6 +330,11 @@ def _save_checkpoint(self, runner) -> None:
331330
backend_args=self.backend_args,
332331
**self.args)
333332

333+
# Model parallel-like training should involve pulling sharded states
334+
# from all ranks, but skip the following procedure.
335+
if not is_main_process():
336+
return
337+
334338
runner.message_hub.update_info(
335339
'last_ckpt',
336340
self.file_backend.join_path(self.out_dir, ckpt_filename))
@@ -357,7 +361,6 @@ def _save_checkpoint(self, runner) -> None:
357361
with open(save_file, 'w') as f:
358362
f.write(filepath)
359363

360-
@master_only
361364
def _save_best_checkpoint(self, runner, metrics) -> None:
362365
"""Save the current checkpoint and delete outdated checkpoint.
363366

mmengine/optim/optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from .default_constructor import DefaultOptimWrapperConstructor
66
from .optimizer_wrapper import OptimWrapper
77
from .optimizer_wrapper_dict import OptimWrapperDict
8+
from .zero_optimizer import ZeroRedundancyOptimizer
89

910
__all__ = [
1011
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
1112
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
12-
'AmpOptimWrapper', 'OptimWrapperDict'
13+
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
1314
]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
import torch
4+
from torch.distributed.rpc import is_available
5+
6+
from mmengine.dist import is_main_process
7+
from mmengine.utils import digit_version
8+
from mmengine.utils.dl_utils import TORCH_VERSION
9+
10+
try:
11+
from torch.distributed.optim import \
12+
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
13+
except ImportError:
14+
_ZeroRedundancyOptimizer = object
15+
16+
from .builder import OPTIMIZERS
17+
18+
19+
@OPTIMIZERS.register_module()
20+
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
21+
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
22+
optimizer type as string.
23+
24+
This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its
25+
states across ranks in the group as described by ZeRO_. The local optimizer
26+
instance in each rank is only responsible for updating approximately
27+
``1 / world_size`` parameters and hence only needs to keep
28+
``1 / world_size`` optimizer states. After parameters are updated locally,
29+
each rank will broadcast its parameters to all other peers to keep all
30+
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used
31+
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to
32+
reduce per-rank peak memory consumption.
33+
34+
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
35+
of parameters at each rank. Each parameter belongs to a single rank and is
36+
not divided among ranks. The partition is arbitrary and might not match the
37+
the parameter registration or usage order.
38+
39+
Warnings:
40+
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.
41+
42+
Args:
43+
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
44+
or :class:`dict` s giving all parameters, which will be sharded
45+
across ranks.
46+
optimizer_type (str): the string of the local optimizer class.
47+
48+
.. _ZeRO: https://arxiv.org/abs/1910.02054
49+
"""
50+
51+
def __init__(self, params, optimizer_type: str, **kwargs):
52+
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
53+
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
54+
'available when pytorch version >= 1.8.')
55+
assert is_available(), 'torch.distributed.rpc is not available.'
56+
optimizer_class = getattr(torch.optim, optimizer_type)
57+
# TODO: Register a DDP communication hook for `overlap_with_ddp=True`.
58+
# Currently only `overlap_with_ddp=False` is supported. For more
59+
# details, please refer to the pytorch's official documentation.
60+
super().__init__(params, optimizer_class, **kwargs)
61+
62+
def state_dict(self):
63+
"""Consolidate `state_dict`s from ranks to save the `state_dict`."""
64+
self.consolidate_state_dict()
65+
state_dict = super().state_dict() if is_main_process() else dict()
66+
return state_dict

tests/test_optim/test_optimizer/test_optimizer.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
23
import sys
4+
import unittest
35
from unittest import TestCase
46
from unittest.mock import MagicMock
57

68
import torch
79
import torch.nn as nn
10+
from torch.distributed.rpc import is_available
811

912
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
1013
DefaultOptimWrapperConstructor, OptimWrapper,
1114
build_optim_wrapper)
1215
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
1316
from mmengine.registry import build_from_cfg
14-
from mmengine.utils.dl_utils import mmcv_full_available
17+
from mmengine.testing._internal import MultiProcessTestCase
18+
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
19+
from mmengine.utils.version_utils import digit_version
1520

1621
MMCV_FULL_AVAILABLE = mmcv_full_available()
1722
if not MMCV_FULL_AVAILABLE:
@@ -713,3 +718,77 @@ def test_default_optimizer_constructor_custom_key(self):
713718
for setting in settings:
714719
assert param_groups[i][setting] == settings[
715720
setting], f'{name} {setting}'
721+
722+
723+
@unittest.skipIf(
724+
(digit_version(TORCH_VERSION) < digit_version('1.8.0'))
725+
or not is_available(),
726+
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
727+
class TestZeroOptimizer(MultiProcessTestCase):
728+
729+
def setUp(self) -> None:
730+
super().setUp()
731+
self._spawn_processes()
732+
733+
def _check_default_optimizer(self, optimizer, model):
734+
self.assertIsInstance(optimizer.optim, torch.optim.SGD)
735+
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
736+
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
737+
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
738+
param_groups = optimizer.param_groups[0]
739+
if MMCV_FULL_AVAILABLE:
740+
param_names = [
741+
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
742+
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
743+
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
744+
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
745+
]
746+
else:
747+
param_names = [
748+
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
749+
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
750+
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
751+
]
752+
param_dict = dict(model.named_parameters())
753+
self.assertEqual(len(param_groups['params']), len(param_names))
754+
for i in range(len(param_groups['params'])):
755+
assert torch.equal(param_groups['params'][i],
756+
param_dict[param_names[i]])
757+
758+
def test_build_zero_redundancy_optimizer(self):
759+
from torch.distributed.optim import ZeroRedundancyOptimizer
760+
self._init_dist_env(self.rank, self.world_size)
761+
model = ExampleModel()
762+
self.base_lr = 0.01
763+
self.momentum = 0.0001
764+
self.base_wd = 0.9
765+
766+
# test build function
767+
optim_wrapper_cfg = dict(
768+
optimizer=dict(
769+
type='ZeroRedundancyOptimizer',
770+
optimizer_type='SGD',
771+
lr=self.base_lr,
772+
weight_decay=self.base_wd,
773+
momentum=self.momentum))
774+
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
775+
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
776+
self._check_default_optimizer(optim_wrapper.optimizer, model)
777+
778+
# test build optimizer without ``optimizer_type``
779+
with self.assertRaises(TypeError):
780+
optim_wrapper_cfg = dict(
781+
optimizer=dict(
782+
type='ZeroRedundancyOptimizer',
783+
lr=self.base_lr,
784+
weight_decay=self.base_wd,
785+
momentum=self.momentum))
786+
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
787+
788+
def _init_dist_env(self, rank, world_size):
789+
"""Initialize the distributed environment."""
790+
os.environ['MASTER_ADDR'] = '127.0.0.1'
791+
os.environ['MASTER_PORT'] = '29510'
792+
os.environ['RANK'] = str(rank)
793+
torch.distributed.init_process_group(
794+
backend='gloo', rank=rank, world_size=world_size)

0 commit comments

Comments
 (0)