Now the code is as follows:
template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
const int64_t* labels, const int batch_size,
const int class_num) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = tid / class_num;
if (tid < batch_size) {
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
}
__syncthreads();
if (tid < batch_size * class_num) {
logit_grad[tid] *= loss_grad[sample_idx];
}
}
Actually, this code needs a global synchronization, not __syncthreads(), which is a thread barrier in one block.