Skip to content

Commit 3d77360

Browse files
committed
add negative clipping for softmax.
1 parent 360bde9 commit 3d77360

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

paddle/operators/math/softmax.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ template <typename T, int MajorType = Eigen::RowMajor,
2525
typename IndexType = Eigen::DenseIndex>
2626
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2727

28+
template <typename T>
29+
struct ValueClip {
30+
HOSTDEVICE T operator()(const T& x) const {
31+
const T kThreshold = -64.;
32+
return x < kThreshold ? kThreshold : x;
33+
}
34+
};
35+
2836
template <typename Place, typename T>
2937
class SoftmaxFunctor {
3038
public:
@@ -47,7 +55,8 @@ class SoftmaxFunctor {
4755
logits.maximum(along_class)
4856
.eval()
4957
.reshape(batch_by_one)
50-
.broadcast(one_by_class));
58+
.broadcast(one_by_class))
59+
.unaryExpr(ValueClip<T>());
5160

5261
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
5362
softmax.device(context.GetEigenDevice<Place>()) =

python/paddle/v2/framework/tests/test_softmax_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def stable_softmax(x):
77
"""Compute the softmax of vector x in a numerically stable way."""
8-
shiftx = x - np.max(x)
8+
shiftx = x - np.max(x).clip(-64.)
99
exps = np.exp(shiftx)
1010
return exps / np.sum(exps)
1111

0 commit comments

Comments
 (0)