Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.117】add parameter value for thresholded_relu #60067

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,7 @@
func : tensor_unfold_grad

- backward_op : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
forward : thresholded_relu (Tensor x, float threshold, float value) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold)
output : Tensor(x_grad)
infer_meta :
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2619,7 +2619,7 @@
no_need_buffer : input

- op : thresholded_relu
args : (Tensor x, float threshold = 1.0)
args : (Tensor x, float threshold = 1.0, float value = 0.0)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/activation_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ DECLARE_ACTIVATION_KERNEL(Ceil)
DECLARE_ACTIVATION_KERNEL(Negative)

DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
Expand All @@ -84,6 +83,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps)

DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardTanh, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, threshold, value)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,17 @@ DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, ExpFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, Expm1Functor)

DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha)

DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh, HardTanhFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
threshold,
value)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
Expand Down
15 changes: 9 additions & 6 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1795,14 +1795,17 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
float threshold;
float value;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
return {{"threshold", &threshold}, {"value", &value}};
}

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto th = static_cast<T>(threshold); // NOLINT
out.device(d) = (x > th).template cast<T>() * x;
auto va = static_cast<T>(value); // NOLINT
out.device(d) =
(x > th).template cast<T>() * x + (x <= th).template cast<T>() * va;
}
};

Expand Down Expand Up @@ -4040,16 +4043,16 @@ struct CudaHardTanhGradFunctor : public BaseActivationFunctor<T> {

template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float value;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
return {{"threshold", &threshold}, {"value", &value}};
}

// thresholded_relu(x) = x > threshold ? x : 0
// thresholded_relu(x) = x > threshold ? x : value
__device__ __forceinline__ T operator()(const T x) const {
return x > static_cast<T>(threshold) ? x : zero;
return x > static_cast<T>(threshold) ? x : static_cast<T>(value);
}
};

Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor)

DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
CudaHardShrinkFunctor,
threshold)
Expand All @@ -138,6 +135,10 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh,
CudaHardTanhFunctor,
t_min,
t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
threshold,
value)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus,
CudaSoftplusFunctor,
Expand Down
15 changes: 8 additions & 7 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ def tanhshrink(x, name=None):
return out


def thresholded_relu(x, threshold=1.0, name=None):
def thresholded_relu(x, threshold=1.0, value=0.0, name=None):
r"""
thresholded relu activation.

Expand All @@ -1557,14 +1557,15 @@ def thresholded_relu(x, threshold=1.0, name=None):
\left\{
\begin{array}{rl}
x,& \text{if } \ x > threshold \\
0,& \text{otherwise}
value,& \text{otherwise}
\end{array}
\right.


Parameters:
x (Tensor): The input Tensor with data type float32, float64.
threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0
threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0.
value (float, optional): The value to replace with. Default is 0.0.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Expand All @@ -1584,7 +1585,7 @@ def thresholded_relu(x, threshold=1.0, name=None):
"""

if in_dynamic_or_pir_mode():
return _C_ops.thresholded_relu(x, threshold)
return _C_ops.thresholded_relu(x, threshold, value)
else:
check_variable_and_dtype(
x,
Expand All @@ -1598,19 +1599,19 @@ def thresholded_relu(x, threshold=1.0, name=None):
type='thresholded_relu',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold},
attrs={'threshold': threshold, 'value': value},
)
return out


@inplace_apis_in_dygraph_only
def thresholded_relu_(x, threshold=1.0, name=None):
def thresholded_relu_(x, threshold=1.0, value=0.0, name=None):
r"""
Inplace version of ``thresholded_relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_nn_functional_thresholded_relu`.
"""
if in_dynamic_mode():
return _C_ops.thresholded_relu_(x, threshold)
return _C_ops.thresholded_relu_(x, threshold, value)


def log_softmax(x, axis=-1, dtype=None, name=None):
Expand Down
10 changes: 6 additions & 4 deletions python/paddle/nn/layer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,13 +1172,14 @@ class ThresholdedReLU(Layer):
\left\{
\begin{array}{rl}
x,& \text{if } \ x > threshold \\
0,& \text{otherwise}
value,& \text{otherwise}
\end{array}
\right.


Parameters:
threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0
threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0.
value (float, optional): The value to replace with. Default is 0.0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Expand All @@ -1199,13 +1200,14 @@ class ThresholdedReLU(Layer):
[2., 0., 0.])
"""

def __init__(self, threshold=1.0, name=None):
def __init__(self, threshold=1.0, value=0.0, name=None):
super().__init__()
self._threshold = threshold
self._value = value
self._name = name

def forward(self, x):
return F.thresholded_relu(x, self._threshold, self._name)
return F.thresholded_relu(x, self._threshold, self._value, self._name)

def extra_repr(self):
name_str = f', name={self._name}' if self._name else ''
Expand Down
61 changes: 43 additions & 18 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4310,8 +4310,8 @@ def test_errors(self):
F.softsign(x_fp16)


def ref_thresholded_relu(x, threshold=1.0):
out = (x > threshold) * x
def ref_thresholded_relu(x, threshold=1.0, value=0.0):
out = (x > threshold) * x + (x <= threshold) * value
return out


Expand All @@ -4320,20 +4320,23 @@ def setUp(self):
self.op_type = "thresholded_relu"
self.init_dtype()
self.init_shape()
self.set_attrs()
self.python_api = paddle.nn.functional.thresholded_relu

threshold = 15

np.random.seed(1024)
x = np.random.uniform(-20, 20, self.shape).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_thresholded_relu(x, threshold)
out = ref_thresholded_relu(x, self.threshold, self.value)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"threshold": threshold}
self.attrs = {"threshold": self.threshold, "value": self.value}
self.convert_input_output()

def set_attrs(self):
self.threshold = 15
self.value = 5

def init_shape(self):
self.shape = [10, 12]

Expand All @@ -4346,6 +4349,12 @@ def test_check_output(self):
self.check_output(check_pir=True)


class TestThresholdedRelu_ZeroValue(TestThresholdedRelu):
def set_attrs(self):
self.threshold = 15
self.value = 0


class TestThresholdedRelu_ZeroDim(TestThresholdedRelu):
def init_shape(self):
self.shape = []
Expand All @@ -4354,37 +4363,47 @@ def init_shape(self):
class TestThresholdedReluAPI(unittest.TestCase):
# test paddle.nn.ThresholdedReLU, paddle.nn.functional.thresholded_relu
def setUp(self):
self.threshold = 15
self.set_attrs()
np.random.seed(1024)
self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64)
self.x_np[np.abs(self.x_np) < 0.005] = 0.02
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def set_attrs(self):
self.threshold = 15
self.value = 5

@test_with_pir_api
def test_static_api(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.thresholded_relu(x, self.threshold)
thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold)
out1 = F.thresholded_relu(x, self.threshold, self.value)
thresholded_relu = paddle.nn.ThresholdedReLU(
self.threshold, self.value
)
out2 = thresholded_relu(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_thresholded_relu(self.x_np, self.threshold)
out_ref = ref_thresholded_relu(
self.x_np, self.threshold, self.value
)
for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)

def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.thresholded_relu(x, self.threshold)
thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold)
out1 = F.thresholded_relu(x, self.threshold, self.value)
thresholded_relu = paddle.nn.ThresholdedReLU(
self.threshold, self.value
)
out2 = thresholded_relu(x)
out_ref = ref_thresholded_relu(self.x_np, self.threshold)
out_ref = ref_thresholded_relu(
self.x_np, self.threshold, self.value
)
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)

Expand All @@ -4405,6 +4424,12 @@ def test_errors(self):
F.thresholded_relu(x_fp16)


class TestThresholdedReluAPI_ZeroValue(TestThresholdedReluAPI):
def set_attrs(self):
self.threshold = 15
self.value = 0


def ref_hardsigmoid(x, slope=0.166666666666667, offset=0.5):
return np.maximum(np.minimum(x * slope + offset, 1.0), 0.0).astype(x.dtype)

Expand Down