diff --git a/nn/loss.py b/nn/loss.py index fef1d576..80d4a3a7 100644 --- a/nn/loss.py +++ b/nn/loss.py @@ -190,6 +190,18 @@ def kl_div(*, target: nn.Tensor, target_type: str, return kl +@nn.scoped +def mean_absolute_difference(a: nn.Tensor, b: nn.Tensor, *, axis: Optional[nn.Dim] = None) -> nn.Tensor: + """ + Mean absolute difference, mean absolute error (MAE), or L1 loss between two tensors, + i.e. mean_{axis}( abs(a - b) ), where axis is the feature dim by default. + """ + if not axis: + assert a.feature_dim + axis = a.feature_dim + return nn.reduce(nn.abs(a - b), mode="mean", axis=axis) + + @nn.scoped def mean_squared_difference(a: nn.Tensor, b: nn.Tensor, *, axis: Optional[nn.Dim] = None) -> nn.Tensor: """