Description
Hi,
I've been trying to use TripletMarginLoss for a project, but I find that it reduces the speed of my training by 5x. For example:
from pytorch_metric_learning.losses import TripletMarginLoss
import torch
embeddings = torch.rand(4000, 8)
labels = torch.randint(0, 30, (4000,))
criterion = TripletMarginLoss(triplets_per_anchor=1)
criterion(embeddings, labels)
The final line takes ~2 seconds to compute using a Tesla V100 and Intel Xeon Gold 6148. Digging into the code, I find that the issue is the function pytorch_metric_learning.utils.loss_and_miner_utils.get_random_triplet_indices
. Continuing the above example, get_random_triplet_indices(labels, t_per_anchor=1)
also takes ~2 seconds.
I profiled the above function and found that 90% of the time is spent in these 3 repeated calls to np.random.choice
and np.where
(in bold):
# sample triplets, with a weighted distribution if weights is specified.
def get_random_triplet_indices(labels, ref_labels=None, t_per_anchor=None, weights=None):
[...]
for i, label in enumerate(labels):
all_pos_pair_mask = ref_labels == label
if ref_labels_is_labels:
all_pos_pair_mask &= indices != i
all_pos_pair_idx = np.where(all_pos_pair_mask)[0]
curr_label_count = len(all_pos_pair_idx)
if curr_label_count == 0:
continue
k = curr_label_count if t_per_anchor is None else t_per_anchor
if weights is not None and not np.any(np.isnan(weights[i])):
n_idx += c_f.NUMPY_RANDOM.choice(batch_size, k, p=weights[i]).tolist()
else:
possible_n_idx = list(np.where(ref_labels != label)[0])
n_idx += c_f.NUMPY_RANDOM.choice(possible_n_idx, k).tolist()
a_idx.extend([i] * k)
curr_p_idx = c_f.safe_random_choice(all_pos_pair_idx, k)
p_idx.extend(curr_p_idx.tolist())
[...]
It seems that the issue is that these calls are within a python for loop, so there is no vectorization. Perhaps this code can be optimized? I see two solutions:
- Write a C implementation of this function.
- Vectorize the numpy operations and eliminate the python for loop.
I would be happy to attempt to submit a PR on this, if there is interest.