Skip to content

Commit 41a5220

Browse files
authored
[Bug fixes]fix bce error when option pos_weight given (#65859)
* fix bce and bce grad with option pos_weight
1 parent 147a8d2 commit 41a5220

7 files changed

+66
-45
lines changed

paddle/fluid/primitive/composite/composite.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,15 +1340,19 @@ Tensor sigmoid_cross_entropy_with_logits_decomp(
13401340
const Tensor zero = full_like_decomp<T>(x, 0, x.type(), x.place());
13411341
const Tensor one = full_like_decomp<T>(x, 1, x.type(), x.place());
13421342
Tensor pos_weight_tensor;
1343+
Tensor tmp_out;
13431344
if (pos_weight) {
13441345
pos_weight_tensor = pos_weight.get();
1346+
auto max_val = where<T>(x < zero, -x, zero);
1347+
auto term1 = (one - label) * x;
1348+
auto term2 = log<T>(exp<T>(-max_val) + exp<T>(-x - max_val));
1349+
tmp_out = term1 + pos_weight_tensor * (term2 + max_val);
13451350
} else {
1346-
pos_weight_tensor = one;
1351+
auto term1 = where<T>(x > zero, x, zero);
1352+
auto term2 = x * label;
1353+
auto term3 = log<T>(one + exp<T>(-abs<T>(x)));
1354+
tmp_out = term1 - term2 + term3;
13471355
}
1348-
auto term1 = where<T>(x > zero, x, zero);
1349-
auto term2 = x * label;
1350-
auto term3 = log<T>(one + exp<T>(-abs<T>(x)));
1351-
const Tensor tmp_out = term1 - term2 + term3 * pos_weight_tensor;
13521356
const Tensor ignore_index_tensor =
13531357
full_like_decomp<T>(x, ignore_index, label.type(), label.place());
13541358
auto out = where<T>(label == ignore_index_tensor, zero, tmp_out);

paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_grad_kernel.cc

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,30 @@ void SigmoidCrossEntropyWithLogitsGradKernel(
4343
T x = x_data[idx];
4444
T label = label_data[idx];
4545
T dout = dout_data[idx];
46-
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
4746
if (static_cast<int>(label) == ignore_index) {
4847
dx_data[idx] = static_cast<T>(0.);
4948
} else {
50-
T term1 = (x > 0) ? static_cast<T>(1) : static_cast<T>(0);
49+
if (pos_weight_data == nullptr) {
50+
T term1 = (x > 0) ? static_cast<T>(1) : static_cast<T>(0);
51+
T e_x = std::exp(-std::abs(x));
52+
T down = 1 + e_x;
53+
T abs_grad = (x >= 0) ? static_cast<T>(1) : static_cast<T>(-1);
54+
T up = -e_x * abs_grad;
55+
T term3 = up / down;
5156

52-
T e_x = std::exp(-std::abs(x));
53-
T down = 1 + e_x;
54-
T abs_grad = (x >= 0) ? static_cast<T>(1) : static_cast<T>(-1);
55-
T up = -e_x * abs_grad * pos_weight_idx;
56-
T term3 = up / down;
57-
58-
T diff = term1 - label + term3;
59-
dx_data[idx] = dout * diff;
57+
T diff = term1 - label + term3;
58+
dx_data[idx] = dout * diff;
59+
} else {
60+
T max_val = x < 0 ? -x : 0;
61+
T term1 = (x < 0) ? static_cast<T>(-1) : static_cast<T>(0);
62+
T down1 = std::exp(-max_val);
63+
T down2 = std::exp(-x - max_val);
64+
T term2 = down1 * (-term1) + down2 * (-1 - term1);
65+
T term3 = (static_cast<T>(1.) - label);
66+
T diff =
67+
pos_weight_data[idx] * (term2 / (down1 + down2) + term1) + term3;
68+
dx_data[idx] = dout * diff;
69+
}
6070
}
6171
}
6272
if (normalize) {

paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_kernel.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,17 @@ void SigmoidCrossEntropyWithLogitsKernel(
4545
if (static_cast<int>(label) == ignore_index) {
4646
out_data[idx] = static_cast<T>(0.);
4747
} else {
48-
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
49-
T term1 = (x > 0) ? x : 0;
50-
T term2 = x * label;
51-
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
52-
out_data[idx] = term1 - term2 + term3 * pos_weight_idx;
48+
if (pos_weight_data == nullptr) {
49+
T term1 = (x > 0) ? x : 0;
50+
T term2 = x * label;
51+
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
52+
out_data[idx] = term1 - term2 + term3;
53+
} else {
54+
T max_val = x < 0 ? -x : 0;
55+
T term1 = (static_cast<T>(1.) - label) * x;
56+
T term2 = std::log(std::exp(-max_val) + std::exp(-x - max_val));
57+
out_data[idx] = term1 + pos_weight_data[idx] * (term2 + max_val);
58+
}
5359
}
5460
}
5561

paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ struct SigmoidBwdPosWeightFunctor {
7373
dx_data = static_cast<T>(0.);
7474
counts = 0;
7575
} else {
76-
T term1 = (x > 0) ? static_cast<T>(1) : static_cast<T>(0);
77-
T e_x = phi::funcs::real_exp(-abs(x));
78-
T down = 1 + e_x;
79-
T abs_grad = (x >= 0) ? static_cast<T>(1) : static_cast<T>(-1);
80-
T up = -e_x * abs_grad * pos_weight;
81-
T term3 = up / down;
82-
83-
T diff = term1 - label + term3;
76+
T max_val = x < 0 ? -x : 0;
77+
T term1 = (x < 0) ? static_cast<T>(-1) : static_cast<T>(0);
78+
T down1 = phi::funcs::real_exp(-max_val);
79+
T down2 = phi::funcs::real_exp(-x - max_val);
80+
T term2 = down1 * (-term1) + down2 * (-1 - term1);
81+
T term3 = (static_cast<T>(1.) - label);
82+
T diff = pos_weight * (term2 / (down1 + down2) + term1) + term3;
8483
dx_data = dout * diff;
84+
8585
counts = 1;
8686
}
8787
phi::Array<T, 2> outs;

paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_kernel.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,12 @@ struct SigmoidFwdPosWeightFunctor {
7272
out_data = static_cast<T>(0.);
7373
counts = 0;
7474
} else {
75-
T term1 = (x > 0) ? x : 0;
76-
T term2 = x * label;
77-
T term3 =
78-
phi::funcs::real_log(static_cast<T>(1) +
79-
phi::funcs::real_exp(static_cast<T>(-abs(x)))) *
80-
pos_weight;
75+
T max_val = x < 0 ? -x : 0;
76+
T term1 = (static_cast<T>(1.) - label) * x;
77+
T term2 = phi::funcs::real_log(phi::funcs::real_exp(-max_val) +
78+
phi::funcs::real_exp(-x - max_val));
79+
out_data = term1 + pos_weight * (term2 + max_val);
8180

82-
out_data = term1 - term2 + term3;
8381
counts = 1;
8482
}
8583
phi::Array<T, 2> outs;

test/deprecated/legacy_test/test_sigmoid_cross_entropy_with_logits_op.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ def setUp(self):
173173

174174
# Fw Pass is implemented as elementwise sigmoid followed by
175175
# elementwise logistic loss
176-
term1 = np.maximum(self.inputs['X'], 0)
177-
term2 = self.inputs['X'] * self.inputs['Label']
178-
term3 = np.log(1 + np.exp(-1 * np.abs(self.inputs['X']))) * pos_weight
179-
self.outputs = {'Out': term1 - term2 + term3}
176+
max_val = np.clip(-self.inputs['X'], 0, np.finfo(np.float64).max)
177+
term1 = (1 - label) * self.inputs['X']
178+
term2 = np.log(np.exp(-max_val) + np.exp(-self.inputs['X'] - max_val))
179+
out = term1 + pos_weight * (term2 + max_val)
180+
181+
self.outputs = {'Out': out}
180182

181183
def test_check_output(self):
182184
self.check_output(check_pir=True, check_prim_pir=True)

test/legacy_test/test_sigmoid_cross_entropy_with_logits_grad_with_auto_grad.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ def fn_ref(x, label, weight):
5656

5757
def fn_comp(x, label, weight):
5858
zeros = paddle.full((self.batch_size, self.num_classes), 0.0)
59-
t1 = paddle.where(x > 0, x, zeros)
60-
t2 = x * label
61-
t3 = paddle.log(1 + paddle.exp(-paddle.abs(x)))
62-
t4 = t1 - t2 + t3 * weight
63-
t5 = paddle.full((self.batch_size, self.num_classes), -100.0)
64-
out = paddle.where(label == t5, zeros, t4)
59+
60+
max_val = paddle.where(x < zeros, -x, zeros)
61+
t1 = (1 - label) * x
62+
t2 = paddle.log((-max_val).exp() + (-x - max_val).exp())
63+
t3 = t1 + weight * (t2 + max_val)
64+
t4 = paddle.full((self.batch_size, self.num_classes), -100.0)
65+
out = paddle.where(label == t4, zeros, t3)
6566
loss = out.sum()
6667
loss.backward()
6768
return out, x.grad

0 commit comments

Comments
 (0)