Skip to content

Commit

Permalink
mean_squared_difference
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 13, 2022
1 parent 74f75ca commit 68d45d6
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,15 @@ def kl_div(*, target: nn.Tensor, target_type: str,
kl = nn.dot(nn.exp(log_target), log_target - log_est, reduce=axis)

return kl


@nn.scoped
def mean_squared_difference(a: nn.Tensor, b: nn.Tensor, *, axis: Optional[nn.Dim] = None) -> nn.Tensor:
"""
Mean squared difference between two tensors,
i.e. mean_{axis}( (a - b) ** 2 ), where axis is the feature dim by default.
"""
if not axis:
assert a.feature_dim
axis = a.feature_dim
return nn.reduce(nn.squared_difference(a, b), mode="mean", axis=axis)

0 comments on commit 68d45d6

Please sign in to comment.