Skip to content

Commit

Permalink
check return type of FunctionalLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Apr 17, 2024
1 parent 8d7dadd commit 9862490
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
20 changes: 16 additions & 4 deletions ppsci/loss/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Optional
from typing import Union

import paddle

from ppsci.loss import base


Expand All @@ -34,7 +36,7 @@ class FunctionalLoss(base.Loss):
$$
Args:
loss_expr (Callable): expression of loss calculation.
loss_expr (Callable[..., paddle.Tensor]): Function for custom loss computation.
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
Examples:
Expand Down Expand Up @@ -63,11 +65,21 @@ class FunctionalLoss(base.Loss):

def __init__(
self,
loss_expr: Callable,
loss_expr: Callable[..., paddle.Tensor],
weight: Optional[Union[float, Dict[str, float]]] = None,
):
super().__init__(None, weight)
self.loss_expr = loss_expr

def forward(self, output_dict, label_dict=None, weight_dict=None):
return self.loss_expr(output_dict, label_dict, weight_dict)
def forward(self, output_dict, label_dict=None, weight_dict=None) -> paddle.Tensor:
loss = self.loss_expr(output_dict, label_dict, weight_dict)

assert isinstance(
loss, (paddle.Tensor, paddle.static.Variable, paddle.pir.Value)
), (
"Loss computed by custom function should be type of 'paddle.Tensor', "
f"'paddle.static.Variable' or 'paddle.pir.Value', but got {type(loss)}."
" Please check the return type of custom loss function."
)

return loss
30 changes: 30 additions & 0 deletions test/loss/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import paddle
import pytest

from ppsci import loss

__all__ = []


def test_non_tensor_return_type():
"""Test for biharmonic equation."""

def loss_func_return_tensor(input_dict, label_dict, weight_dict):
return (0.5 * (input_dict["x"] - label_dict["x"]) ** 2).sum()

def loss_func_reuturn_builtin_float(input_dict, label_dict, weight_dict):
return (0.5 * (input_dict["x"] - label_dict["x"]) ** 2).sum().item()

wrapped_loss1 = loss.FunctionalLoss(loss_func_return_tensor)
wrapped_loss2 = loss.FunctionalLoss(loss_func_reuturn_builtin_float)

input_dict = {"x": paddle.randn([10, 1])}
label_dict = {"x": paddle.zeros([10, 1])}

wrapped_loss1(input_dict, label_dict)
with pytest.raises(AssertionError):
wrapped_loss2(input_dict, label_dict)


if __name__ == "__main__":
pytest.main()

0 comments on commit 9862490

Please sign in to comment.