Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Check return type of FunctionalLoss #854

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions ppsci/loss/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Optional
from typing import Union

import paddle

from ppsci.loss import base


Expand All @@ -34,7 +36,7 @@ class FunctionalLoss(base.Loss):
$$

Args:
loss_expr (Callable): expression of loss calculation.
loss_expr (Callable[..., paddle.Tensor]): Function for custom loss computation.
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

Examples:
Expand Down Expand Up @@ -63,11 +65,21 @@ class FunctionalLoss(base.Loss):

def __init__(
self,
loss_expr: Callable,
loss_expr: Callable[..., paddle.Tensor],
weight: Optional[Union[float, Dict[str, float]]] = None,
):
super().__init__(None, weight)
self.loss_expr = loss_expr

def forward(self, output_dict, label_dict=None, weight_dict=None):
return self.loss_expr(output_dict, label_dict, weight_dict)
def forward(self, output_dict, label_dict=None, weight_dict=None) -> paddle.Tensor:
loss = self.loss_expr(output_dict, label_dict, weight_dict)

assert isinstance(
loss, (paddle.Tensor, paddle.static.Variable, paddle.pir.Value)
), (
"Loss computed by custom function should be type of 'paddle.Tensor', "
f"'paddle.static.Variable' or 'paddle.pir.Value', but got {type(loss)}."
" Please check the return type of custom loss function."
)

return loss
30 changes: 30 additions & 0 deletions test/loss/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import paddle
import pytest

from ppsci import loss

__all__ = []


def test_non_tensor_return_type():
"""Test for biharmonic equation."""

def loss_func_return_tensor(input_dict, label_dict, weight_dict):
return (0.5 * (input_dict["x"] - label_dict["x"]) ** 2).sum()

def loss_func_reuturn_builtin_float(input_dict, label_dict, weight_dict):
return (0.5 * (input_dict["x"] - label_dict["x"]) ** 2).sum().item()

wrapped_loss1 = loss.FunctionalLoss(loss_func_return_tensor)
wrapped_loss2 = loss.FunctionalLoss(loss_func_reuturn_builtin_float)

input_dict = {"x": paddle.randn([10, 1])}
label_dict = {"x": paddle.zeros([10, 1])}

wrapped_loss1(input_dict, label_dict)
with pytest.raises(AssertionError):
wrapped_loss2(input_dict, label_dict)


if __name__ == "__main__":
pytest.main()