Skip to content

There is bug in softmax_with_cross_entropy_op backward. #9119

@qingqing01

Description

@qingqing01

Now the code is as follows:

template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
                                 const int64_t* labels, const int batch_size,
                                 const int class_num) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int sample_idx = tid / class_num;

  if (tid < batch_size) {
    PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
    logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
  }

  __syncthreads();

  if (tid < batch_size * class_num) {
    logit_grad[tid] *= loss_grad[sample_idx];
  }
}

Actually, this code needs a global synchronization, not __syncthreads(), which is a thread barrier in one block.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions