@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
28
28
for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < N;
29
29
i += blockDim .x * gridDim .x ) {
30
30
PADDLE_ASSERT (label[i] >= 0 && label[i] < D);
31
- Y[i] = -tolerable_value (log (X[i * D + label[i]]));
31
+ Y[i] = -TolerableValue<T>() (log (X[i * D + label[i]]));
32
32
}
33
33
}
34
34
35
+ template <typename T>
36
+ __device__ __forceinline__ T sum_single_warp (T val) {
37
+ val += __shfl_down (val, 16 );
38
+ val += __shfl_down (val, 8 );
39
+ val += __shfl_down (val, 4 );
40
+ val += __shfl_down (val, 2 );
41
+ val += __shfl_down (val, 1 );
42
+ return val;
43
+ }
44
+
35
45
template <typename T>
36
46
__global__ void SoftCrossEntropyKernel (T* Y, const T* X, const T* label,
37
- const int N, const int D) {
38
- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < N;
39
- i += blockDim .x * gridDim .x ) {
40
- T sum = static_cast <T>(0 );
41
- for (int j = 0 ; j < D; j++) {
42
- sum += label[i * D + j] * tolerable_value (log (X[i * D + j]));
43
- }
44
- Y[i] = -sum;
47
+ const int class_num) {
48
+ int tid = threadIdx .x ;
49
+ extern __shared__ T d_sum[];
50
+ d_sum[tid] = 0 ;
51
+
52
+ int cur_idx = tid;
53
+ int next_idx = blockIdx .x * class_num + tid;
54
+ while (cur_idx < class_num) {
55
+ d_sum[tid] += TolerableValue<T>()(std::log (X[next_idx])) * label[next_idx];
56
+ next_idx += blockDim .x ;
57
+ cur_idx += blockDim .x ;
58
+ }
59
+ __syncthreads ();
60
+
61
+ for (unsigned int stride = blockDim .x >> 1 ; stride >= 32 ; stride >>= 1 ) {
62
+ if (tid < stride) d_sum[tid] += d_sum[tid + stride];
63
+ __syncthreads ();
45
64
}
65
+
66
+ T val = d_sum[tid];
67
+ val = sum_single_warp<T>(val);
68
+ if (tid == 0 ) Y[blockIdx .x ] = -val;
46
69
}
47
70
48
- // TODO(qingqing): make zero setting an common function.
71
+ // TODO(qingqing): make zero setting a common function.
49
72
template <typename T>
50
- __global__ void zero (T* X, const int N) {
73
+ __global__ void Zero (T* X, const int N) {
51
74
for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < N;
52
75
i += blockDim .x * gridDim .x ) {
53
76
X[i] = 0.0 ;
@@ -71,13 +94,10 @@ template <typename T>
71
94
__global__ void SoftCrossEntropyGradientKernel (T* dX, const T* dY, const T* X,
72
95
const T* label, const int N,
73
96
const int D) {
74
- // TOOD(qingqing): optimize for this kernel
75
- for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < N;
76
- i += blockDim .x * gridDim .x ) {
77
- for (int j = 0 ; j < D; ++j) {
78
- int idx = i * D + j;
79
- dX[idx] = -label[idx] * dY[i] / X[idx];
80
- }
97
+ int ids = blockIdx .x * blockDim .x + threadIdx .x ;
98
+ if (ids < N * D) {
99
+ int row_ids = ids / D;
100
+ dX[ids] = -label[ids] * dY[row_ids] / X[ids];
81
101
}
82
102
}
83
103
@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
86
106
public:
87
107
void Compute (const framework::ExecutionContext& ctx) const override {
88
108
PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
89
- " It must use GPUPlace ." );
109
+ " This kernel only runs on GPU device ." );
90
110
91
- auto x = ctx.Input <Tensor>(" X" );
92
- auto y = ctx.Output <Tensor>(" Y " );
93
- auto label = ctx.Input <Tensor>(" Label " );
111
+ const Tensor* x = ctx.Input <Tensor>(" X" );
112
+ const Tensor* label = ctx.Input <Tensor>(" Label " );
113
+ Tensor* y = ctx.Output <Tensor>(" Y " );
94
114
95
- auto * x_data = x->data <T>();
96
- y->mutable_data <T>(ctx.GetPlace ());
97
- auto * y_data = y->data <T>();
115
+ const T* x_data = x->data <T>();
116
+ T* y_data = y->mutable_data <T>(ctx.GetPlace ());
98
117
99
- int n = x->dims ()[0 ];
100
- int d = x->dims ()[1 ];
101
- int block = 512 ;
102
- int grid = (n + block - 1 ) / block;
103
- // TODO(qingqing) launch kernel on specified stream
104
- // base on ExecutionContext.
105
- if (ctx.Attr <bool >(" soft_label" )) {
118
+ int batch_size = x->dims ()[0 ];
119
+ int class_num = x->dims ()[1 ];
120
+
121
+ if (ctx.Attr <bool >(" softLabel" )) {
106
122
auto * label_data = ctx.Input <Tensor>(" Label" )->data <T>();
107
- SoftCrossEntropyKernel<T><<<grid, block>>> (y_data, x_data, label_data, n,
108
- d);
123
+ int block = class_num > 512 ? 512 : pow (2 , int (std::log2 (class_num)));
124
+
125
+ SoftCrossEntropyKernel<
126
+ T><<<batch_size, block, block * sizeof (T),
127
+ reinterpret_cast <const platform::CUDADeviceContext&>(
128
+ ctx.device_context())
129
+ .stream()>>> (y_data, x_data, label_data, class_num);
109
130
} else {
110
131
auto * label_data = ctx.Input <Tensor>(" Label" )->data <int >();
111
- CrossEntropyKernel<T><<<grid, block>>> (y_data, x_data, label_data, n, d);
132
+ int block = 512 ;
133
+ int grid = (batch_size + block - 1 ) / block;
134
+ CrossEntropyKernel<T><<<
135
+ grid, block, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
136
+ ctx.device_context())
137
+ .stream()>>> (y_data, x_data, label_data,
138
+ batch_size, class_num);
112
139
}
113
140
}
114
141
};
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
118
145
public:
119
146
void Compute (const framework::ExecutionContext& ctx) const override {
120
147
PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
121
- " It must use GPUPlace." );
148
+ " This kernel only runs on GPU device." );
149
+
150
+ const Tensor* x = ctx.Input <Tensor>(" X" );
151
+ const Tensor* label = ctx.Input <Tensor>(" Label" );
152
+ Tensor* dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
122
153
123
- auto x = ctx. Input <Tensor>( " X " );
124
- auto dx = ctx.Output <Tensor>(framework::GradVarName (" X " ) );
125
- auto dy = ctx. Input <Tensor>( framework::GradVarName ( " Y " ));
126
- auto label = ctx. Input <Tensor>( " Label " );
154
+ const T* dy_data =
155
+ ctx.Input <Tensor>(framework::GradVarName (" Y " ))-> data <T>( );
156
+ T* dx_data = dx-> mutable_data <T>(ctx. GetPlace ( ));
157
+ const T* x_data = x-> data <T>( );
127
158
128
- auto * dx_data = dx->mutable_data <T>(ctx.GetPlace ());
129
- auto * dy_data = dy->data <T>();
130
- auto * x_data = x->data <T>();
159
+ int batch_size = x->dims ()[0 ];
160
+ int class_num = x->dims ()[1 ];
131
161
132
- int n = x->dims ()[0 ];
133
- int d = x->dims ()[1 ];
134
162
int block = 512 ;
135
- int grid = (n * d + block - 1 ) / block;
136
- zero<T><<<grid, block>>> (dx_data, n * d);
137
- grid = (n + block - 1 ) / block;
138
- // TODO(qingqing): launch kernel on specified stream
139
- // base on ExecutionContext.
140
- if (ctx.Attr <bool >(" soft_label" )) {
163
+ int grid = (batch_size * class_num + block - 1 ) / block;
164
+
165
+ if (ctx.Attr <bool >(" softLabel" )) {
141
166
auto * label_data = label->data <T>();
142
- SoftCrossEntropyGradientKernel<T><<<grid, block>>> (
143
- dx_data, dy_data, x_data, label_data, n, d);
167
+ SoftCrossEntropyGradientKernel<T><<<
168
+ grid, block, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
169
+ ctx.device_context())
170
+ .stream()>>> (dx_data, dy_data, x_data, label_data,
171
+ batch_size, class_num);
144
172
} else {
173
+ Zero<T><<<grid, block, 0 ,
174
+ reinterpret_cast <const platform::CUDADeviceContext&>(
175
+ ctx.device_context())
176
+ .stream()>>> (dx_data, batch_size * class_num);
177
+
145
178
auto * label_data = label->data <int >();
146
- CrossEntropyGradientKernel<T><<<grid, block>>> (dx_data, dy_data, x_data,
147
- label_data, n, d);
179
+ grid = (batch_size + block - 1 ) / block;
180
+ CrossEntropyGradientKernel<T><<<
181
+ grid, block, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
182
+ ctx.device_context())
183
+ .stream()>>> (dx_data, dy_data, x_data, label_data,
184
+ batch_size, class_num);
148
185
}
149
186
}
150
187
};
0 commit comments