Skip to content

Weird behavior of CosineSimilarityMetric when used with tensors of shape [d] #2240

Closed
@ValerianRey

Description

@ValerianRey

🐛 Bug?

According to the documentation, CosineSimilarityMetric requires tensors of shape [N, d], with N the batch size and d the dimension of the vectors.

Using it with vectors of shape [d] does not raise any error, and the call to compute behaves weirdly and gives confusing results. I'm not sure whether this is the expected behavior, a bug, or if this usage is simply unintended (and then maybe an error could have been raised).

To Reproduce

>>> from torchmetrics import CosineSimilarity
>>> from torch import tensor
>>> 
>>> cosine_similarity = CosineSimilarity(reduction="mean")
>>> a = tensor([1., 1., 1.])
>>> b = tensor([100., 100., 100.])
>>> cosine_similarity(a, b)
tensor(1.)  # a and b have the same direction, so this is normal.
>>> cosine_similarity(b, a)  
tensor(1.)  # same for b and a.
>>> cosine_similarity.compute()
tensor(0.0200)  # I would expect this to be 1 too (the average of the 2 previous calls).

The obtained result (0.02) is actually the cosine similarity between [1, 1, 1, 100, 100, 100] and [100, 100, 100, 1, 1, 1]. I would have expected to get instead the average between the cosine similarity of [1, 1, 1] and [100, 100, 100] and the cosine similarity of [100, 100, 100] and [1, 1, 1], which is 1.

If instead we use it as the documentation says, with tensors of shape [N, d], we get different results:

>>> from torchmetrics import CosineSimilarity
>>> from torch import tensor
>>> 
>>> cosine_similarity = CosineSimilarity(reduction="mean")
>>> a = tensor([[1., 1., 1.]])  # tensor of shape [1, 3] instead of [3]
>>> b = tensor([[100., 100., 100.]])  # tensor of shape [1, 3] instead of [3]
>>> cosine_similarity(a, b)
tensor(1.)
>>> cosine_similarity(b, a)
tensor(1.)
>>> cosine_similarity.compute()
tensor(1.)  # 1 instead of 0.02

Environment:

  • TorchMetrics 1.2.0
  • Python 3.10.10
  • torch 2.1.1
  • Ubuntu 20.04.6 LTS

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions