|
12 | 12 |
|
13 | 13 | from colossalai.context.parallel_mode import ParallelMode |
14 | 14 | from colossalai.core import global_context as gpc |
15 | | -from colossalai.logging import get_global_dist_logger |
16 | | -from colossalai.registry import OPTIMIZER_WRAPPERS |
17 | | -from colossalai.utils import print_rank_0 |
18 | | -from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32 |
19 | | -from ..multi_tensor_apply import multi_tensor_applier |
| 15 | +from colossalai.logging import get_dist_logger |
| 16 | +from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes, |
| 17 | + clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier) |
20 | 18 |
|
21 | 19 |
|
22 | 20 | def _zero_grad_group_helper(group, set_to_none): |
@@ -92,7 +90,7 @@ def __init__(self, |
92 | 90 | self._growth_tracker = 0 |
93 | 91 | self._hysteresis_tracker = self.hysteresis |
94 | 92 |
|
95 | | - self._logger = get_global_dist_logger() |
| 93 | + self._logger = get_dist_logger() |
96 | 94 |
|
97 | 95 | @property |
98 | 96 | def scale(self): |
@@ -145,7 +143,6 @@ def load_state_dict(self, state_dict): |
145 | 143 | self._max_scale = state_dict['max_scale'] |
146 | 144 |
|
147 | 145 |
|
148 | | -@OPTIMIZER_WRAPPERS.register_module |
149 | 146 | class FP16Optimizer(Optimizer): |
150 | 147 | """Float16 optimizer for fp16 and bf16 data types. |
151 | 148 |
|
@@ -184,13 +181,13 @@ def __init__(self, |
184 | 181 | max_scale: int = 2 ** 32): |
185 | 182 | # default args for compatibility |
186 | 183 | bf16 = False |
187 | | - params_have_main_grad = False |
| 184 | + params_have_main_grad = True |
188 | 185 |
|
189 | 186 | # have a defaults for compatibility with pytorch optim |
190 | 187 | self.defaults = optimizer.defaults |
191 | 188 |
|
192 | 189 | # log config |
193 | | - self._logger = get_global_dist_logger() |
| 190 | + self._logger = get_dist_logger() |
194 | 191 | self._logger.info(f"\n========= FP16 Optimizer Config =========\n" |
195 | 192 | f"Optimizer: {optimizer.__class__.__name__}\n" |
196 | 193 | f"clip_grad = {clip_grad}\n" |
@@ -328,6 +325,7 @@ def _copy_model_grads_to_main_grads(self): |
328 | 325 | else: |
329 | 326 | if model_param.grad is not None: |
330 | 327 | main_param.grad = model_param.grad.float() |
| 328 | + |
331 | 329 | # For fp32 grads, we need to reset the grads to main grad. |
332 | 330 | if self.params_have_main_grad: |
333 | 331 | for model_group in self.fp32_from_fp32_groups: |
@@ -387,10 +385,6 @@ def reload_model_params(self): |
387 | 385 |
|
388 | 386 | @torch.no_grad() |
389 | 387 | def step(self): |
390 | | - # for param_group in self.float16_groups: |
391 | | - # for param in param_group: |
392 | | - # print(param.grad is None) |
393 | | - |
394 | 388 | # Copy gradients from model params to main params. |
395 | 389 | self._copy_model_grads_to_main_grads() |
396 | 390 |
|
|
0 commit comments