@@ -35,6 +35,14 @@ static inline T NormalizeL1(T* x, size_t len) {
35
35
return sum;
36
36
}
37
37
38
+ template <typename T>
39
+ struct ScalarMul {
40
+ explicit ScalarMul (const T& scalar) : scalar(scalar) {}
41
+ T operator ()(const T& val) const { return val * scalar; }
42
+
43
+ T scalar;
44
+ };
45
+
38
46
using framework::LoDTensor;
39
47
using framework::LoD;
40
48
using framework::Tensor;
@@ -349,8 +357,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
349
357
// data reader operator, it can have no gradients.
350
358
PADDLE_ENFORCE (emission_grad, " Output(Emission@Grad) should not be null." );
351
359
emission_grad->mutable_data <T>(platform::CPUPlace ());
352
- math::SetConstant<platform::CPUPlace, T>()(ctx.device_context (),
353
- emission_grad, 0 .);
354
360
if (transition_grad) {
355
361
transition_grad->mutable_data <T>(platform::CPUPlace ());
356
362
math::SetConstant<platform::CPUPlace, T>()(ctx.device_context (),
@@ -480,24 +486,27 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
480
486
auto row_sum = prob.sum (Eigen::DSizes<int , 1 >(1 ))
481
487
.reshape (Eigen::DSizes<int , 2 >(seq_length, 1 ))
482
488
.broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
483
- x_grad_mat.device (*place) = prob / row_sum;
489
+ x_grad_mat.device (*place) =
490
+ (prob / row_sum).unaryExpr (ScalarMul<T>(ll_grad));
484
491
485
492
for (size_t k = 0 ; k < seq_length; ++k) {
486
- x_grad_mat (k, label_value[k]) -= static_cast <T>(1 . );
493
+ x_grad_mat (k, label_value[k]) -= static_cast <T>(ll_grad );
487
494
}
488
495
489
496
if (transition_grad) {
490
497
T* trans_grad = transition_grad->data <T>();
491
498
for (size_t k = 0 ; k < tag_num; ++k) {
499
+ // Do not multiply by the output gradient here, because x_grad_mat has
500
+ // alrealy done this.
492
501
trans_grad[k] += x_grad_mat (/* from start state*/ 0 , k);
493
502
trans_grad[tag_num + k] +=
494
503
x_grad_mat (/* to end state*/ seq_length - 1 , k);
495
504
}
496
505
497
506
auto x_exps_mat = EigenMatrix<T>::From (emission_exps);
498
507
499
- // TODO(caoying): Fix this to avoid using this local variable if when can
500
- // profiling the training process.
508
+ // TODO(caoying): Fix this to avoid using this local variable if we can
509
+ // profile the training process.
501
510
Tensor tmp;
502
511
tmp.mutable_data <T>(beta->dims (), platform::CPUPlace ());
503
512
auto tmp_mat = EigenMatrix<T>::From (tmp);
@@ -520,11 +529,11 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
520
529
for (size_t j = 0 ; j < tag_num; ++j) {
521
530
trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
522
531
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
523
- alpha_mat (k - 1 , i) * tmp_mat (k, j);
532
+ alpha_mat (k - 1 , i) * tmp_mat (k, j) * ll_grad ;
524
533
}
525
534
}
526
535
trans_grad[(label_value[k - 1 ] + state_trans_base_idx) * tag_num +
527
- label_value[k]] -= static_cast <T>(1 . );
536
+ label_value[k]] -= static_cast <T>(ll_grad );
528
537
}
529
538
}
530
539
}
0 commit comments