Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNRL with Multiple hard negatives per query and NoDuplicatesBatchSampler #2954

Open
ArthurCamara opened this issue Sep 23, 2024 · 0 comments

Comments

@ArthurCamara
Copy link
Contributor

As stated in the documentation, the MNRL loss (and, by extension, it’s Cached variant) can handle more than one hard negative per anchor, as these extra negatives will be included in the “pool” of all the negatives each anchor will be scored against.

That being said, there are two issues related to that in the current version:
First, using the recommended NoDuplicatesBatchSampler will ignore all but the first hard negative in the dataset for each anchor. For instance, take the training_nli_v3.py example. The dataset has a format (a_1, p_1, n_2), (a_1, p_1, n_2)... with multiple hard negatives per query (See rows 17-21 of the triplet subset). However, the sampler will skip all rows except the first when building each batch, as both anchor and positive are already present in the batch_values set. One could try to work around it by setting the valid_label_columns, but it will just ignore these columns when considering what is already present in the batch.

Addressing this is somewhat straightforward. We can just skip the rows where either the positive or the negative have already been seen between all the negatives in the batch (or, in the case where there is no negative column, we can revert to the current behaviour:

    def __iter__(self) -> Iterator[list[int]]:
        """
        Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the
        batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield
        the batch indices and continue with the next batch.
        """
        if self.generator and self.seed:
            self.generator.manual_seed(self.seed + self.epoch)
        anchor_column = self.dataset.column_names[0]
        positive_column = self.dataset.column_names[1]
        negative_column = (
            self.dataset.column_names[2] if len(self.dataset.column_names) > 2 else None
        )
        remaining_indices = set(
            torch.randperm(len(self.dataset), generator=self.generator).tolist()
        )
        while remaining_indices:
            batch_values = set()
            batch_indices = []
            for index in remaining_indices:
                sample = self.dataset[index]
                # Make sure that either the positive or the negative ARE NOT in the seen values
                if negative_column:
                    if sample[negative_column] in batch_values:
                        continue
                elif sample[positive_column] in batch_values:
                    continue
                batch_indices.append(index)
                if len(batch_indices) == self.batch_size:
                    yield batch_indices
                    break

                batch_values.add(sample[anchor_column])
                batch_values.add(sample[positive_column])
            else:
                # NOTE: some indices might still have been ignored here
                if not self.drop_last:
                    yield batch_indices

            remaining_indices -= set(batch_indices)

However, even by fixing this, MNRL would still behave different from what we would expect. In the current implementation, if the dataset has multiple hard negatives in the format (a_1, p_1, n_1), (a_1, p_1, n_2), the loss would be computed n_nard_negatives times for each anchor, as each time (a_1, p_1, n_k) happens in the dataset, the loss wrt. a_1 will be computed again.
Ideally, we would want to add all hard negatives to the (larger) pool of negatives, and compute the positive score just once.

The easier way around this (IMO) is to allow for multiple negative rows (similar to the output of the mine_hard_negatives if as_triplets is set to False in the sampler.

This means changing the __iter__ snippet above to something like this:


negative_columns = [self.dataset.column_names[i] for i in range(2, len(self.dataset.column_names))]
(…)
if negative_columns:
    if any(sample[negative_column] in batch_values for negative_column in negative_columns):
        continue

From some initial tests, this seems to be enough to make it work with MNRL and CMNRL, but there could be more to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant