Skip to content

Commit 7d65321

Browse files
authored
Merge pull request #4237 from lcy-seso/optimize_cross_entropy_kernel
optimize cross entropy kernel.
2 parents 1c0a4c9 + 000d751 commit 7d65321

File tree

7 files changed

+242
-171
lines changed

7 files changed

+242
-171
lines changed

paddle/operators/accuracy_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
6969
return;
7070
}
7171

72-
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>(
73-
num_samples, infer_width, inference_data, label_data, accuracy_data);
72+
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
73+
1, PADDLE_CUDA_NUM_THREADS, 0,
74+
reinterpret_cast<const platform::CUDADeviceContext&>(
75+
ctx.device_context())
76+
.stream()>>>(num_samples, infer_width, inference_data, label_data,
77+
accuracy_data);
7478
}
7579
};
7680

paddle/operators/cross_entropy_op.cc

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
2323

2424
protected:
2525
void InferShape(const framework::InferShapeContext &ctx) const override {
26-
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
26+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
2727
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
28-
"Input(Label) must not be null.");
29-
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
28+
"Input(Label) should be not null.");
29+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
30+
"Output(Y) should be not null.");
3031

3132
auto x = ctx.Input<Tensor>("X");
3233
auto label = ctx.Input<Tensor>("Label");
33-
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
34+
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
3435
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
35-
"Input(Label)'s rank must be 2.");
36+
"Input(Label)'s rank should be 2.");
3637
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
37-
"The 1st dimension of Input(X) and Input(Label) must "
38+
"The 1st dimension of Input(X) and Input(Label) should "
3839
"be equal.");
39-
if (ctx.Attr<bool>("soft_label")) {
40+
if (ctx.Attr<bool>("softLabel")) {
4041
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
41-
"If Attr(soft_label) == true, The 2nd dimension of "
42-
"Input(X) and Input(Label) must be equal.");
42+
"If Attr(softLabel) == true, the 2nd dimension of "
43+
"Input(X) and Input(Label) should be equal.");
4344
} else {
4445
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
45-
"If Attr(soft_label) == false, The 2nd dimension of "
46-
"Input(Label) must be 1.");
46+
"If Attr(softLabel) == false, the 2nd dimension of "
47+
"Input(Label) should be 1.");
4748
}
4849

4950
ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
@@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
5758

5859
protected:
5960
void InferShape(const framework::InferShapeContext &ctx) const override {
60-
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
61+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
6162
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
62-
"Input(Label) must not be null.");
63+
"Input(Label) should be not null.");
6364
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
64-
"Input(Y@GRAD) must not be null.");
65+
"Input(Y@GRAD) shoudl be not null.");
66+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")),
67+
"Output(X@GRAD) should be not null.");
6568

6669
auto x = ctx.Input<Tensor>("X");
6770
auto label = ctx.Input<Tensor>("Label");
6871
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
69-
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
70-
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
72+
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
73+
PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
74+
"Input(Y@Grad)'s rank should be 2.");
7175
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
72-
"Input(Label)'s rank must be 2.");
76+
"Input(Label)'s rank should be 2.");
7377
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
74-
"The 1st dimension of Input(X) and Input(Label) must "
78+
"The 1st dimension of Input(X) and Input(Label) should "
7579
"be equal.");
7680
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
77-
"The 1st dimension of Input(X) and Input(Y@Grad) must "
81+
"The 1st dimension of Input(X) and Input(Y@Grad) should "
7882
"be equal.");
7983
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
80-
"The 2nd dimension of Input(Y@Grad) must be 1.");
81-
if (ctx.Attr<bool>("soft_label")) {
84+
"The 2nd dimension of Input(Y@Grad) should be 1.");
85+
if (ctx.Attr<bool>("softLabel")) {
8286
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
83-
"If Attr(soft_label) == true, The 2nd dimension of "
84-
"Input(X) and Input(Label) must be equal.");
87+
"When Attr(softLabel) == true, the 2nd dimension of "
88+
"Input(X) and Input(Label) should be equal.");
8589
} else {
8690
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
87-
"If Attr(soft_label) == false, The 2nd dimension of "
88-
"Input(Label) must be 1.");
91+
"When Attr(softLabel) == false, the 2nd dimension of "
92+
"Input(Label) should be 1.");
8993
}
9094

9195
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
@@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
98102
CrossEntropyOpMaker(framework::OpProto *proto,
99103
framework::OpAttrChecker *op_checker)
100104
: OpProtoAndCheckerMaker(proto, op_checker) {
101-
AddInput("X", "The first input of CrossEntropyOp");
102-
AddInput("Label", "The second input of CrossEntropyOp");
103-
AddOutput("Y", "The output of CrossEntropyOp");
104-
AddAttr<bool>("soft_label", "Is soft label. Default zero.")
105+
AddInput("X",
106+
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
107+
"where N is the batch size and D is the number of classes. "
108+
"This input is a probability computed by the previous operator, "
109+
"which is almost always the result of a softmax operator.");
110+
AddInput(
111+
"Label",
112+
"(Tensor, default Tensor<int>), the ground truth which is "
113+
"a 2-D tensor. "
114+
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
115+
"[N x 1]. "
116+
"When softLabel is set to true, `Label` is a Tensor<float/double> "
117+
"with shape [N x K].");
118+
AddOutput("Y",
119+
"(Tensor, default Tensor<float>), a 2-D tensor "
120+
"with shape [N x 1]. The cross entropy loss.");
121+
AddAttr<bool>(
122+
"softLabel",
123+
"(bool, default false), a flag to indicate whether to interpretate "
124+
"the given labels as soft labels.")
105125
.SetDefault(false);
106-
107126
AddComment(R"DOC(
108127
CrossEntropy Operator.
109128
110129
It supports both standard cross-entropy and soft-label cross-entropy loss
111130
computation.
112131
1) One-hot cross-entropy:
113-
soft_label = False, Label[i, 0] indicates the class index for sample i:
132+
softLabel = false, Label[i, 0] indicates the class index for sample i:
114133
115134
Y[i] = -log(X[i, Label[i]])
116135
117136
2) Soft-label cross-entropy:
118-
soft_label = True, Label[i, j] indicates the soft label of class j
137+
softLabel = true, Label[i, j] indicates the soft label of class j
119138
for sample i:
120139
121140
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}

paddle/operators/cross_entropy_op.cu

Lines changed: 92 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
2828
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
2929
i += blockDim.x * gridDim.x) {
3030
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
31-
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
31+
Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
3232
}
3333
}
3434

35+
template <typename T>
36+
__device__ __forceinline__ T sum_single_warp(T val) {
37+
val += __shfl_down(val, 16);
38+
val += __shfl_down(val, 8);
39+
val += __shfl_down(val, 4);
40+
val += __shfl_down(val, 2);
41+
val += __shfl_down(val, 1);
42+
return val;
43+
}
44+
3545
template <typename T>
3646
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
37-
const int N, const int D) {
38-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
39-
i += blockDim.x * gridDim.x) {
40-
T sum = static_cast<T>(0);
41-
for (int j = 0; j < D; j++) {
42-
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
43-
}
44-
Y[i] = -sum;
47+
const int class_num) {
48+
int tid = threadIdx.x;
49+
extern __shared__ T d_sum[];
50+
d_sum[tid] = 0;
51+
52+
int cur_idx = tid;
53+
int next_idx = blockIdx.x * class_num + tid;
54+
while (cur_idx < class_num) {
55+
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
56+
next_idx += blockDim.x;
57+
cur_idx += blockDim.x;
58+
}
59+
__syncthreads();
60+
61+
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
62+
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
63+
__syncthreads();
4564
}
65+
66+
T val = d_sum[tid];
67+
val = sum_single_warp<T>(val);
68+
if (tid == 0) Y[blockIdx.x] = -val;
4669
}
4770

48-
// TODO(qingqing): make zero setting an common function.
71+
// TODO(qingqing): make zero setting a common function.
4972
template <typename T>
50-
__global__ void zero(T* X, const int N) {
73+
__global__ void Zero(T* X, const int N) {
5174
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
5275
i += blockDim.x * gridDim.x) {
5376
X[i] = 0.0;
@@ -71,13 +94,10 @@ template <typename T>
7194
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
7295
const T* label, const int N,
7396
const int D) {
74-
// TOOD(qingqing): optimize for this kernel
75-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
76-
i += blockDim.x * gridDim.x) {
77-
for (int j = 0; j < D; ++j) {
78-
int idx = i * D + j;
79-
dX[idx] = -label[idx] * dY[i] / X[idx];
80-
}
97+
int ids = blockIdx.x * blockDim.x + threadIdx.x;
98+
if (ids < N * D) {
99+
int row_ids = ids / D;
100+
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
81101
}
82102
}
83103

@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
86106
public:
87107
void Compute(const framework::ExecutionContext& ctx) const override {
88108
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
89-
"It must use GPUPlace.");
109+
"This kernel only runs on GPU device.");
90110

91-
auto x = ctx.Input<Tensor>("X");
92-
auto y = ctx.Output<Tensor>("Y");
93-
auto label = ctx.Input<Tensor>("Label");
111+
const Tensor* x = ctx.Input<Tensor>("X");
112+
const Tensor* label = ctx.Input<Tensor>("Label");
113+
Tensor* y = ctx.Output<Tensor>("Y");
94114

95-
auto* x_data = x->data<T>();
96-
y->mutable_data<T>(ctx.GetPlace());
97-
auto* y_data = y->data<T>();
115+
const T* x_data = x->data<T>();
116+
T* y_data = y->mutable_data<T>(ctx.GetPlace());
98117

99-
int n = x->dims()[0];
100-
int d = x->dims()[1];
101-
int block = 512;
102-
int grid = (n + block - 1) / block;
103-
// TODO(qingqing) launch kernel on specified stream
104-
// base on ExecutionContext.
105-
if (ctx.Attr<bool>("soft_label")) {
118+
int batch_size = x->dims()[0];
119+
int class_num = x->dims()[1];
120+
121+
if (ctx.Attr<bool>("softLabel")) {
106122
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
107-
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
108-
d);
123+
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
124+
125+
SoftCrossEntropyKernel<
126+
T><<<batch_size, block, block * sizeof(T),
127+
reinterpret_cast<const platform::CUDADeviceContext&>(
128+
ctx.device_context())
129+
.stream()>>>(y_data, x_data, label_data, class_num);
109130
} else {
110131
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
111-
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
132+
int block = 512;
133+
int grid = (batch_size + block - 1) / block;
134+
CrossEntropyKernel<T><<<
135+
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
136+
ctx.device_context())
137+
.stream()>>>(y_data, x_data, label_data,
138+
batch_size, class_num);
112139
}
113140
}
114141
};
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
118145
public:
119146
void Compute(const framework::ExecutionContext& ctx) const override {
120147
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
121-
"It must use GPUPlace.");
148+
"This kernel only runs on GPU device.");
149+
150+
const Tensor* x = ctx.Input<Tensor>("X");
151+
const Tensor* label = ctx.Input<Tensor>("Label");
152+
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
122153

123-
auto x = ctx.Input<Tensor>("X");
124-
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
125-
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
126-
auto label = ctx.Input<Tensor>("Label");
154+
const T* dy_data =
155+
ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
156+
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
157+
const T* x_data = x->data<T>();
127158

128-
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
129-
auto* dy_data = dy->data<T>();
130-
auto* x_data = x->data<T>();
159+
int batch_size = x->dims()[0];
160+
int class_num = x->dims()[1];
131161

132-
int n = x->dims()[0];
133-
int d = x->dims()[1];
134162
int block = 512;
135-
int grid = (n * d + block - 1) / block;
136-
zero<T><<<grid, block>>>(dx_data, n * d);
137-
grid = (n + block - 1) / block;
138-
// TODO(qingqing): launch kernel on specified stream
139-
// base on ExecutionContext.
140-
if (ctx.Attr<bool>("soft_label")) {
163+
int grid = (batch_size * class_num + block - 1) / block;
164+
165+
if (ctx.Attr<bool>("softLabel")) {
141166
auto* label_data = label->data<T>();
142-
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
143-
dx_data, dy_data, x_data, label_data, n, d);
167+
SoftCrossEntropyGradientKernel<T><<<
168+
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
169+
ctx.device_context())
170+
.stream()>>>(dx_data, dy_data, x_data, label_data,
171+
batch_size, class_num);
144172
} else {
173+
Zero<T><<<grid, block, 0,
174+
reinterpret_cast<const platform::CUDADeviceContext&>(
175+
ctx.device_context())
176+
.stream()>>>(dx_data, batch_size * class_num);
177+
145178
auto* label_data = label->data<int>();
146-
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
147-
label_data, n, d);
179+
grid = (batch_size + block - 1) / block;
180+
CrossEntropyGradientKernel<T><<<
181+
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
182+
ctx.device_context())
183+
.stream()>>>(dx_data, dy_data, x_data, label_data,
184+
batch_size, class_num);
148185
}
149186
}
150187
};

0 commit comments

Comments
 (0)