-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add soft-label support for cross-entropy operator. #4081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
qingqing01
merged 6 commits into
PaddlePaddle:develop
from
xinghai-sun:soft_label_cross_entropy
Sep 19, 2017
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6d60352
Add soft-label support for cross-entropy operator.
xinghai-sun d7717f2
Merge branch 'develop' into soft_label_cross_entropy
xinghai-sun e870682
Update cross entropy operator by following reviewer's comments.
xinghai-sun 8e7fe8c
Merge branch 'develop' into soft_label_cross_entropy
xinghai-sun d8046da
Use soft_label attribute for cross-entropy.
xinghai-sun 19de8ae
Fixed a error in mnist unitest.
xinghai-sun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/cross_entropy_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using framework::LoDTensor; | ||
|
||
class CrossEntropyOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), | ||
"Input(Label) must not be null."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); | ||
|
||
auto x = ctx.Input<Tensor>("X"); | ||
auto label = ctx.Input<Tensor>("Label"); | ||
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); | ||
PADDLE_ENFORCE_EQ(label->dims().size(), 2, | ||
"Input(Label)'s rank must be 2."); | ||
// TODO(xinghai-sun): remove this check after swtiching to bool | ||
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 || | ||
ctx.Attr<int>("soft_label") == 1); | ||
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], | ||
"The 1st dimension of Input(X) and Input(Label) must " | ||
"be equal."); | ||
if (ctx.Attr<int>("soft_label") == 1) { | ||
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], | ||
"If Attr(soft_label) == 1, The 2nd dimension of " | ||
"Input(X) and Input(Label) must be equal."); | ||
} else { | ||
PADDLE_ENFORCE_EQ(label->dims()[1], 1, | ||
"If Attr(soft_label) == 0, The 2nd dimension of " | ||
"Input(Label) must be 1."); | ||
} | ||
|
||
ctx.Output<LoDTensor>("Y")->Resize({x->dims()[0], 1}); | ||
} | ||
}; | ||
|
||
class CrossEntropyGradientOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), | ||
"Input(Label) must not be null."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), | ||
"Input(Y@GRAD) must not be null."); | ||
|
||
auto x = ctx.Input<Tensor>("X"); | ||
auto label = ctx.Input<Tensor>("Label"); | ||
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y")); | ||
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); | ||
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); | ||
PADDLE_ENFORCE_EQ(label->dims().size(), 2, | ||
"Input(Label)'s rank must be 2."); | ||
// TODO(xinghai-sun): remove this check after swtiching to bool | ||
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 || | ||
ctx.Attr<int>("soft_label") == 1); | ||
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], | ||
"The 1st dimension of Input(X) and Input(Label) must " | ||
"be equal."); | ||
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], | ||
"The 1st dimension of Input(X) and Input(Y@Grad) must " | ||
"be equal."); | ||
PADDLE_ENFORCE_EQ(dy->dims()[1], 1, | ||
"The 2nd dimension of Input(Y@Grad) must be 1."); | ||
if (ctx.Attr<int>("soft_label") == 1) { | ||
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], | ||
"If Attr(soft_label) == 1, The 2nd dimension of " | ||
"Input(X) and Input(Label) must be equal."); | ||
} else { | ||
PADDLE_ENFORCE_EQ(label->dims()[1], 1, | ||
"If Attr(soft_label) == 0, The 2nd dimension of " | ||
"Input(Label) must be 1."); | ||
} | ||
|
||
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X")); | ||
dx->Resize(x->dims()); | ||
} | ||
}; | ||
|
||
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
CrossEntropyOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The first input of CrossEntropyOp"); | ||
AddInput("Label", "The second input of CrossEntropyOp"); | ||
AddOutput("Y", "The output of CrossEntropyOp"); | ||
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0); | ||
|
||
AddComment(R"DOC( | ||
CrossEntropy Operator. | ||
|
||
It supports both standard cross-entropy and soft-label cross-entropy loss | ||
computation. | ||
1) One-hot cross-entropy: | ||
soft_label = 0, Label[i, 0] indicates the class index for sample i: | ||
|
||
Y[i] = -log(X[i, Label[i]]) | ||
|
||
2) Soft-label cross-entropy: | ||
soft_label = 1, Label[i, j] indicates the soft label of class j | ||
for sample i: | ||
|
||
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} | ||
|
||
Please make sure that in this case the summuation of each row of Label | ||
equals one. | ||
|
||
3) One-hot cross-entropy with vecterized Input(Label): | ||
As a special case of 2), when each row of Input(Label) has only one | ||
non-zero element (equals 1), soft-label cross-entropy degenerates to a | ||
one-hot cross-entropy with one-hot label representation. | ||
)DOC"); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, | ||
cross_entropy_grad, ops::CrossEntropyGradientOp); | ||
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>); | ||
REGISTER_OP_CPU_KERNEL(cross_entropy_grad, | ||
ops::CrossEntropyGradientOpKernel<float>); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/cross_entropy_op.h" | ||
#include "paddle/platform/assert.h" | ||
#include "paddle/platform/hostdevice.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, | ||
const int N, const int D) { | ||
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. | ||
// CUDA_1D_KERNEL_LOOP(i, N) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
PADDLE_ASSERT(label[i] >= 0 && label[i] < D); | ||
Y[i] = -tolerable_value(log(X[i * D + label[i]])); | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, | ||
const int N, const int D) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
T sum = static_cast<T>(0); | ||
for (int j = 0; j < D; j++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); | ||
} | ||
Y[i] = -sum; | ||
} | ||
} | ||
|
||
// TODO(qingqing): make zero setting an common function. | ||
template <typename T> | ||
__global__ void zero(T* X, const int N) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
X[i] = 0.0; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, | ||
const int* label, const int N, | ||
const int D) { | ||
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. | ||
// CUDA_1D_KERNEL_LOOP(i, N) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
int idx = i * D + label[i]; | ||
dX[idx] = -dY[i] / X[idx]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, | ||
const T* label, const int N, | ||
const int D) { | ||
// TOOD(qingqing): optimize for this kernel | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
for (int j = 0; j < D; ++j) { | ||
int idx = i * D + j; | ||
dX[idx] = -label[idx] * dY[i] / X[idx]; | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
class CrossEntropyOpCUDAKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), | ||
"It must use GPUPlace."); | ||
|
||
auto x = ctx.Input<Tensor>("X"); | ||
auto y = ctx.Output<Tensor>("Y"); | ||
auto label = ctx.Input<Tensor>("Label"); | ||
|
||
auto* x_data = x->data<T>(); | ||
y->mutable_data<T>(ctx.GetPlace()); | ||
auto* y_data = y->data<T>(); | ||
|
||
int n = x->dims()[0]; | ||
int d = x->dims()[1]; | ||
int block = 512; | ||
int grid = (n + block - 1) / block; | ||
// TODO(qingqing) launch kernel on specified stream | ||
// base on ExecutionContext. | ||
if (ctx.Attr<int>("soft_label") == 1) { | ||
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); | ||
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, | ||
d); | ||
} else { | ||
auto* label_data = ctx.Input<Tensor>("Label")->data<int>(); | ||
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d); | ||
} | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), | ||
"It must use GPUPlace."); | ||
|
||
auto x = ctx.Input<Tensor>("X"); | ||
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y")); | ||
auto label = ctx.Input<Tensor>("Label"); | ||
|
||
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); | ||
auto* dy_data = dy->data<T>(); | ||
auto* x_data = x->data<T>(); | ||
|
||
int n = x->dims()[0]; | ||
int d = x->dims()[1]; | ||
int block = 512; | ||
int grid = (n * d + block - 1) / block; | ||
zero<T><<<grid, block>>>(dx_data, n * d); | ||
grid = (n + block - 1) / block; | ||
// TODO(qingqing): launch kernel on specified stream | ||
// base on ExecutionContext. | ||
if (ctx.Attr<int>("soft_label") == 1) { | ||
auto* label_data = label->data<T>(); | ||
SoftCrossEntropyGradientKernel<T><<<grid, block>>>( | ||
dx_data, dy_data, x_data, label_data, n, d); | ||
} else { | ||
auto* label_data = label->data<int>(); | ||
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data, | ||
label_data, n, d); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>); | ||
REGISTER_OP_GPU_KERNEL(cross_entropy_grad, | ||
ops::CrossEntropyGradientOpCUDAKernel<float>); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我的小疑问请教一下 @qingqing01
i += blockDim.x * gridDim.x
循环变量i
只要增加一次,就会直接超过最大thread数,也就是这个kernel函数实际是并不需要循环多次,只计算输出向量的一个位置,逻辑上等价于判断i>= batch_size
时,直接return,写成循环有什么考虑呢?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果针对下面设置grid, threadd的方式,确实不需要for循环。但如果将下面的grid的设置为一个固定的数,就是总共发起固定数目的总线程数,for循环就是有用的,有可能一个线程计算多个输出。这样这个kernel已经处理了边界,就不需要修改了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
明白啦~ 确实,cross entropy 这个 kernel 比较简单也比较特殊。grid 数目也已经计算好。