Skip to content

Commit

Permalink
Fix broadcasting cosine_similarity (pytorch#109363)
Browse files Browse the repository at this point in the history
Fixes pytorch#109333
Pull Request resolved: pytorch#109363
Approved by: https://github.com/peterbell10
  • Loading branch information
lezcano authored and pytorchmergebot committed Sep 15, 2023
1 parent aed9bee commit 653c156
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
11 changes: 6 additions & 5 deletions aten/src/ATen/ExpandUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,18 @@ expand_inplace(
// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
if (to_expand1.sizes().equals(to_expand2.sizes())) {
auto s1 = to_expand1.sym_sizes();
auto s2 = to_expand2.sym_sizes();
if (s1.equals(s2)) {
return std::make_tuple(
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
}

auto expanded_size =
infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
auto expanded_size = infer_size_symdimvector(s1, s2);
return std::make_tuple(
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)));
c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
}

inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
Expand Down
12 changes: 7 additions & 5 deletions aten/src/ATen/native/Distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,24 +308,26 @@ Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, doub
// We accept integral types (and bools lol) but vector_norm does not
auto x1_is_int = c10::isIntegralType(x1_.scalar_type(), /*încludeBool=*/true);
auto x2_is_int = c10::isIntegralType(x2_.scalar_type(), /*încludeBool=*/true);
auto x1 = x1_is_int ? x1_.to(commonDtype) : x1_;
auto x2 = x2_is_int ? x2_.to(commonDtype) : x2_;
auto x1_t = x1_is_int ? x1_.to(commonDtype) : x1_;
auto x2_t = x2_is_int ? x2_.to(commonDtype) : x2_;
c10::MaybeOwned<Tensor> x1, x2;
std::tie(x1, x2) = expand_outplace(x1_t, x2_t);


// We want to divide each tensor by its norm first, as it's more numerically stable.
// This keeps the result between -1.0 and 1.0
// We clone them, as we're going to modify them in-place
// This allows the gradients to propagate propertly all the way to x1 and x2
auto x1_norm = at::linalg_vector_norm(x1, 2, /*dim=*/dim, /*keepdim=*/true).clone();
auto x2_norm = at::linalg_vector_norm(x2, 2, /*dim=*/dim, /*keepdim=*/true).clone();
auto x1_norm = at::linalg_vector_norm(*x1, 2, /*dim=*/dim, /*keepdim=*/true).clone();
auto x2_norm = at::linalg_vector_norm(*x2, 2, /*dim=*/dim, /*keepdim=*/true).clone();

{
at::NoGradGuard guard;
x1_norm.clamp_min_(eps);
x2_norm.clamp_min_(eps);
}

return ((x1 / x1_norm) * (x2 / x2_norm)).sum(dim);
return ((*x1 / x1_norm) * (*x2 / x2_norm)).sum(dim);
}

}} // namespace at::native
12 changes: 12 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5609,6 +5609,18 @@ def test_cosine_similarity(self):
out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
self.assertEqual(out, 1.)

# Check broadcasting #109333
a = torch.ones(2, 3, dtype=torch.float)
b = torch.ones(1, 1, dtype=torch.float)
out = F.cosine_similarity(a, b)
self.assertEqual(out, torch.ones(2, dtype=torch.float))

a = torch.ones(2, 3, dtype=torch.float)
b = torch.ones(1, dtype=torch.float)
out = F.cosine_similarity(a, b)
self.assertEqual(out, torch.ones(2, dtype=torch.float))


def test_grid_sample_error_checking(self):
input = torch.empty(1, 1, 2, 2)
grid = torch.empty(1, 1, 1, 2)
Expand Down

0 comments on commit 653c156

Please sign in to comment.