diff --git a/simpletransformers/config/model_args.py b/simpletransformers/config/model_args.py index ebf600b5..e4d47621 100644 --- a/simpletransformers/config/model_args.py +++ b/simpletransformers/config/model_args.py @@ -395,7 +395,9 @@ class RetrievalArgs(Seq2SeqArgs): moving_average_loss_count: int = 10 nll_lambda: float = 1.0 output_dropout: float = 0.1 - pytrec_eval_metrics: list = field(default_factory=lambda: ["recip_rank", "recall_100", "ndcg_cut_10", "ndcg"]) + pytrec_eval_metrics: list = field( + default_factory=lambda: ["recip_rank", "recall_100", "ndcg_cut_10", "ndcg"] + ) query_config: dict = field(default_factory=dict) remove_duplicates_from_eval_passages: bool = False repeat_high_loss_n: int = 0 diff --git a/simpletransformers/retrieval/pytrec_eval_utils.py b/simpletransformers/retrieval/pytrec_eval_utils.py index 6f2baa09..a45af8b3 100644 --- a/simpletransformers/retrieval/pytrec_eval_utils.py +++ b/simpletransformers/retrieval/pytrec_eval_utils.py @@ -4,7 +4,9 @@ logger = logging.getLogger(__name__) -def convert_predictions_to_pytrec_format(predicted_doc_ids, query_dataset, id_column="_id"): +def convert_predictions_to_pytrec_format( + predicted_doc_ids, query_dataset, id_column="_id" +): run_dict = {} for query_id, doc_ids in zip(query_dataset[id_column], predicted_doc_ids): run_dict[query_id] = {doc_id: 1 / (i + 1) for i, doc_id in enumerate(doc_ids)} diff --git a/simpletransformers/retrieval/retrieval_model.py b/simpletransformers/retrieval/retrieval_model.py index a561e3c0..da0e0880 100644 --- a/simpletransformers/retrieval/retrieval_model.py +++ b/simpletransformers/retrieval/retrieval_model.py @@ -837,6 +837,9 @@ def train( colbert_percentage = ( retrieval_output.colbert_correct_predictions_percentage ) + reranking_correct_predictions_percentage = ( + retrieval_output.reranking_correct_predictions_percentage + ) if args.n_gpu > 1: loss = loss.mean() @@ -844,7 +847,10 @@ def train( # Compare the current loss to the moving average loss current_loss = loss.item() - if (args.repeat_high_loss_n == 0 or moving_loss.size() < args.moving_average_loss_count): + if ( + args.repeat_high_loss_n == 0 + or moving_loss.size() < args.moving_average_loss_count + ): break if current_loss > moving_loss.get_average_loss(): @@ -884,9 +890,14 @@ def train( moving_loss.add_loss(current_loss) if show_running_loss and (args.kl_div_loss or args.margin_mse_loss): - batch_iterator.set_description( - f"Epochs {epoch_number + 1}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f} Correct percentage: {correct_predictions_percentage:4.1f} Colbert percentage: {colbert_percentage:4.1f}" - ) + if args.unified_cross_rr: + batch_iterator.set_description( + f"Epochs {epoch_number + 1}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f} Correct percentage: {correct_predictions_percentage:4.1f} Colbert percentage: {colbert_percentage:4.1f} Reranking correct percentage: {reranking_correct_predictions_percentage:4.1f}" + ) + else: + batch_iterator.set_description( + f"Epochs {epoch_number + 1}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f} Correct percentage: {correct_predictions_percentage:4.1f} Colbert percentage: {colbert_percentage:4.1f}" + ) elif show_running_loss: if args.repeat_high_loss_n > 0: batch_iterator.set_description( @@ -950,13 +961,23 @@ def train( } else: if args.kl_div_loss or args.margin_mse_loss: - logging_dict = { - "Training loss": current_loss, - "lr": scheduler.get_last_lr()[0], - "global_step": global_step, - "correct_predictions_percentage": correct_predictions_percentage, - "colbert_correct_predictions_percentage": colbert_percentage, - } + if args.unified_cross_rr: + logging_dict = { + "Training loss": current_loss, + "lr": scheduler.get_last_lr()[0], + "global_step": global_step, + "correct_predictions_percentage": correct_predictions_percentage, + "colbert_correct_predictions_percentage": colbert_percentage, + "reranking_correct_predictions_percentage": reranking_correct_predictions_percentage, + } + else: + logging_dict = { + "Training loss": current_loss, + "lr": scheduler.get_last_lr()[0], + "global_step": global_step, + "correct_predictions_percentage": correct_predictions_percentage, + "colbert_correct_predictions_percentage": colbert_percentage, + } if args.include_nll_loss: logging_dict[ "nll_loss" @@ -1459,7 +1480,10 @@ def eval_model( ) else: _, query_dataset, qrels_dataset = load_trec_format( - eval_data, qrels_name=eval_set, data_format=self.args.data_format, skip_passages=True + eval_data, + qrels_name=eval_set, + data_format=self.args.data_format, + skip_passages=True, ) if self.args.data_format == "beir": @@ -1481,14 +1505,18 @@ def eval_model( ) self.prediction_passages = passage_index - query_text_column = "text" if self.args.data_format == "beir" else "query_text" + query_text_column = ( + "text" if self.args.data_format == "beir" else "query_text" + ) predicted_doc_ids, pre_rerank_doc_ids = self.predict( to_predict=query_dataset[query_text_column], doc_ids_only=True ) run_dict = convert_predictions_to_pytrec_format( - predicted_doc_ids, query_dataset, id_column="_id" if self.args.data_format == "beir" else "query_id" + predicted_doc_ids, + query_dataset, + id_column="_id" if self.args.data_format == "beir" else "query_id", ) qrels_dict = convert_qrels_dataset_to_pytrec_format(qrels_dataset) @@ -1502,7 +1530,8 @@ def eval_model( except: # Convert run_dict keys to strings run_dict = { - str(key): {str(k): v for k, v in value.items()} for key, value in run_dict.items() + str(key): {str(k): v for k, v in value.items()} + for key, value in run_dict.items() } results = evaluator.evaluate(run_dict) @@ -2299,7 +2328,9 @@ def retrieve_docs_from_query_embeddings( ) ): ids, vectors, *_ = passage_dataset.get_top_docs( - query_embeddings_retr.astype(np.float32), retrieve_n_docs, return_indices=False, + query_embeddings_retr.astype(np.float32), + retrieve_n_docs, + return_indices=False, ) doc_ids_batched[ i * args.retrieval_batch_size : (i * args.retrieval_batch_size) @@ -2985,6 +3016,20 @@ def _calculate_loss( correct_predictions_count / len(nll_labels) ) * 100 + if self.args.unified_cross_rr: + rerank_max_score, rerank_max_idxs = torch.max(reranking_softmax_score, 1) + rerank_correct_predictions_count = ( + (rerank_max_idxs == torch.tensor(nll_labels)) + .sum() + .cpu() + .detach() + .numpy() + .item() + ) + rerank_correct_predictions_percentage = ( + rerank_correct_predictions_count / len(nll_labels) + ) * 100 + if self.args.kl_div_loss or self.args.margin_mse_loss: colbert_softmax_score = torch.nn.functional.softmax(label_scores, dim=-1) colbert_max_score, colbert_max_idxs = torch.max(colbert_softmax_score, 1) @@ -3013,6 +3058,7 @@ def _calculate_loss( reranking_loss=reranking_loss.item() if reranking_loss else None, nll_loss=nll_loss.item(), colbert_correct_predictions_percentage=colbert_correct_predictions_percentage, + reranking_correct_predictions_percentage=rerank_correct_predictions_percentage, ) return retrieval_output @@ -3042,16 +3088,16 @@ def _rerank_passages(self, query_outputs, context_outputs, is_evaluating=False): reranking_model_token_type_ids = torch.zeros_like( reranking_model_inputs_embeds[:, :, 0] ) - reranking_model_token_type_ids[:, 0] = torch.ones( - (query_outputs.size(0)) - ) + reranking_model_token_type_ids[:, 0] = torch.ones((query_outputs.size(0))) else: reranking_model_inputs_embeds = torch.zeros( (query_outputs.size(0), self.args.max_seq_length, query_outputs.size(1)) ) reranking_model_inputs_embeds[:, 0, :] = query_outputs - reranking_model_inputs_embeds[:, 1 : context_outputs.size(0) + 1, :] = context_outputs + reranking_model_inputs_embeds[ + :, 1 : context_outputs.size(0) + 1, : + ] = context_outputs reranking_model_attention_mask = torch.zeros( (query_outputs.size(0), self.args.max_seq_length) @@ -3065,9 +3111,7 @@ def _rerank_passages(self, query_outputs, context_outputs, is_evaluating=False): reranking_model_token_type_ids = torch.zeros( (query_outputs.size(0), self.args.max_seq_length) ) - reranking_model_token_type_ids[:, 0] = torch.ones( - (query_outputs.size(0)) - ) + reranking_model_token_type_ids[:, 0] = torch.ones((query_outputs.size(0))) reranking_model_inputs = { "inputs_embeds": reranking_model_inputs_embeds.to(self.device), @@ -3078,10 +3122,13 @@ def _rerank_passages(self, query_outputs, context_outputs, is_evaluating=False): reranking_outputs = self.reranking_model(**reranking_model_inputs) reranking_query_outputs = reranking_outputs[0][:, 0, :] - reranking_context_outputs = reranking_outputs[0][:, 1: context_outputs.size(1 if is_evaluating else 0) + 1, :] + reranking_context_outputs = reranking_outputs[0][ + :, 1 : context_outputs.size(1 if is_evaluating else 0) + 1, : + ] reranking_dot_score = torch.bmm( - reranking_query_outputs.unsqueeze(1), reranking_context_outputs.transpose(-1, -2) + reranking_query_outputs.unsqueeze(1), + reranking_context_outputs.transpose(-1, -2), ) reranking_softmax_score = torch.nn.functional.log_softmax( reranking_dot_score.squeeze(1), dim=-1 diff --git a/simpletransformers/retrieval/retrieval_tools.py b/simpletransformers/retrieval/retrieval_tools.py index b95b0d1b..c2ae5050 100644 --- a/simpletransformers/retrieval/retrieval_tools.py +++ b/simpletransformers/retrieval/retrieval_tools.py @@ -82,7 +82,11 @@ def generate_latex_row( row += f"\\textbf{{{results[model_name][dataset_name][metric]:.3f}}}\\rlap{{\\textsuperscript{{*}}}} & " else: row += f"\\textbf{{{results[model_name][dataset_name][metric]:.3f}}} & " - elif model_name == second_best_model_name and best_model_name != second_best_model_name and len(all_scores) > 2: + elif ( + model_name == second_best_model_name + and best_model_name != second_best_model_name + and len(all_scores) > 2 + ): row += f"\\textit{{{results[model_name][dataset_name][metric]:.3f}}} & " else: diff --git a/simpletransformers/retrieval/retrieval_utils.py b/simpletransformers/retrieval/retrieval_utils.py index b3ed8a32..1367e04d 100644 --- a/simpletransformers/retrieval/retrieval_utils.py +++ b/simpletransformers/retrieval/retrieval_utils.py @@ -336,7 +336,9 @@ def preprocess_batch_for_hf_dataset( } -def get_output_embeddings(embeddings, concatenate_embeddings=False, n_cls_tokens=3, use_pooler_output=False): +def get_output_embeddings( + embeddings, concatenate_embeddings=False, n_cls_tokens=3, use_pooler_output=False +): """ Extracts the embeddings from the output of the model. Concatenates CLS embeddings if concatenate_embeddings is True. @@ -910,7 +912,9 @@ def get_doc_dicts(self, doc_ids): for i in tqdm(range(doc_ids.shape[0]), desc="Retrieving doc dicts") ] - def get_top_docs(self, question_hidden_states, n_docs=5, passages_only=False, return_indices=True): + def get_top_docs( + self, question_hidden_states, n_docs=5, passages_only=False, return_indices=True + ): if passages_only: _, docs = self.dataset.get_nearest_examples_batch( "embeddings", question_hidden_states, n_docs @@ -1652,7 +1656,9 @@ def select_batch_from_pandas(batch, df): return ClusteredDataset(batch_datasets, len(clustered_batches)) -def load_trec_file(file_name, data_dir=None, header=False, loading_qrels=False, data_format=None): +def load_trec_file( + file_name, data_dir=None, header=False, loading_qrels=False, data_format=None +): if data_dir: if loading_qrels: if os.path.exists(os.path.join(data_dir, "qrels", f"{file_name}.tsv")): @@ -1744,7 +1750,13 @@ def load_trec_format( if not skip_passages: collection = load_trec_file("corpus", data_dir, collection_header) queries = load_trec_file("queries", data_dir, queries_header) - qrels = load_trec_file(qrels_name, data_dir, qrels_header, loading_qrels=True, data_format=data_format) + qrels = load_trec_file( + qrels_name, + data_dir, + qrels_header, + loading_qrels=True, + data_format=data_format, + ) else: if not collection_path or not queries_path or not qrels_path: @@ -1759,7 +1771,11 @@ def load_trec_format( # Also check if an index exists - return None if skip_passages else collection["train"], queries["train"], qrels["train"] + return ( + None if skip_passages else collection["train"], + queries["train"], + qrels["train"], + ) def convert_beir_columns_to_trec_format( @@ -1841,9 +1857,7 @@ def embed_passages_trec_format( logger.info("Generating embeddings for evaluation passages completed.") if args.save_passage_dataset: - output_dataset_directory = os.path.join( - args.output_dir, "passage_dataset" - ) + output_dataset_directory = os.path.join(args.output_dir, "passage_dataset") os.makedirs(output_dataset_directory, exist_ok=True) passage_dataset.save_to_disk(output_dataset_directory) @@ -1888,6 +1902,7 @@ def __init__( reranking_loss=None, nll_loss=None, colbert_correct_predictions_percentage=None, + reranking_correct_predictions_percentage=None, ): self.loss = loss self.context_outputs = context_outputs @@ -1901,6 +1916,9 @@ def __init__( self.colbert_correct_predictions_percentage = ( colbert_correct_predictions_percentage ) + self.reranking_correct_predictions_percentage = ( + reranking_correct_predictions_percentage + ) class MarginMSELoss(nn.Module):