Skip to content

Commit fd5199f

Browse files
authored
Merge pull request #3989 from pkuyym/fix-3923-r
Add huber loss operator.
2 parents 24d988c + ac5f421 commit fd5199f

File tree

4 files changed

+311
-0
lines changed

4 files changed

+311
-0
lines changed

paddle/operators/huber_loss_op.cc

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/huber_loss_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class HuberLossOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
26+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
27+
28+
auto x_dims = ctx->GetInputDim("X");
29+
auto y_dims = ctx->GetInputDim("Y");
30+
31+
PADDLE_ENFORCE_EQ(x_dims, y_dims);
32+
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
33+
"The rank of Input(X) must be 2 and the shape is "
34+
"[batch_size, 1].");
35+
PADDLE_ENFORCE_EQ(x_dims[1], 1,
36+
"Each row of Input(X) contains a real value, "
37+
"so the 2nd dimension of Input(X) must be 1.");
38+
39+
ctx->SetOutputDim("Residual", x_dims);
40+
ctx->SetOutputDim("Out", {x_dims[0], 1});
41+
ctx->ShareLoD("X", "Out");
42+
}
43+
};
44+
45+
template <typename AttrType>
46+
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
47+
public:
48+
HuberLossOpMaker(framework::OpProto* proto,
49+
framework::OpAttrChecker* op_checker)
50+
: OpProtoAndCheckerMaker(proto, op_checker) {
51+
AddInput("X",
52+
"The input value of huber loss op."
53+
"X is a 2-D tensor with shape [batch_size, 1].");
54+
AddInput("Y",
55+
"The target value of huber loss op."
56+
"Y is a 2-D tensor with shape [batch_size, 1].");
57+
AddOutput("Residual",
58+
"Intermediate tensor to cache residual value between Y and X."
59+
"The shape is same as Input(X) and will be reused in backward.")
60+
.AsIntermediate();
61+
AddOutput("Out",
62+
"The output tensor with shape [batch_size, 1] which represents "
63+
"the huber loss.");
64+
AddAttr<AttrType>("delta", "Hyper parameter in huber loss.");
65+
AddComment(R"DOC(
66+
Huber loss is a loss function used in robust regression. We define X as the
67+
input value and Y as the target value. Huber loss can evaluate the fitness of
68+
X to Y. Different from MSE loss, Huber loss is more robust for outliers. The
69+
shape of X and Y are [batch_size, 1]. The equation is:
70+
71+
L_{\delta}(y, f(x)) =
72+
\begin{cases}
73+
0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\
74+
\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise
75+
\end{cases}
76+
77+
)DOC");
78+
}
79+
};
80+
81+
class HuberLossGradOp : public framework::OperatorWithKernel {
82+
public:
83+
using framework::OperatorWithKernel::OperatorWithKernel;
84+
85+
void InferShape(framework::InferShapeContext* ctx) const override {
86+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
87+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
88+
PADDLE_ENFORCE(ctx->HasInput("Residual"),
89+
"Input(Residual) should not be null.");
90+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
91+
"Input(Out@GRAD) should not be null.");
92+
93+
auto x_dims = ctx->GetInputDim("X");
94+
auto y_dims = ctx->GetInputDim("Y");
95+
auto residual_dims = ctx->GetInputDim("Residual");
96+
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
97+
98+
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
99+
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
100+
101+
auto x_grad_name = framework::GradVarName("X");
102+
auto y_grad_name = framework::GradVarName("Y");
103+
if (ctx->HasOutput(x_grad_name)) {
104+
ctx->SetOutputDim(x_grad_name, x_dims);
105+
}
106+
if (ctx->HasOutput(y_grad_name)) {
107+
ctx->SetOutputDim(y_grad_name, y_dims);
108+
}
109+
}
110+
};
111+
112+
} // namespace operators
113+
} // namespace paddle
114+
115+
namespace ops = paddle::operators;
116+
REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
117+
huber_loss_grad, ops::HuberLossGradOp);
118+
REGISTER_OP_CPU_KERNEL(huber_loss,
119+
ops::HuberLossKernel<paddle::platform::CPUPlace, float>);
120+
REGISTER_OP_CPU_KERNEL(
121+
huber_loss_grad,
122+
ops::HuberLossGradKernel<paddle::platform::CPUPlace, float>);

paddle/operators/huber_loss_op.cu

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/huber_loss_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(huber_loss,
20+
ops::HuberLossKernel<paddle::platform::GPUPlace, float>);
21+
REGISTER_OP_GPU_KERNEL(
22+
huber_loss_grad,
23+
ops::HuberLossGradKernel<paddle::platform::GPUPlace, float>);

paddle/operators/huber_loss_op.h

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
#include "paddle/platform/hostdevice.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
template <typename T, int MajorType = Eigen::RowMajor,
25+
typename IndexType = Eigen::DenseIndex>
26+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
27+
28+
template <typename T>
29+
struct HuberLossForward {
30+
HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {}
31+
32+
HOSTDEVICE T operator()(const T& val) const {
33+
T abs_val = std::abs(val);
34+
if (abs_val <= delta) {
35+
return static_cast<T>(0.5) * val * val;
36+
} else {
37+
return delta * (abs_val - static_cast<T>(0.5) * delta);
38+
}
39+
}
40+
41+
T delta;
42+
};
43+
44+
template <typename Place, typename T, typename AttrType = T>
45+
class HuberLossKernel : public framework::OpKernel<T> {
46+
public:
47+
void Compute(const framework::ExecutionContext& context) const override {
48+
auto* in0 = context.Input<Tensor>("X");
49+
auto* in1 = context.Input<Tensor>("Y");
50+
auto* out0 = context.Output<Tensor>("Residual");
51+
auto* out1 = context.Output<Tensor>("Out");
52+
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
53+
auto place = context.GetEigenDevice<Place>();
54+
55+
auto x = EigenVector<T>::Flatten(*in0);
56+
auto y = EigenVector<T>::Flatten(*in1);
57+
out0->mutable_data<T>(context.GetPlace());
58+
auto residual = EigenVector<T>::Flatten(*out0);
59+
residual.device(place) = y - x;
60+
out1->mutable_data<T>(context.GetPlace());
61+
auto loss = EigenVector<T>::Flatten(*out1);
62+
loss.device(place) = residual.unaryExpr(HuberLossForward<T>(delta));
63+
}
64+
};
65+
66+
template <typename T>
67+
struct HuberLossBackward {
68+
HOSTDEVICE HuberLossBackward(const T& delta, T sign)
69+
: sign(sign), delta(delta) {}
70+
71+
HOSTDEVICE T operator()(const T& val) const {
72+
T abs_val = std::abs(val);
73+
if (abs_val <= delta) {
74+
return sign * val;
75+
} else {
76+
if (val > 0) {
77+
return sign * delta;
78+
} else {
79+
return -1 * sign * delta;
80+
}
81+
}
82+
}
83+
84+
T sign;
85+
T delta;
86+
};
87+
88+
template <typename Place, typename T, typename AttrType = T>
89+
class HuberLossGradKernel : public framework::OpKernel<T> {
90+
public:
91+
void Compute(const framework::ExecutionContext& context) const override {
92+
auto* in0 = context.Input<Tensor>("Residual");
93+
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
94+
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
95+
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
96+
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
97+
auto place = context.GetEigenDevice<Place>();
98+
99+
auto residual = EigenVector<T>::Flatten(*in0);
100+
auto out_grad = EigenVector<T>::Flatten(*in1);
101+
102+
if (out0) {
103+
out0->mutable_data<T>(context.GetPlace());
104+
auto x_grad = EigenVector<T>::Flatten(*out0);
105+
x_grad.device(place) =
106+
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
107+
}
108+
109+
if (out1) {
110+
out1->mutable_data<T>(context.GetPlace());
111+
auto y_grad = EigenVector<T>::Flatten(*out1);
112+
y_grad.device(place) =
113+
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
114+
}
115+
}
116+
};
117+
118+
} // namespace operators
119+
} // namespace paddle
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
def huber_loss_forward(val, delta):
7+
abs_val = abs(val)
8+
if abs_val <= delta:
9+
return 0.5 * val * val
10+
else:
11+
return delta * (abs_val - 0.5 * delta)
12+
13+
14+
class TestHuberLossOp(OpTest):
15+
def setUp(self):
16+
self.op_type = 'huber_loss'
17+
samples_num = 64
18+
delta = 1.0
19+
self.inputs = {
20+
'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
21+
'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
22+
}
23+
residual = self.inputs['Y'] - self.inputs['X']
24+
loss = np.vectorize(huber_loss_forward)(residual, delta)
25+
self.attrs = {'delta': delta}
26+
self.outputs = {
27+
'Residual': residual,
28+
'Out': loss.reshape((samples_num, 1))
29+
}
30+
31+
def test_check_output(self):
32+
self.check_output()
33+
34+
def test_check_grad_normal(self):
35+
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008)
36+
37+
def test_check_grad_ingore_x(self):
38+
self.check_grad(
39+
['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual"))
40+
41+
def test_check_grad_ingore_y(self):
42+
self.check_grad(
43+
['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
44+
45+
46+
if __name__ == '__main__':
47+
unittest.main()

0 commit comments

Comments
 (0)