Skip to content

Commit 7a0a576

Browse files
authored
fix adamw lr_to_coeff is fixed when dygraph (#30526)
1 parent 59ad6ff commit 7a0a576

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

python/paddle/fluid/tests/unittests/test_adamw_op.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,27 @@ def test_adamw_lr_decay(self):
9898
value = np.arange(26).reshape(2, 13).astype("float32")
9999
a = paddle.to_tensor(value)
100100
linear = paddle.nn.Linear(13, 5)
101+
102+
lr = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=10)
103+
wd = 0.1
101104
adam = paddle.optimizer.AdamW(
102-
learning_rate=paddle.optimizer.lr.NoamDecay(
103-
d_model=512, warmup_steps=4000),
105+
learning_rate=lr,
104106
parameters=linear.parameters(),
105107
apply_decay_param_fun=lambda name: True,
106-
weight_decay=0.01)
107-
out = linear(a)
108-
out.backward()
109-
adam.step()
110-
adam.clear_gradients()
108+
weight_decay=wd)
109+
110+
for _ in range(2):
111+
out = linear(a)
112+
out.backward()
113+
lr_to_coeff = adam._lr_to_coeff
114+
adam.step()
115+
116+
for i, value in enumerate(lr_to_coeff.values()):
117+
self.assertAlmostEqual(value.numpy()[0], 1.0 - lr() * wd)
118+
self.assertEqual(len(adam._lr_to_coeff), 0)
119+
120+
lr.step()
121+
adam.clear_gradients()
111122

112123

113124
if __name__ == "__main__":

python/paddle/optimizer/adamw.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ def _append_decoupled_weight_decay(self, block, param_and_grad):
173173
[param, grad]), framework.name_scope('weight decay'):
174174
self._params_name.add(param.name)
175175

176-
# If it has been calculated, the result will be reused
176+
# If it has been calculated, the result will be reused.
177+
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed
178+
# every step, so need clear _lr_to_coeff every step,
179+
# we do this in _create_optimization_pass
177180
decay_coeff = self._lr_to_coeff.get(learning_rate, None)
178181
if decay_coeff is None:
179182
decay_coeff = 1.0 - learning_rate * self._coeff
@@ -186,5 +189,12 @@ def _append_optimize_op(self, block, param_and_grad):
186189
self._append_decoupled_weight_decay(block, param_and_grad)
187190
return super(AdamW, self)._append_optimize_op(block, param_and_grad)
188191

192+
def _create_optimization_pass(self, parameters_and_grads):
193+
optimize_ops = super(
194+
AdamW, self)._create_optimization_pass(parameters_and_grads)
195+
# In dygraph mode, clear _lr_to_coeff after applied gradient
196+
self._lr_to_coeff = dict()
197+
return optimize_ops
198+
189199
def __str__(self):
190200
return " ".join(["Weight Decay, params:", ",".join(self._params_name)])

0 commit comments

Comments
 (0)