Skip to content

TripletMarginLoss Performance #192

Closed
@AlexSchuy

Description

@AlexSchuy

Hi,

I've been trying to use TripletMarginLoss for a project, but I find that it reduces the speed of my training by 5x. For example:

from pytorch_metric_learning.losses import TripletMarginLoss
import torch

embeddings = torch.rand(4000, 8)
labels = torch.randint(0, 30, (4000,))

criterion = TripletMarginLoss(triplets_per_anchor=1)

criterion(embeddings, labels)

The final line takes ~2 seconds to compute using a Tesla V100 and Intel Xeon Gold 6148. Digging into the code, I find that the issue is the function pytorch_metric_learning.utils.loss_and_miner_utils.get_random_triplet_indices. Continuing the above example, get_random_triplet_indices(labels, t_per_anchor=1) also takes ~2 seconds.

I profiled the above function and found that 90% of the time is spent in these 3 repeated calls to np.random.choice and np.where (in bold):


# sample triplets, with a weighted distribution if weights is specified.
def get_random_triplet_indices(labels, ref_labels=None, t_per_anchor=None, weights=None):
    [...]
    for i, label in enumerate(labels):
        all_pos_pair_mask = ref_labels == label
        if ref_labels_is_labels:
            all_pos_pair_mask &= indices != i
        all_pos_pair_idx = np.where(all_pos_pair_mask)[0]
        curr_label_count = len(all_pos_pair_idx)
        if curr_label_count == 0:
            continue
        k = curr_label_count if t_per_anchor is None else t_per_anchor

        if weights is not None and not np.any(np.isnan(weights[i])):
            n_idx += c_f.NUMPY_RANDOM.choice(batch_size, k, p=weights[i]).tolist()
        else:
            possible_n_idx = list(np.where(ref_labels != label)[0])
            n_idx += c_f.NUMPY_RANDOM.choice(possible_n_idx, k).tolist()

        a_idx.extend([i] * k)
        curr_p_idx = c_f.safe_random_choice(all_pos_pair_idx, k)
        p_idx.extend(curr_p_idx.tolist())
    [...]

It seems that the issue is that these calls are within a python for loop, so there is no vectorization. Perhaps this code can be optimized? I see two solutions:

  1. Write a C implementation of this function.
  2. Vectorize the numpy operations and eliminate the python for loop.

I would be happy to attempt to submit a PR on this, if there is interest.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions