Open
Description
🐛 Bug
Hi TorchMetrics Team,
In the following example, nDCG calculation using GPU tensors spent 2 times longer the time using CPU tensors and numpy array.
To Reproduce
The codes were tested on both Google Colab and a Slurm cluster.
Code sample
import timeit
import numpy as np
import torch
from sklearn.metrics import ndcg_score
from torchmetrics.functional.retrieval import retrieval_normalized_dcg
# p and t are examples given by both sklearn and torchmetrics
p = [.1, .2, .3, 4, 70] * 100
t = [10, 0, 0, 1, 5] * 100
number = int(1e4)
# 1. BENCHMARK: numpy array
preds = np.asarray([p])
target = np.asarray([t])
def a():
return ndcg_score(target, preds)
print(f'numpy array: {timeit.timeit("a()", setup="from __main__ import a", number=number):.4f}')
# 2. cpu tensor
preds_cpu = torch.tensor(p)
target_cpu = torch.tensor(t)
assert preds_cpu.device == torch.device("cpu")
def b():
retrieval_normalized_dcg(preds_cpu, target_cpu)
print(f'CPU tensor: {timeit.timeit(f"b()", setup="from __main__ import b", number=number):.4f}')
# 3. gpu tensor
preds_gpu = torch.tensor(p, device="cuda")
target_gpu = torch.tensor(t, device="cuda")
assert preds_gpu.device == torch.device("cuda:0")
def c():
retrieval_normalized_dcg(preds_gpu, target_gpu)
print(f'GPU tensor: {timeit.timeit("c()", setup="from __main__ import c", number=number):.4f}')
Results:
# Tesla T4
numpy array: 6.4896
CPU tensor: 5.8501
GPU tensor: 10.4120
I also tested the codes on the Slurm Cluster I'm currently using, the GPU here is an A100.
numpy array: 3.8700
CPU tensor: 2.9305
GPU tensor: 7.7575
Expected behavior
The performance of calculation using GPU tensors, if not superior, should be at least close to CPU tensors.
Environment
- TorchMetrics version (and how you installed TM, e.g.
conda
,pip
, build from source): 1.2.1 (pip) - Python & PyTorch Version (e.g., 1.0): Python 3.10.12 and 3.10.13, Torch 2.1.0 and 2.1.1
- Any other relevant information such as OS (e.g., Linux): Ubuntu 22.04.3 LTS and Linux 5.4.204-ql-generic-12.0-19 x86_64