Skip to content

Commit ebd992e

Browse files
committed
backpropagate gradients the CRF operator receives.
1 parent 2ac9a3d commit ebd992e

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

paddle/operators/linear_chain_crf_op.h

+17-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ static inline T NormalizeL1(T* x, size_t len) {
3535
return sum;
3636
}
3737

38+
template <typename T>
39+
struct ScalarMul {
40+
explicit ScalarMul(const T& scalar) : scalar(scalar) {}
41+
T operator()(const T& val) const { return val * scalar; }
42+
43+
T scalar;
44+
};
45+
3846
using framework::LoDTensor;
3947
using framework::LoD;
4048
using framework::Tensor;
@@ -349,8 +357,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
349357
// data reader operator, it can have no gradients.
350358
PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null.");
351359
emission_grad->mutable_data<T>(platform::CPUPlace());
352-
math::SetConstant<platform::CPUPlace, T>()(ctx.device_context(),
353-
emission_grad, 0.);
354360
if (transition_grad) {
355361
transition_grad->mutable_data<T>(platform::CPUPlace());
356362
math::SetConstant<platform::CPUPlace, T>()(ctx.device_context(),
@@ -480,24 +486,27 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
480486
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
481487
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
482488
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
483-
x_grad_mat.device(*place) = prob / row_sum;
489+
x_grad_mat.device(*place) =
490+
(prob / row_sum).unaryExpr(ScalarMul<T>(ll_grad));
484491

485492
for (size_t k = 0; k < seq_length; ++k) {
486-
x_grad_mat(k, label_value[k]) -= static_cast<T>(1.);
493+
x_grad_mat(k, label_value[k]) -= static_cast<T>(ll_grad);
487494
}
488495

489496
if (transition_grad) {
490497
T* trans_grad = transition_grad->data<T>();
491498
for (size_t k = 0; k < tag_num; ++k) {
499+
// Do not multiply by the output gradient here, because x_grad_mat has
500+
// alrealy done this.
492501
trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
493502
trans_grad[tag_num + k] +=
494503
x_grad_mat(/*to end state*/ seq_length - 1, k);
495504
}
496505

497506
auto x_exps_mat = EigenMatrix<T>::From(emission_exps);
498507

499-
// TODO(caoying): Fix this to avoid using this local variable if when can
500-
// profiling the training process.
508+
// TODO(caoying): Fix this to avoid using this local variable if we can
509+
// profile the training process.
501510
Tensor tmp;
502511
tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
503512
auto tmp_mat = EigenMatrix<T>::From(tmp);
@@ -520,11 +529,11 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
520529
for (size_t j = 0; j < tag_num; ++j) {
521530
trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
522531
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
523-
alpha_mat(k - 1, i) * tmp_mat(k, j);
532+
alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad;
524533
}
525534
}
526535
trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
527-
label_value[k]] -= static_cast<T>(1.);
536+
label_value[k]] -= static_cast<T>(ll_grad);
528537
}
529538
}
530539
}

0 commit comments

Comments
 (0)