Skip to content

Commit 8b8ad6b

Browse files
committed
fix implementations of supporting soft labels.
1 parent bb58b63 commit 8b8ad6b

File tree

9 files changed

+272
-102
lines changed

9 files changed

+272
-102
lines changed

paddle/operators/cross_entropy_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
2828
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
2929
i += blockDim.x * gridDim.x) {
3030
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
31-
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
31+
Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
3232
}
3333
}
3434

@@ -39,7 +39,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
3939
i += blockDim.x * gridDim.x) {
4040
T sum = static_cast<T>(0);
4141
for (int j = 0; j < D; j++) {
42-
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
42+
sum += label[i * D + j] * TolerableValue<T>()(log(X[i * D + j]));
4343
}
4444
Y[i] = -sum;
4545
}

paddle/operators/cross_entropy_op.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ namespace operators {
2222
using Tensor = framework::Tensor;
2323

2424
template <typename T>
25-
HOSTDEVICE T tolerable_value(const T x) {
26-
PADDLE_ASSERT(std::is_floating_point<T>::value);
27-
const T kApproInf = 1e20;
28-
if (x == INFINITY) {
29-
return kApproInf;
25+
struct TolerableValue {
26+
HOSTDEVICE T operator()(const T& x) const {
27+
PADDLE_ASSERT(std::is_floating_point<T>::value);
28+
const T kApproInf = 1e20;
29+
30+
if (x == INFINITY) return kApproInf;
31+
if (x == -INFINITY) return -kApproInf;
32+
return x;
3033
}
31-
if (x == -INFINITY) {
32-
return -kApproInf;
33-
}
34-
return x;
35-
}
34+
};
3635

3736
template <typename T>
3837
class CrossEntropyOpKernel : public framework::OpKernel {
@@ -57,7 +56,8 @@ class CrossEntropyOpKernel : public framework::OpKernel {
5756
for (int i = 0; i < batch_size; ++i) {
5857
T sum = static_cast<T>(0);
5958
for (int j = 0; j < class_num; ++j) {
60-
sum += label_data[index] * tolerable_value(std::log(x_data[index]));
59+
sum +=
60+
label_data[index] * TolerableValue<T>()(std::log(x_data[index]));
6161
y_data[i] = -sum;
6262
index++;
6363
}
@@ -66,7 +66,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
6666
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
6767
for (int i = 0; i < batch_size; ++i) {
6868
int index = i * class_num + label_data[i];
69-
y_data[i] = -tolerable_value(std::log(x_data[index]));
69+
y_data[i] = -TolerableValue<T>()(std::log(x_data[index]));
7070
}
7171
}
7272
}

paddle/operators/math/softmax.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21-
template class SoftmaxFunctor<platform::CPUPlace, float>;
21+
template class SoftmaxFunctor<platform::GPUPlace, float>;
2222

2323
} // namespace math
2424
} // namespace operators

paddle/operators/math/softmax.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2828
template <typename Place, typename T>
2929
class SoftmaxFunctor {
3030
public:
31-
void operator()(const framework::Tensor* X, framework::Tensor* Y,
32-
const framework::ExecutionContext& context) {
31+
void operator()(const framework::ExecutionContext& context,
32+
const framework::Tensor* X, framework::Tensor* Y) {
3333
auto logits = EigenMatrix<T>::From(*X);
3434
auto softmax = EigenMatrix<T>::From(*Y);
3535

paddle/operators/softmax_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel {
3535
// allocate memory on device.
3636
Y->mutable_data<T>(context.GetPlace());
3737

38-
math::SoftmaxFunctor<Place, T>()(X, Y, context);
38+
math::SoftmaxFunctor<Place, T>()(context, X, Y);
3939
}
4040
};
4141

paddle/operators/softmax_with_cross_entropy_op.cc

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,31 @@ class SoftmaxWithCrossEntropyOpMaker
2323
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
2424
framework::OpAttrChecker* op_checker)
2525
: OpProtoAndCheckerMaker(proto, op_checker) {
26-
//(TODO caoying) replace int with boolean
27-
AddAttr<int>("soft_label",
28-
"(int, default 0), A flag to indicate whether to interpretate "
29-
"the given labels as soft labels.")
30-
.SetDefault(0);
26+
AddAttr<bool>(
27+
"softLabel",
28+
"(bool, default: false), A flag to indicate whether to interpretate "
29+
"the given labels as soft labels.")
30+
.SetDefault(false);
3131
AddInput("Logits",
32-
"(Tensor, default Tensor<float>), The unscaled log probabilities "
32+
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
3333
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
3434
"and K is the class number.")
3535
.NotInGradient();
3636
AddInput(
3737
"Label",
38-
"(Tensor, default Tensor<int>), The ground truth which is "
39-
"a 1-D or 2-D tensor. "
40-
"If soft_label is set to 0, Label is a Tensor<int> with shape [N x 1]. "
41-
"If soft_label is set to 1, Label is a Tensor<float/double> "
38+
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
39+
"tensor. "
40+
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
41+
"If softLable is set to 1, Label is a Tensor<float/double> "
4242
"with shape [N x K].");
4343
AddOutput(
4444
"Softmax",
45-
"(Tensor, default Tensor<float>), A 2-D tensor with shape [N x K]. "
45+
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
4646
"The outputs value of softmax activation by given the input batch, "
4747
"which will be used in backward calculation.")
4848
.AsIntermediate();
4949
AddOutput("Loss",
50-
"(Tensor, default Tensor<float>), A 1-D tensor. The cross "
50+
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
5151
"entropy loss with shape [N x 1].");
5252
AddComment(R"DOC(
5353
Cross entropy loss with softmax are used as the output layer extensively. This
@@ -83,15 +83,39 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
8383

8484
protected:
8585
void InferShape(const framework::InferShapeContext& ctx) const override {
86+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"),
87+
"Input(Logits) should be not null.");
88+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
89+
"Input(Label) should be not null.");
90+
91+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"),
92+
"Output(Softmax) should be not null.");
93+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
94+
"Output(Loss) should be not null.");
95+
8696
const Tensor* logits = ctx.Input<Tensor>("Logits");
97+
const Tensor* labels = ctx.Input<Tensor>("Label");
8798
PADDLE_ENFORCE(
8899
logits->dims().size() == 2UL,
89-
"The input of softmax_with_cross_entropy should be a 2-d tensor.");
90-
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 1UL,
91-
"The label should be a 1-d tensor.");
92-
93-
ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims());
94-
ctx.Output<framework::LoDTensor>("Loss")->Resize({logits->dims()[0], 1});
100+
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
101+
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
102+
"The labels should be a 2-D tensor.");
103+
104+
if (ctx.Attr<bool>("softLabel")) {
105+
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1],
106+
"If Attr(softLabel) == true, the 2nd dimension of "
107+
"Input(X) and Input(Label) should be equal.");
108+
} else {
109+
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
110+
"If Attr(softLabel) == false, the 2nd dimension of "
111+
"Input(Label) should be 1.");
112+
}
113+
114+
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims());
115+
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1});
116+
117+
ctx.ShareLoD("Logits", /*->*/ "Softmax");
118+
ctx.ShareLoD("Logits", /*->*/ "Loss");
95119
}
96120
};
97121

@@ -102,11 +126,28 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
102126
protected:
103127
void InferShape(const framework::InferShapeContext& ctx) const override {
104128
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
105-
"Input(Loss@Grad) should not be null");
129+
"Input(Loss@Grad) should not be null.");
106130
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
107131
"Input(Softmax) should be not null.");
108132
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
109133
"Input(Lable) should be not null.");
134+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
135+
"Output(Logits@Grad) should be not null.");
136+
137+
const Tensor* softmax = ctx.Input<Tensor>("Softmax");
138+
const Tensor* labels = ctx.Input<Tensor>("Label");
139+
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
140+
"The labels should be a 2-D tensor.");
141+
142+
if (ctx.Attr<bool>("softLabel")) {
143+
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1],
144+
"When Attr(softLabel) == true, the 2nd dimension of "
145+
"Input(X) and Input(Label) should be equal.");
146+
} else {
147+
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
148+
"When Attr(softLabel) == false, the 2nd dimension of "
149+
"Input(Label) should be 1.");
150+
}
110151

111152
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
112153
->Resize(ctx.Input<Tensor>("Softmax")->dims());

paddle/operators/softmax_with_cross_entropy_op.cu

Lines changed: 108 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,78 @@ namespace operators {
2424
using Tensor = framework::Tensor;
2525

2626
template <typename T>
27-
__global__ void CrossEntropyKernel(T* out, const T* softmax_out,
28-
const int* label, const int batch_size,
29-
const int class_num) {
27+
__global__ void CrossEntropy(T* out, const T* softmax_out, const int* labels,
28+
const int batch_size, const int class_num) {
3029
int i = blockIdx.x * blockDim.x + threadIdx.x;
3130
if (i < batch_size) {
32-
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
33-
out[i] = -tolerable_value(std::log(softmax_out[i * class_num + label[i]]));
31+
PADDLE_ASSERT(labels[i] >= 0 && labels[i] < class_num);
32+
out[i] =
33+
-TolerableValue<T>()(std::log(softmax_out[i * class_num + labels[i]]));
3434
}
3535
}
3636

3737
template <typename T>
38-
__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
39-
const int* label,
40-
const int batch_size,
41-
const int class_num) {
42-
int i = blockIdx.x * blockDim.x + threadIdx.x;
43-
if (i < batch_size) {
44-
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
45-
softmax_out[i * class_num + label[i]] -= 1.;
38+
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
39+
const int* labels, const int batch_size,
40+
const int class_num) {
41+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
42+
int sample_idx = tid / class_num;
43+
44+
if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
45+
__syncthreads();
46+
47+
if (tid < batch_size) {
48+
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
49+
out_grad[tid * class_num + labels[tid]] -= 1.;
50+
}
51+
}
52+
53+
template <typename T>
54+
__device__ __forceinline__ T sum_single_warp(T val) {
55+
val += __shfl_down(val, 16);
56+
val += __shfl_down(val, 8);
57+
val += __shfl_down(val, 4);
58+
val += __shfl_down(val, 2);
59+
val += __shfl_down(val, 1);
60+
return val;
61+
}
62+
63+
template <typename T>
64+
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
65+
const int class_num) {
66+
int tid = threadIdx.x;
67+
extern __shared__ T d_sum[];
68+
d_sum[tid] = 0;
69+
70+
int cur_idx = tid;
71+
int next_idx = blockIdx.x * class_num + tid;
72+
while (cur_idx < class_num) {
73+
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
74+
next_idx += blockDim.x;
75+
cur_idx += blockDim.x;
76+
}
77+
__syncthreads();
78+
79+
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
80+
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
81+
__syncthreads();
82+
}
83+
84+
T val = d_sum[tid];
85+
val = sum_single_warp<T>(val);
86+
if (tid == 0) Y[blockIdx.x] = -val;
87+
}
88+
89+
template <typename T>
90+
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
91+
const T* loss_grad,
92+
const T* labels,
93+
const int batch_size,
94+
const int class_num) {
95+
int ids = blockIdx.x * blockDim.x + threadIdx.x;
96+
if (ids < batch_size * class_num) {
97+
int row_ids = ids / class_num;
98+
logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
4699
}
47100
}
48101

@@ -52,27 +105,36 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
52105
void Compute(const framework::ExecutionContext& context) const override {
53106
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
54107
"This kernel only runs on GPU device.");
108+
T* loss_data =
109+
context.Output<Tensor>("Loss")->mutable_data<T>(context.GetPlace());
55110

56-
// Calculate ths softmax outputs.
57111
const Tensor* logits = context.Input<Tensor>("Logits");
58112
Tensor* softmax = context.Output<Tensor>("Softmax");
59-
softmax->mutable_data<T>(context.GetPlace());
60-
math::SoftmaxFunctor<platform::GPUPlace, T>()(logits, softmax, context);
61-
T* softmax_out = softmax->data<T>();
62-
63-
// Calculate the cross entropy loss based on hard labels.
64-
const int* label_data = context.Input<Tensor>("Label")->data<int>();
65-
Tensor* loss = context.Output<Tensor>("Loss");
66-
loss->mutable_data<T>(context.GetPlace());
67-
T* loss_data = loss->data<T>();
113+
T* softmax_out = softmax->mutable_data<T>(context.GetPlace());
114+
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
68115

69116
const int batch_size = logits->dims()[0];
70117
const int class_num = logits->dims()[1];
71118
int block = 512;
72119
int grid = (batch_size + block - 1) / block;
73120

74-
CrossEntropyKernel<T><<<grid, block>>>(loss_data, softmax_out, label_data,
75-
batch_size, class_num);
121+
if (context.Attr<bool>("softLabel")) {
122+
const T* label_data = context.Input<Tensor>("Label")->data<T>();
123+
block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
124+
125+
SoftCrossEntropyKernel<
126+
T><<<batch_size, block, block * sizeof(T),
127+
reinterpret_cast<const platform::CUDADeviceContext&>(
128+
context.device_context())
129+
.stream()>>>(loss_data, softmax_out, label_data, class_num);
130+
} else {
131+
const int* label_data = context.Input<Tensor>("Label")->data<int>();
132+
CrossEntropy<T><<<grid, block, 0,
133+
reinterpret_cast<const platform::CUDADeviceContext&>(
134+
context.device_context())
135+
.stream()>>>(loss_data, softmax_out, label_data,
136+
batch_size, class_num);
137+
}
76138
}
77139
};
78140

@@ -82,22 +144,34 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
82144
void Compute(const framework::ExecutionContext& context) const override {
83145
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
84146
"This kernel only runs on GPU device.");
85-
147+
const Tensor* labels = context.Input<Tensor>("Label");
148+
const T* loss_grad_data =
149+
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
86150
Tensor* logit_grad =
87151
context.Output<Tensor>(framework::GradVarName("Logits"));
88152
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
89153
T* logit_grad_data = logit_grad->data<T>();
90154

91155
const int batch_size = logit_grad->dims()[0];
92156
const int class_num = logit_grad->dims()[1];
93-
94-
const int* label_data = context.Input<Tensor>("Label")->data<int>();
95-
96-
const int block = 512;
97-
const int grid = (batch_size + block - 1) / block;
98-
99-
CrossEntropyWithSoftmaxGradKernel<T><<<grid, block>>>(
100-
logit_grad_data, label_data, batch_size, class_num);
157+
int block = 512;
158+
int grid = (batch_size * class_num + block - 1) / block;
159+
160+
if (context.Attr<bool>("softLabel")) {
161+
const T* label_data = labels->data<T>();
162+
SoftCrossEntropyGradientKernel<T><<<
163+
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
164+
context.device_context())
165+
.stream()>>>(logit_grad_data, loss_grad_data,
166+
label_data, batch_size, class_num);
167+
} else {
168+
const int* label_data = labels->data<int>();
169+
CrossEntropyGrad<T><<<
170+
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
171+
context.device_context())
172+
.stream()>>>(logit_grad_data, loss_grad_data,
173+
label_data, batch_size, class_num);
174+
}
101175
}
102176
};
103177

0 commit comments

Comments
 (0)