Skip to content

Add LARS to SGD and Momentum Optimizers (#6811) #7788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions paddle/operators/momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,93 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
AddAttr<bool>("use_local_lr",
"(bool, default false) "
"Use LARS")
.SetDefault(false);
AddAttr<float>("local_gw_ratio", "(float) LARS coefficient")
.SetDefault(0.001);
AddAttr<float>("weight_decay", "(float) LARS weight decay")
.SetDefault(0.0005);

AddComment(R"DOC(
Momentum Optimizer.

This optimizer has a flag for Nestrov Momentum.
Thie optimizer has attributes for LARS to adjust local LR for large batch training of CNN.
paper : https://arxiv.org/abs/1708.03888.
The update equations are as follows:

$$
velocity = mu * velocity + gradient \\
if (use\_nesterov): \\
param = param - gradient * learning\_rate + mu * velocity * learning\_rate \\
else: \\
else if (use\_lcoal\_lr): \\
learning\_rate *= local\_gw\_ratio * sqrt(sumsq(param))
/ (sqrt(sumsq(gradient))+ weight\_decay * sqrt(sumsq(param))) \\
param = param - learning\_rate * velocity. \\
else: \\
param = param - learning\_rate * velocity. \\
$$

)DOC");
}
};

template <typename T>
class MomentumOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");

param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());

T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
bool use_local_lr = ctx.Attr<bool>("use_local_lr");
T local_gw_ratio = static_cast<T>(ctx.Attr<float>("local_gw_ratio"));
T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));

auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);

auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto *lr = learning_rate->data<T>();

T local_lr = lr[0];
if (use_local_lr) {
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor,
Eigen::DenseIndex>
p_norm = p.square().sum().sqrt();
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor,
Eigen::DenseIndex>
g_norm = g.square().sum().sqrt();
if ((p_norm(0) > static_cast<T>(0)) && (g_norm(0) > static_cast<T>(0)))
local_lr = lr[0] * local_gw_ratio * p_norm(0) /
(g_norm(0) + weight_decay * p_norm(0));
}

v_out = v * mu + g;
if (use_nesterov) {
p_out = p - (g - v_out * mu) * lr[0];
} else {
p_out = p - local_lr * v_out;
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
ops::MomentumOpKernel<double>);
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpCPUKernel<float>,
ops::MomentumOpCPUKernel<double>);
91 changes: 88 additions & 3 deletions paddle/operators/sgd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,99 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor) Input gradient");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddAttr<bool>("use_local_lr",
"(bool, default false) "
"Use LARS")
.SetDefault(false);
AddAttr<float>("local_gw_ratio", "(float) LARS coefficient")
.SetDefault(0.001);
AddAttr<float>("weight_decay", "(float) LARS weight decay")
.SetDefault(0.0005);

AddComment(R"DOC(

SGD operator

This operator implements one step of the stochastic gradient descent algorithm.
This optimizer has attributes for LARS to adjust local LR for large batch training of CNN.
paper : https://arxiv.org/abs/1708.03888.
$$
if (use\_local\_lr): \\
learning\_rate *= local\_gw\_ratio * sqrt(sumsq(param))
/ (sqrt(sumsq(grad))+ weight\_decay * sqrt(sumsq(param))) \\
param\_out = param - learning\_rate * grad
$$
)DOC");
}
};

$$param\_out = param - learning\_rate * grad$$
template <typename T>
class SGDOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");

)DOC");
auto* grad_var = ctx.InputVar("Grad");

bool use_local_lr = ctx.Attr<bool>("use_local_lr");
T local_gw_ratio = static_cast<T>(ctx.Attr<float>("local_gw_ratio"));
T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));

// Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad");

auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto* lr = learning_rate->data<T>();

T local_lr = lr[0];
if (use_local_lr) {
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor,
Eigen::DenseIndex>
p_norm = p.square().sum().sqrt();
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor,
Eigen::DenseIndex>
g_norm = g.square().sum().sqrt();
if ((p_norm(0) > static_cast<T>(0)) && (g_norm(0) > static_cast<T>(0)))
local_lr = lr[0] * local_gw_ratio * p_norm(0) /
(g_norm(0) + weight_decay * p_norm(0));
}
o = p - local_lr * g;
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad");

auto in_height = grad->height();
auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);

auto& in_value = grad->value();
auto& in_rows = grad->rows();

int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);

auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>();
auto* lr = learning_rate->data<T>();

for (size_t i = 0; i < in_rows.size(); i++) {
for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j];
}
}
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
}
};

Expand All @@ -66,4 +150,5 @@ This operator implements one step of the stochastic gradient descent algorithm.

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpCPUKernel<float>,
ops::SGDOpCPUKernel<double>);
38 changes: 33 additions & 5 deletions python/paddle/v2/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,19 @@ class SGDOptimizer(Optimizer):
""" Simple SGD optimizer without any state.
"""

def __init__(self, learning_rate, **kwargs):
def __init__(self,
learning_rate,
use_local_lr=False,
local_gw_ratio=0.001,
weight_decay=0.0005,
**kwargs):
assert learning_rate is not None
super(SGDOptimizer, self).__init__(**kwargs)
self.type = "sgd"
self._learning_rate = learning_rate
self._use_local_lr = bool(use_local_lr)
self._local_gw_ratio = local_gw_ratio
self._weight_decay = weight_decay

def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
Expand All @@ -246,7 +254,12 @@ def _append_optimize_op(self, block, param_and_grad):
"Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad)
},
outputs={"ParamOut": param_and_grad[0]})
outputs={"ParamOut": param_and_grad[0]},
attrs={
"use_local_lr": self._use_local_lr,
"local_gw_ratio": self._local_gw_ratio,
"weight_decay": self._weight_decay
})

return sgd_op

Expand All @@ -256,14 +269,24 @@ class MomentumOptimizer(Optimizer):
"""
_velocity_acc_str = "velocity"

def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs):
def __init__(self,
learning_rate,
momentum,
use_nesterov=False,
use_local_lr=False,
local_gw_ratio=0.001,
weight_decay=0.0005,
**kwargs):
assert learning_rate is not None
assert momentum is not None
super(MomentumOptimizer, self).__init__(**kwargs)
self.type = "momentum"
self._learning_rate = learning_rate
self._momentum = momentum
self._use_nesterov = bool(use_nesterov)
self._use_local_lr = bool(use_local_lr)
self._local_gw_ratio = local_gw_ratio
self._weight_decay = weight_decay

def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
Expand All @@ -289,8 +312,13 @@ def _append_optimize_op(self, block, param_and_grad):
"ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc
},
attrs={"mu": self._momentum,
"use_nesterov": self._use_nesterov})
attrs={
"mu": self._momentum,
"use_nesterov": self._use_nesterov,
"use_local_lr": self._use_local_lr,
"local_gw_ratio": self._local_gw_ratio,
"weight_decay": self._weight_decay
})

return momentum_op

Expand Down
53 changes: 53 additions & 0 deletions python/paddle/v2/fluid/tests/test_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest


Expand Down Expand Up @@ -86,5 +87,57 @@ def test_check_output(self):
self.check_output()


class TestMomentumOp3(OpTest):
'''Test Momentum with LARS attribute on
'''

def setUp(self):
self.op_type = "momentum"

param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([2.56]).astype("float32")
mu = 0.0001
use_nesterov = False
use_local_lr = True
local_gw_ratio = 0.1
weight_decay = 0.5

self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
}

self.attrs = {
'mu': mu,
'use_nesterov': use_nesterov,
'use_local_lr': use_local_lr,
'local_gw_ratio': local_gw_ratio,
'weight_decay': weight_decay,
}

velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - grad * learning_rate + \
velocity_out * mu * learning_rate
else:
local_lr = learning_rate
if use_local_lr:
param_norm = np.sqrt(np.sum(np.square(param)))
grad_norm = np.sqrt(np.sum(np.square(grad)))
local_lr = learning_rate * local_gw_ratio * param_norm / \
(grad_norm + weight_decay * param_norm)
param_out = param - local_lr * velocity_out

self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}

def test_check_output(self):
place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3)


if __name__ == "__main__":
unittest.main()
33 changes: 33 additions & 0 deletions python/paddle/v2/fluid/tests/test_sgd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,39 @@ def test_check_output(self):
self.check_output()


class TestSGDLARSOp(OpTest):
def setUp(self):
self.op_type = "sgd"
w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
lr = np.array([3.0]).astype("float32")

use_local_lr = True
local_gw_ratio = 0.02
weight_decay = 0.005
self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}

self.attrs = {
'use_local_lr': use_local_lr,
'local_gw_ratio': local_gw_ratio,
'weight_decay': weight_decay,
}

local_lr = lr
if use_local_lr:
w_norm = np.sqrt(np.sum(np.square(w)))
g_norm = np.sqrt(np.sum(np.square(g)))
local_lr = lr * local_gw_ratio * w_norm / \
(g_norm + weight_decay * w_norm)

param_out = w - local_lr * g
self.outputs = {'ParamOut': param_out}

def test_check_output(self):
place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3)


class TestSparseSGDOp(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
Expand Down