diff --git a/mmdet/models/losses/triplet_loss.py b/mmdet/models/losses/triplet_loss.py index d9c9604b8c7..4528239beb4 100644 --- a/mmdet/models/losses/triplet_loss.py +++ b/mmdet/models/losses/triplet_loss.py @@ -40,7 +40,7 @@ def hard_mining_triplet_loss_forward( inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape - (num_classes). + (batch_size). Returns: torch.Tensor: triplet loss with hard mining.