Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,8 @@
func : inverse_grad

- backward_op : kldiv_loss_grad
forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean") -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, str reduction)
forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean", bool log_target = false) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, str reduction, bool log_target)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/op_version.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,14 @@
comment : "The arg 'dispensable' of Input 'Scale' is changed: from 'False' to 'True'."
default : "true"

- op : kldiv_loss
version :
- checkpoint : Upgrade kldiv_loss, add a new attribute [log_target]
action :
- add_attr : log_target
comment : In order to specify whether 'label' is passed in log space.
default : "false"

- op : lamb
version :
- checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut].
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : kldiv_loss
args : (Tensor x, Tensor label, str reduction = "mean")
args : (Tensor x, Tensor label, str reduction = "mean", bool log_target = false)
output : Tensor(out)
infer_meta :
func : KLDivInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
bool log_target,
MetaTensor* out,
MetaConfig config) {
auto dim_x = x.dims();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
bool log_target,
MetaTensor* out,
MetaConfig config = MetaConfig());

Expand Down
19 changes: 13 additions & 6 deletions paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@ namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossBackward {
HOSTDEVICE KLDivLossBackward() {}
bool log_target = false;

HOSTDEVICE KLDivLossBackward(bool logTarget) : log_target(logTarget) {}

HOSTDEVICE T operator()(const T& target, const T& grad) const {
if (target <= 0) {
return 0;
if (log_target) {
return static_cast<T>(-1.) * std::exp(target) * grad;
} else {
return static_cast<T>(-1.) * grad;
if (target <= 0) {
return 0;
} else {
return static_cast<T>(-1.) * target * grad;
}
}
}
};
Expand All @@ -40,6 +46,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
bool log_target,
DenseTensor* d_x) {
auto& place = *dev_ctx.eigen_device();
auto* target = &label;
Expand All @@ -58,9 +65,9 @@ void KLDivLossGradKernel(const Context& dev_ctx,
auto loss_grad_t = phi::EigenVector<T>::Flatten(*loss_grad);

auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
auto grad_t = target_t * loss_grad_expand;
auto grad_t = loss_grad_expand;
input_grad_t.device(place) =
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
target_t.binaryExpr(grad_t, KLDivLossBackward<T>(log_target));

if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
Expand Down
18 changes: 13 additions & 5 deletions paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,29 @@ namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossForward {
HOSTDEVICE KLDivLossForward() {}
bool log_target = false;

HOSTDEVICE KLDivLossForward(bool logTarget) : log_target(logTarget) {}

HOSTDEVICE T operator()(const T& target, const T& input) const {
if (target <= 0) {
return 0;
if (log_target) {
return std::exp(target) * (target - input);
} else {
return target * (std::log(target) - input);
if (target <= 0) {
return 0;
} else {
return target * (std::log(target) - input);
}
}
}
};

template <typename T, typename Context>
void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
bool log_target,
DenseTensor* out) {
auto& place = *(dev_ctx.eigen_device());
auto* input = &x;
Expand All @@ -51,7 +59,7 @@ void KLDivLossKernel(const Context& dev_ctx,
auto input_t = phi::EigenVector<T>::Flatten(*input);
auto target_t = phi::EigenVector<T>::Flatten(*target);
auto loss_t = phi::EigenVector<T>::Flatten(*loss);
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>(log_target));
if ("none" == reduction) {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/kldiv_loss_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
bool log_target,
DenseTensor* d_x);
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/kldiv_loss_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
bool log_target,
DenseTensor* out);
} // namespace phi
34 changes: 28 additions & 6 deletions paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
bool log_target,
DenseTensor* d_x) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(d_x);
Expand All @@ -33,12 +34,33 @@ void KLDivLossGradKernel(const Context& dev_ctx,
}

int r = XPU_SUCCESS;
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<const XPUType*>(d_out.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_x->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");

if (log_target) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm<XPUType>(label.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp);

r = xpu::exp(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
label_exp,
label.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");

r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label_exp),
reinterpret_cast<const XPUType*>(d_out.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_x->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
} else {
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<const XPUType*>(d_out.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_x->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
}

if ("none" != reduction) {
PADDLE_THROW(phi::errors::Unavailable(
"Not supported reduction [%s] in kldiv_loss_grad", reduction));
Expand Down
34 changes: 28 additions & 6 deletions paddle/phi/kernels/xpu/kldiv_loss_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
bool log_target,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
Expand All @@ -32,12 +33,33 @@ void KLDivLossKernel(const Context& dev_ctx,
}

int r = XPU_SUCCESS;
r = xpu::kldiv_loss(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");

if (log_target) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm<XPUType>(label.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp);

r = xpu::exp(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
label_exp,
label.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");

r = xpu::kldiv_loss(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(label_exp),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");
} else {
r = xpu::kldiv_loss(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");
}

if ("none" != reduction) {
PADDLE_THROW(phi::errors::Unavailable(
"Not supported reduction [%s] in kldiv_loss", reduction));
Expand Down
22 changes: 19 additions & 3 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,16 +1618,22 @@ def poisson_nll_loss(
return loss_out


def kl_div(input, label, reduction='mean', name=None):
def kl_div(input, label, reduction='mean', log_target=False, name=None):
r"""
Calculate the Kullback-Leibler divergence loss
between Input(X) and Input(Target). Notes that Input(X) is the
log-probability and Input(Target) is the probability.

KL divergence loss is calculated as follows:

If `log_target` is False:

$$l(x, y) = y * (\log(y) - x)$$

If `log_target` is True:

$$l(x, y) = \exp(y) * (y - x)$$

Here :math:`x` is input and :math:`y` is label.

If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result.
Expand All @@ -1649,6 +1655,7 @@ def kl_div(input, label, reduction='mean', name=None):
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be applied.
Default is ``'mean'``.
log_target (bool, optional): Indicate whether `label` is passed in log space. Default is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的示例代码,可以加一个log_target=True时的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

name(str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -1689,6 +1696,15 @@ def kl_div(input, label, reduction='mean', name=None):
>>> print(pred_loss.shape)
[5, 20]

>>> # if label is in the log space, set log_target = True
>>> target = paddle.uniform(shape, min=0, max=10).astype('float32')
>>> log_target = paddle.log(target)
>>> pred_loss_1 = F.kl_div(x, target, reduction='none')
>>> pred_loss_2 = F.kl_div(x, log_target, reduction='none', log_target=True)
>>> print(paddle.allclose(pred_loss_1, pred_loss_2))
Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True,
True)

"""
# ugly type promotion
if (
Expand All @@ -1703,7 +1719,7 @@ def kl_div(input, label, reduction='mean', name=None):
label = paddle.cast(label, 'float64')

if in_dynamic_or_pir_mode():
out = _C_ops.kldiv_loss(input, label, 'none')
out = _C_ops.kldiv_loss(input, label, 'none', log_target)
if reduction == 'mean':
out = paddle.mean(out)
elif reduction == 'sum':
Expand All @@ -1729,7 +1745,7 @@ def kl_div(input, label, reduction='mean', name=None):
type='kldiv_loss',
inputs={'X': input, 'Target': label},
outputs={'Loss': loss},
attrs={'reduction': 'none'},
attrs={'reduction': 'none', 'log_target': log_target},
)

if reduction == 'mean':
Expand Down
22 changes: 20 additions & 2 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,14 @@ class KLDivLoss(Layer):

KL divergence loss is calculated as follows:

If `log_target` is False:

$$l(x, y) = y * (\log(y) - x)$$

If `log_target` is True:

$$l(x, y) = \exp(y) * (y - x)$$

Here :math:`x` is input and :math:`y` is label.

If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result.
Expand All @@ -1054,6 +1060,7 @@ class KLDivLoss(Layer):
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be applied.
Default is ``'mean'``.
log_target (bool, optional): Indicate whether `label` is passed in log space. Default is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的示例代码,可以加一个log_target=True时的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


Shape:

Expand Down Expand Up @@ -1097,14 +1104,25 @@ class KLDivLoss(Layer):
>>> print(pred_loss.shape)
[5, 20]

>>> # if label is in the log space, set log_target = True
>>> target = paddle.uniform(shape, min=0, max=10).astype('float32')
>>> log_target = paddle.log(target)
>>> kldiv_criterion_1 = nn.KLDivLoss(reduction='none')
>>> kldiv_criterion_2 = nn.KLDivLoss(reduction='none', log_target=True)
>>> pred_loss_1 = kldiv_criterion_1(x, target)
>>> pred_loss_2 = kldiv_criterion_2(x, log_target)
>>> print(paddle.allclose(pred_loss_1, pred_loss_2))
Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True,
True)
"""

def __init__(self, reduction='mean'):
def __init__(self, reduction='mean', log_target=False):
super().__init__()
self.reduction = reduction
self.log_target = log_target

def forward(self, input, label):
out = F.kl_div(input, label, self.reduction)
out = F.kl_div(input, label, self.reduction, self.log_target)
return out


Expand Down
Loading