Skip to content

Commit e29c2d1

Browse files
authored
[amp] dygraph amp support param_group (#34899)
* dygraph amp support param_group * remove unused code * fix doc
1 parent b0cb414 commit e29c2d1

File tree

3 files changed

+100
-19
lines changed

3 files changed

+100
-19
lines changed

python/paddle/amp/grad_scaler.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,49 @@ def minimize(self, optimizer, *args, **kwargs):
146146
"""
147147
return super(GradScaler, self).minimize(optimizer, *args, **kwargs)
148148

149+
def step(self, optimizer):
150+
"""
151+
This function is similar as `optimizer.step()`, which performs parameters updating.
152+
153+
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
154+
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
155+
156+
Args:
157+
optimizer(Optimizer): The optimizer used to update parameters.
158+
159+
Examples:
160+
.. code-block:: python
161+
162+
# required: gpu
163+
import paddle
164+
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
165+
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
166+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
167+
data = paddle.rand([10, 3, 32, 32])
168+
with paddle.amp.auto_cast():
169+
conv = model(data)
170+
loss = paddle.mean(conv)
171+
scaled = scaler.scale(loss) # scale the loss
172+
scaled.backward() # do backward
173+
scaler.step(optimizer)
174+
optimizer.clear_grad()
175+
"""
176+
if not self._enable:
177+
return optimizer.step()
178+
179+
# unscale the grad
180+
self._unscale(optimizer)
181+
182+
if self._found_inf:
183+
self._cache_founf_inf = True
184+
else:
185+
optimizer.step()
186+
self._cache_founf_inf = False
187+
188+
if self._use_dynamic_loss_scaling:
189+
# uopdate the scale
190+
self._update()
191+
149192
def is_enable(self):
150193
"""
151194
Enable loss scaling or not.

python/paddle/fluid/dygraph/amp/loss_scaler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,19 @@ def minimize(self, optimizer, *args, **kwargs):
212212
def _unscale(self, optimizer):
213213
if not self._enable:
214214
return
215-
param_grads = [
216-
param._grad_ivar() for param in optimizer._parameter_list
217-
if param._grad_ivar() is not None
218-
]
215+
216+
if getattr(optimizer, '_param_groups', None) and isinstance(
217+
optimizer._param_groups[0], dict):
218+
param_grads = []
219+
for group in optimizer._param_groups:
220+
for param in group['params']:
221+
if param._grad_ivar() is not None:
222+
param_grads.append(param._grad_ivar())
223+
else:
224+
param_grads = [
225+
param._grad_ivar() for param in optimizer._parameter_list
226+
if param._grad_ivar() is not None
227+
]
219228
_C_ops.check_finite_and_unscale(param_grads, self._scale, param_grads,
220229
self._found_inf)
221230

python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import six
2020
from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting
2121

22+
if fluid.core.is_compiled_with_cuda():
23+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
24+
2225

2326
class SimpleConv(fluid.dygraph.Layer):
2427
def __init__(self,
@@ -373,8 +376,6 @@ def train_resnet(self,
373376
return dy_out, dy_param_value, dy_grad_value
374377

375378
def test_with_state_dict(self):
376-
if fluid.core.is_compiled_with_cuda():
377-
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
378379
with fluid.dygraph.guard():
379380
out_use_state_dict = self.train_resnet(
380381
enable_amp=True, use_data_loader=True, use_save_load=True)
@@ -390,18 +391,43 @@ class TestResnet2(unittest.TestCase):
390391
Use paddle-2.0 API
391392
"""
392393

393-
def train_resnet(self, enable_amp=True, use_data_loader=False):
394+
def train_resnet(self,
395+
enable_amp=True,
396+
use_data_loader=False,
397+
use_param_group=False):
394398
seed = 90
395399

396400
batch_size = train_parameters["batch_size"]
397-
batch_num = 1
401+
batch_num = 10
398402

399403
paddle.seed(seed)
400404
paddle.framework.random._manual_program_seed(seed)
401405

402406
resnet = ResNet(use_cudnn=True)
403-
optimizer = optimizer_setting(
404-
train_parameters, parameter_list=resnet.parameters())
407+
408+
if use_param_group:
409+
conv_params = resnet.conv.parameters()
410+
other_params = []
411+
for p in resnet.parameters():
412+
contains = False
413+
for q in conv_params:
414+
if p is q:
415+
contains = True
416+
if not contains:
417+
other_params.append(p)
418+
# NOTE(zhiqiu): The Membership test operations(in / not in) calls "is" and "equal",
419+
# see details: https://docs.python.org/3/reference/expressions.html#membership-test-operations.
420+
# So do not use other_params = [p for p in resnet.parameters() if p not in conv_params]
421+
optimizer = paddle.optimizer.Momentum(parameters=[{
422+
'params': conv_params,
423+
'learning_rate': 0.01
424+
}, {
425+
'params': other_params,
426+
'learning_rate': 0.001
427+
}])
428+
else:
429+
optimizer = paddle.optimizer.SGD(parameters=resnet.parameters())
430+
405431
np.random.seed(seed)
406432
train_reader = paddle.batch(
407433
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
@@ -456,7 +482,7 @@ def train_resnet(self, enable_amp=True, use_data_loader=False):
456482
scaled_loss = scaler.scale(avg_loss)
457483
scaled_loss.backward()
458484

459-
scaler.minimize(optimizer, scaled_loss)
485+
scaler.step(optimizer)
460486

461487
dy_grad_value = {}
462488
for param in resnet.parameters():
@@ -475,22 +501,27 @@ def train_resnet(self, enable_amp=True, use_data_loader=False):
475501
return dy_out, dy_param_value, dy_grad_value
476502

477503
def test_resnet(self):
478-
if fluid.core.is_compiled_with_cuda():
479-
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
480504
with fluid.dygraph.guard():
481505
out_fp32 = self.train_resnet(enable_amp=False)
482506
out_amp = self.train_resnet(enable_amp=True)
483507
print(out_fp32[0], out_amp[0])
484-
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
508+
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
485509

486510
def test_with_data_loader(self):
487-
if fluid.core.is_compiled_with_cuda():
488-
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
489511
with fluid.dygraph.guard():
490512
out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True)
491513
out_amp = self.train_resnet(enable_amp=True, use_data_loader=True)
492514
print(out_fp32[0], out_amp[0])
493-
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
515+
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
516+
517+
def test_param_group(self):
518+
with fluid.dygraph.guard():
519+
out_fp32 = self.train_resnet(
520+
enable_amp=False, use_data_loader=True, use_param_group=True)
521+
out_amp = self.train_resnet(
522+
enable_amp=True, use_data_loader=True, use_param_group=True)
523+
print(out_fp32[0], out_amp[0])
524+
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
494525

495526

496527
class TestResnet(unittest.TestCase):
@@ -566,8 +597,6 @@ def train_resnet(self, enable_amp=True):
566597
return dy_out, dy_param_value, dy_grad_value
567598

568599
def test_resnet(self):
569-
if fluid.core.is_compiled_with_cuda():
570-
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
571600
out_fp32 = self.train_resnet(enable_amp=False)
572601
out_amp = self.train_resnet(enable_amp=True)
573602
print(out_fp32[0], out_amp[0])

0 commit comments

Comments
 (0)