Skip to content

Add multi_tensor for momentum optimizer and clear_grads #37564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ee0611c
add multi_tensor for momentum and clear_grads for optimizer
zhangbo9674 Nov 25, 2021
cff538e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Dec 1, 2021
d1f2e65
fix bug for dygraph
zhangbo9674 Dec 1, 2021
f2170fd
add unittest
zhangbo9674 Dec 1, 2021
517539e
refine comment
zhangbo9674 Dec 1, 2021
b5b0181
add param_group
zhangbo9674 Dec 2, 2021
5040d32
refine regularizaiton logic
zhangbo9674 Dec 3, 2021
9359158
del clear_grads
zhangbo9674 Dec 3, 2021
eaa498c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Dec 3, 2021
335fc20
add clear_grads
zhangbo9674 Dec 8, 2021
6575038
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Dec 8, 2021
7feaa77
merge develop & fix confilct
zhangbo9674 Dec 8, 2021
debef46
add dispensable check of None
zhangbo9674 Dec 8, 2021
4fef207
refine clear_grad
zhangbo9674 Dec 9, 2021
aa370b0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Dec 9, 2021
aa37e17
fix build bug
zhangbo9674 Dec 9, 2021
28e4a7e
refine code by comment
zhangbo9674 Dec 13, 2021
c5981f3
refine code
zhangbo9674 Dec 14, 2021
c152fbb
add multi tensor check
zhangbo9674 Dec 14, 2021
ff221ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Dec 14, 2021
2cf25e2
refine param_group update
zhangbo9674 Dec 14, 2021
668d5ad
add multi tensor for static mode
zhangbo9674 Dec 15, 2021
f7317e2
refine comments
zhangbo9674 Dec 15, 2021
da9c4a4
merge develop
zhangbo9674 Dec 15, 2021
deb10be
delete useless comma for momentum
zhangbo9674 Dec 16, 2021
9d2df8a
refine comment for momentum
zhangbo9674 Dec 16, 2021
3d1cd6a
refine code by commment
zhangbo9674 Dec 16, 2021
5423606
Merge branch 'develop' into dev/approve_momentum_py
zhangbo9674 Dec 17, 2021
c363ba5
merge develop
zhangbo9674 Dec 17, 2021
5d9239b
fix conflict
zhangbo9674 Dec 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
ssize_t arg_idx, bool dispensable) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);

if (list == nullptr) {
if (list == nullptr || list == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/op_function_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}},
{"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}},
{"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}},
Expand Down Expand Up @@ -113,6 +115,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"run_program", {"DOut"}},
Expand Down Expand Up @@ -153,6 +156,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}},
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}},
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,14 @@ PYBIND11_MODULE(core_noavx, m) {

m.def("disable_signal_handler", &DisableSignalHandler);

m.def("clear_gradients",
[](std::vector<std::shared_ptr<imperative::VarBase>> param_list,
bool set_to_zero) {
for (auto param : param_list) {
param->ClearGradient(set_to_zero);
}
});

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("cudnn_version", &platform::DnnVersion);
m.def("gpu_memory_available", []() {
Expand Down
185 changes: 185 additions & 0 deletions python/paddle/fluid/tests/unittests/test_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import numpy


def calculate_momentum_by_numpy(param,
Expand Down Expand Up @@ -805,5 +806,189 @@ def test_momentum_dygraph(self):
adam.clear_gradients()


class TestMultiTensorMomentumDygraph(unittest.TestCase):
def _momentum_optimize_dygraph(self,
place,
use_param_attr=False,
use_param_group=False,
use_amp=False,
use_multi_tensor=False):
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
input = paddle.randn((5, 5))
weight_attr = paddle.ParamAttr(
learning_rate=0.5,
regularizer=paddle.regularizer.L2Decay(1.0),
trainable=True)
if use_param_attr:
model = paddle.nn.Linear(5, 5, weight_attr)
else:
model = paddle.nn.Linear(5, 5)
if not use_param_group:
optimizer = paddle.optimizer.Momentum(
parameters=model.parameters(),
use_multi_tensor=use_multi_tensor,
multi_precision=use_amp)
else:
optimizer = paddle.optimizer.Momentum(
parameters=[{
'params': model.parameters(),
'weight_decay': 0.001,
'learning_rate': 0.1,
'momentum': 0.99
}],
use_multi_tensor=use_multi_tensor,
multi_precision=use_amp)
for idx in range(5):
if place == 'gpu' and use_amp == True:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if place == 'gpu' and use_amp == True:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
optimizer.clear_grad(set_to_zero=False)
else:
output = model(input)
loss = paddle.mean(output)
# This can be any optimizer supported by dygraph.
loss.backward()
optimizer.step()
optimizer.clear_grad(set_to_zero=False)
return output, model.parameters()

def _get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places

def _check_with_place_amp(self, place, use_amp):
output1, params1 = self._momentum_optimize_dygraph(
place=place, use_amp=use_amp, use_multi_tensor=True)
output2, params2 = self._momentum_optimize_dygraph(
place=place, use_amp=use_amp, use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)

def _check_with_param_arrt(self, place, use_amp):
output1, params1 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_attr=True,
use_multi_tensor=True)
output2, params2 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_attr=True,
use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)

def _check_with_param_group(self, place, use_amp):
output1, params1 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_group=True,
use_multi_tensor=True)
output2, params2 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_group=True,
use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)

def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._check_with_place_amp(place, use_amp)
self._check_with_param_arrt(place, use_amp)
self._check_with_param_group(place, use_amp)


class TestMultiTensorMomentumStatic(unittest.TestCase):
def _momentum_optimize_static(self,
place,
use_amp=False,
use_multi_tensor=False):
paddle.enable_static()
paddle.seed(10)
np.random.seed(10)
if place == 'cpu':
use_amp = False
exe = paddle.static.Executor(place=place)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.optimizer.Momentum(
multi_precision=use_amp, use_multi_tensor=use_multi_tensor)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False)
with paddle.static.program_guard(train_program, startup_program):
if use_amp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16')
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32')
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.fluid.layers.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
x = numpy.random.random(size=(2, 2)).astype('float16')
else:
x = numpy.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
loss_data, = exe.run(train_program,
feed={"X": x},
fetch_list=[loss.name])
out.append(loss_data)
return out

def _get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places

def _check_with_place_amp(self, place, use_amp):
output1 = self._momentum_optimize_static(
place=place, use_amp=use_amp, use_multi_tensor=True)
output2 = self._momentum_optimize_static(
place=place, use_amp=use_amp, use_multi_tensor=False)
for idx in range(len(output1)):
self.assertEqual(
np.allclose(
output1[idx], output2[idx], rtol=1e-05), True)

def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._check_with_place_amp(place, use_amp)


if __name__ == "__main__":
unittest.main()
Loading