Skip to content

Commit 000d751

Browse files
committed
fix backward op.
1 parent 201c2bc commit 000d751

File tree

4 files changed

+76
-73
lines changed

4 files changed

+76
-73
lines changed

paddle/operators/cross_entropy_op.cc

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
3737
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
3838
"The 1st dimension of Input(X) and Input(Label) should "
3939
"be equal.");
40-
if (ctx.Attr<bool>("soft_label")) {
40+
if (ctx.Attr<bool>("softLabel")) {
4141
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
42-
"If Attr(soft_label) == true, the 2nd dimension of "
42+
"If Attr(softLabel) == true, the 2nd dimension of "
4343
"Input(X) and Input(Label) should be equal.");
4444
} else {
4545
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
46-
"If Attr(soft_label) == false, the 2nd dimension of "
46+
"If Attr(softLabel) == false, the 2nd dimension of "
4747
"Input(Label) should be 1.");
4848
}
4949

@@ -63,6 +63,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
6363
"Input(Label) should be not null.");
6464
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
6565
"Input(Y@GRAD) shoudl be not null.");
66+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")),
67+
"Output(X@GRAD) should be not null.");
6668

6769
auto x = ctx.Input<Tensor>("X");
6870
auto label = ctx.Input<Tensor>("Label");
@@ -80,13 +82,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
8082
"be equal.");
8183
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
8284
"The 2nd dimension of Input(Y@Grad) should be 1.");
83-
if (ctx.Attr<bool>("soft_label")) {
85+
if (ctx.Attr<bool>("softLabel")) {
8486
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
85-
"When Attr(soft_label) == true, the 2nd dimension of "
87+
"When Attr(softLabel) == true, the 2nd dimension of "
8688
"Input(X) and Input(Label) should be equal.");
8789
} else {
8890
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
89-
"When Attr(soft_label) == false, the 2nd dimension of "
91+
"When Attr(softLabel) == false, the 2nd dimension of "
9092
"Input(Label) should be 1.");
9193
}
9294

@@ -105,18 +107,19 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
105107
"where N is the batch size and D is the number of classes. "
106108
"This input is a probability computed by the previous operator, "
107109
"which is almost always the result of a softmax operator.");
108-
AddInput("Label",
109-
"(Tensor, default Tensor<int>), the ground truth which is "
110-
"a 1-D or 2-D tensor. "
111-
"When soft_label is set to 0, `Label` is a Tensor<int> with shape "
112-
"[N x 1]. "
113-
"When soft_label is set to 1, `Label` is a Tensor<float/double> "
114-
"with shape [N x K].");
110+
AddInput(
111+
"Label",
112+
"(Tensor, default Tensor<int>), the ground truth which is "
113+
"a 2-D tensor. "
114+
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
115+
"[N x 1]. "
116+
"When softLabel is set to true, `Label` is a Tensor<float/double> "
117+
"with shape [N x K].");
115118
AddOutput("Y",
116-
"(Tensor, default Tensor<float>), a 1-D tensor "
119+
"(Tensor, default Tensor<float>), a 2-D tensor "
117120
"with shape [N x 1]. The cross entropy loss.");
118121
AddAttr<bool>(
119-
"soft_label",
122+
"softLabel",
120123
"(bool, default false), a flag to indicate whether to interpretate "
121124
"the given labels as soft labels.")
122125
.SetDefault(false);
@@ -126,12 +129,12 @@ CrossEntropy Operator.
126129
It supports both standard cross-entropy and soft-label cross-entropy loss
127130
computation.
128131
1) One-hot cross-entropy:
129-
soft_label = False, Label[i, 0] indicates the class index for sample i:
132+
softLabel = false, Label[i, 0] indicates the class index for sample i:
130133
131134
Y[i] = -log(X[i, Label[i]])
132135
133136
2) Soft-label cross-entropy:
134-
soft_label = True, Label[i, j] indicates the soft label of class j
137+
softLabel = true, Label[i, j] indicates the soft label of class j
135138
for sample i:
136139
137140
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}

paddle/operators/cross_entropy_op.cu

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
7070

7171
// TODO(qingqing): make zero setting a common function.
7272
template <typename T>
73-
__global__ void zero(T* X, const int N) {
73+
__global__ void Zero(T* X, const int N) {
7474
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
7575
i += blockDim.x * gridDim.x) {
7676
X[i] = 0.0;
@@ -108,18 +108,17 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
108108
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
109109
"This kernel only runs on GPU device.");
110110

111-
auto x = ctx.Input<Tensor>("X");
112-
auto y = ctx.Output<Tensor>("Y");
113-
auto label = ctx.Input<Tensor>("Label");
111+
const Tensor* x = ctx.Input<Tensor>("X");
112+
const Tensor* label = ctx.Input<Tensor>("Label");
113+
Tensor* y = ctx.Output<Tensor>("Y");
114114

115-
auto* x_data = x->data<T>();
116-
y->mutable_data<T>(ctx.GetPlace());
117-
auto* y_data = y->data<T>();
115+
const T* x_data = x->data<T>();
116+
T* y_data = y->mutable_data<T>(ctx.GetPlace());
118117

119118
int batch_size = x->dims()[0];
120119
int class_num = x->dims()[1];
121120

122-
if (ctx.Attr<bool>("soft_label")) {
121+
if (ctx.Attr<bool>("softLabel")) {
123122
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
124123
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
125124

@@ -148,38 +147,41 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
148147
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
149148
"This kernel only runs on GPU device.");
150149

151-
auto x = ctx.Input<Tensor>("X");
152-
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
153-
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
154-
auto label = ctx.Input<Tensor>("Label");
150+
const Tensor* x = ctx.Input<Tensor>("X");
151+
const Tensor* label = ctx.Input<Tensor>("Label");
152+
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
155153

156-
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
157-
auto* dy_data = dy->data<T>();
158-
auto* x_data = x->data<T>();
154+
const T* dy_data =
155+
ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
156+
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
157+
const T* x_data = x->data<T>();
159158

160-
int n = x->dims()[0];
161-
int d = x->dims()[1];
159+
int batch_size = x->dims()[0];
160+
int class_num = x->dims()[1];
162161

163162
int block = 512;
164-
int grid = (n * d + block - 1) / block;
165-
zero<T><<<grid, block, 0,
166-
reinterpret_cast<const platform::CUDADeviceContext&>(
167-
ctx.device_context())
168-
.stream()>>>(dx_data, n * d);
169-
if (ctx.Attr<bool>("soft_label")) {
163+
int grid = (batch_size * class_num + block - 1) / block;
164+
165+
if (ctx.Attr<bool>("softLabel")) {
170166
auto* label_data = label->data<T>();
171167
SoftCrossEntropyGradientKernel<T><<<
172168
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
173169
ctx.device_context())
174170
.stream()>>>(dx_data, dy_data, x_data, label_data,
175-
n, d);
171+
batch_size, class_num);
176172
} else {
173+
Zero<T><<<grid, block, 0,
174+
reinterpret_cast<const platform::CUDADeviceContext&>(
175+
ctx.device_context())
176+
.stream()>>>(dx_data, batch_size * class_num);
177+
177178
auto* label_data = label->data<int>();
179+
grid = (batch_size + block - 1) / block;
178180
CrossEntropyGradientKernel<T><<<
179181
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
180182
ctx.device_context())
181183
.stream()>>>(dx_data, dy_data, x_data, label_data,
182-
n, d);
184+
batch_size, class_num);
183185
}
184186
}
185187
};

paddle/operators/cross_entropy_op.h

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ class CrossEntropyOpKernel : public framework::OpKernel {
4242
public:
4343
void Compute(const framework::ExecutionContext& ctx) const override {
4444
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
45-
"It must use CPUPlace.");
45+
"This kernel only runs on CPU.");
4646
const Tensor* x = ctx.Input<Tensor>("X");
4747
const Tensor* labels = ctx.Input<Tensor>("Label");
4848
Tensor* y = ctx.Output<Tensor>("Y");
49-
y->mutable_data<T>(ctx.GetPlace());
49+
T* y_data = y->mutable_data<T>(ctx.GetPlace());
5050

5151
const int batch_size = x->dims()[0];
52-
if (ctx.Attr<bool>("soft_label")) {
52+
if (ctx.Attr<bool>("softLabel")) {
5353
auto prob = EigenMatrix<T>::From(*x);
5454
auto lbl_mat = EigenMatrix<T>::From(*labels);
5555
auto loss = EigenMatrix<T>::From(*y);
@@ -60,9 +60,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
6060
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
6161
} else {
6262
const int class_num = x->dims()[1];
63-
6463
const T* x_data = x->data<T>();
65-
T* y_data = y->data<T>();
6664

6765
const int* label_data = labels->data<int>();
6866
for (int i = 0; i < batch_size; ++i) {
@@ -78,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
7876
public:
7977
void Compute(const framework::ExecutionContext& ctx) const override {
8078
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
81-
"It must use CPUPlace.");
82-
83-
auto x = ctx.Input<Tensor>("X");
84-
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
85-
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
86-
auto label = ctx.Input<Tensor>("Label");
87-
88-
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
89-
auto* dy_data = dy->data<T>();
90-
auto* x_data = x->data<T>();
79+
"This kernel only runs on CPU.");
80+
const Tensor* x = ctx.Input<Tensor>("X");
81+
const Tensor* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
82+
const Tensor* label = ctx.Input<Tensor>("Label");
83+
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
84+
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
9185

92-
int batch_size = x->dims()[0];
9386
int class_num = x->dims()[1];
94-
95-
// TODO(qingqing): make zero setting an common function.
96-
if (ctx.Attr<bool>("soft_label")) {
97-
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
98-
int index = 0;
99-
for (int i = 0; i < batch_size; ++i) {
100-
for (int j = 0; j < class_num; ++j) {
101-
dx_data[index] = -label_data[index] * dy_data[i] / x_data[index];
102-
index++;
103-
}
104-
}
87+
if (ctx.Attr<bool>("softLabel")) {
88+
auto x_mat = EigenMatrix<T>::From(*x);
89+
auto dy_mat = EigenMatrix<T>::From(*dy);
90+
auto lbl_mat = EigenMatrix<T>::From(*label);
91+
auto dx_mat = EigenMatrix<T>::From(*dx);
92+
93+
dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
94+
-(lbl_mat * dy_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) /
95+
x_mat);
10596
} else {
106-
auto* label_data = label->data<int>();
97+
int batch_size = x->dims()[0];
98+
const T* dy_data = dy->data<T>();
99+
const T* x_data = x->data<T>();
100+
const int* label_data = label->data<int>();
101+
102+
// TODO(qingqing): make zero setting a common function.
107103
memset(dx_data, 0, sizeof(T) * batch_size * class_num);
104+
108105
for (int i = 0; i < batch_size; ++i) {
109106
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
110107
int index = i * class_num + label_data[i];

python/paddle/v2/framework/tests/test_cross_entropy_op.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def setUp(self):
2121

2222
self.inputs = {"X": X, "Label": label}
2323
self.outputs = {"Y": cross_entropy}
24-
self.attrs = {"soft_label": False}
24+
self.attrs = {"softLabel": False}
2525

2626
def test_check_output(self):
2727
self.check_output()
@@ -49,7 +49,7 @@ def setUp(self):
4949

5050
self.inputs = {"X": X, "Label": label}
5151
self.outputs = {"Y": cross_entropy}
52-
self.attrs = {"soft_label": True}
52+
self.attrs = {"softLabel": True}
5353

5454
def test_check_output(self):
5555
self.check_output()
@@ -73,6 +73,7 @@ def setUp(self):
7373
0, class_num, (batch_size), dtype="int32")
7474
label = np.zeros(X.shape)
7575
label[np.arange(batch_size), label_index] = 1
76+
7677
cross_entropy = np.asmatrix(
7778
[[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])],
7879
dtype="float32")
@@ -81,7 +82,7 @@ def setUp(self):
8182

8283
self.inputs = {"X": X, "Label": label}
8384
self.outputs = {"Y": cross_entropy}
84-
self.attrs = {"soft_label": True}
85+
self.attrs = {"softLabel": True}
8586

8687
def test_check_output(self):
8788
self.check_output()

0 commit comments

Comments
 (0)