Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ThilinaRajapakse committed Oct 26, 2023
1 parent cb1dc41 commit 969d206
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 36 deletions.
4 changes: 3 additions & 1 deletion simpletransformers/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ class RetrievalArgs(Seq2SeqArgs):
moving_average_loss_count: int = 10
nll_lambda: float = 1.0
output_dropout: float = 0.1
pytrec_eval_metrics: list = field(default_factory=lambda: ["recip_rank", "recall_100", "ndcg_cut_10", "ndcg"])
pytrec_eval_metrics: list = field(
default_factory=lambda: ["recip_rank", "recall_100", "ndcg_cut_10", "ndcg"]
)
query_config: dict = field(default_factory=dict)
remove_duplicates_from_eval_passages: bool = False
repeat_high_loss_n: int = 0
Expand Down
4 changes: 3 additions & 1 deletion simpletransformers/retrieval/pytrec_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
logger = logging.getLogger(__name__)


def convert_predictions_to_pytrec_format(predicted_doc_ids, query_dataset, id_column="_id"):
def convert_predictions_to_pytrec_format(
predicted_doc_ids, query_dataset, id_column="_id"
):
run_dict = {}
for query_id, doc_ids in zip(query_dataset[id_column], predicted_doc_ids):
run_dict[query_id] = {doc_id: 1 / (i + 1) for i, doc_id in enumerate(doc_ids)}
Expand Down
97 changes: 72 additions & 25 deletions simpletransformers/retrieval/retrieval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,14 +837,20 @@ def train(
colbert_percentage = (
retrieval_output.colbert_correct_predictions_percentage
)
reranking_correct_predictions_percentage = (
retrieval_output.reranking_correct_predictions_percentage
)

if args.n_gpu > 1:
loss = loss.mean()

# Compare the current loss to the moving average loss
current_loss = loss.item()

if (args.repeat_high_loss_n == 0 or moving_loss.size() < args.moving_average_loss_count):
if (
args.repeat_high_loss_n == 0
or moving_loss.size() < args.moving_average_loss_count
):
break

if current_loss > moving_loss.get_average_loss():
Expand Down Expand Up @@ -884,9 +890,14 @@ def train(
moving_loss.add_loss(current_loss)

if show_running_loss and (args.kl_div_loss or args.margin_mse_loss):
batch_iterator.set_description(
f"Epochs {epoch_number + 1}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f} Correct percentage: {correct_predictions_percentage:4.1f} Colbert percentage: {colbert_percentage:4.1f}"
)
if args.unified_cross_rr:
batch_iterator.set_description(
f"Epochs {epoch_number + 1}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f} Correct percentage: {correct_predictions_percentage:4.1f} Colbert percentage: {colbert_percentage:4.1f} Reranking correct percentage: {reranking_correct_predictions_percentage:4.1f}"
)
else:
batch_iterator.set_description(
f"Epochs {epoch_number + 1}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f} Correct percentage: {correct_predictions_percentage:4.1f} Colbert percentage: {colbert_percentage:4.1f}"
)
elif show_running_loss:
if args.repeat_high_loss_n > 0:
batch_iterator.set_description(
Expand Down Expand Up @@ -950,13 +961,23 @@ def train(
}
else:
if args.kl_div_loss or args.margin_mse_loss:
logging_dict = {
"Training loss": current_loss,
"lr": scheduler.get_last_lr()[0],
"global_step": global_step,
"correct_predictions_percentage": correct_predictions_percentage,
"colbert_correct_predictions_percentage": colbert_percentage,
}
if args.unified_cross_rr:
logging_dict = {
"Training loss": current_loss,
"lr": scheduler.get_last_lr()[0],
"global_step": global_step,
"correct_predictions_percentage": correct_predictions_percentage,
"colbert_correct_predictions_percentage": colbert_percentage,
"reranking_correct_predictions_percentage": reranking_correct_predictions_percentage,
}
else:
logging_dict = {
"Training loss": current_loss,
"lr": scheduler.get_last_lr()[0],
"global_step": global_step,
"correct_predictions_percentage": correct_predictions_percentage,
"colbert_correct_predictions_percentage": colbert_percentage,
}
if args.include_nll_loss:
logging_dict[
"nll_loss"
Expand Down Expand Up @@ -1459,7 +1480,10 @@ def eval_model(
)
else:
_, query_dataset, qrels_dataset = load_trec_format(
eval_data, qrels_name=eval_set, data_format=self.args.data_format, skip_passages=True
eval_data,
qrels_name=eval_set,
data_format=self.args.data_format,
skip_passages=True,
)

if self.args.data_format == "beir":
Expand All @@ -1481,14 +1505,18 @@ def eval_model(
)
self.prediction_passages = passage_index

query_text_column = "text" if self.args.data_format == "beir" else "query_text"
query_text_column = (
"text" if self.args.data_format == "beir" else "query_text"
)

predicted_doc_ids, pre_rerank_doc_ids = self.predict(
to_predict=query_dataset[query_text_column], doc_ids_only=True
)

run_dict = convert_predictions_to_pytrec_format(
predicted_doc_ids, query_dataset, id_column="_id" if self.args.data_format == "beir" else "query_id"
predicted_doc_ids,
query_dataset,
id_column="_id" if self.args.data_format == "beir" else "query_id",
)
qrels_dict = convert_qrels_dataset_to_pytrec_format(qrels_dataset)

Expand All @@ -1502,7 +1530,8 @@ def eval_model(
except:
# Convert run_dict keys to strings
run_dict = {
str(key): {str(k): v for k, v in value.items()} for key, value in run_dict.items()
str(key): {str(k): v for k, v in value.items()}
for key, value in run_dict.items()
}
results = evaluator.evaluate(run_dict)

Expand Down Expand Up @@ -2299,7 +2328,9 @@ def retrieve_docs_from_query_embeddings(
)
):
ids, vectors, *_ = passage_dataset.get_top_docs(
query_embeddings_retr.astype(np.float32), retrieve_n_docs, return_indices=False,
query_embeddings_retr.astype(np.float32),
retrieve_n_docs,
return_indices=False,
)
doc_ids_batched[
i * args.retrieval_batch_size : (i * args.retrieval_batch_size)
Expand Down Expand Up @@ -2985,6 +3016,20 @@ def _calculate_loss(
correct_predictions_count / len(nll_labels)
) * 100

if self.args.unified_cross_rr:
rerank_max_score, rerank_max_idxs = torch.max(reranking_softmax_score, 1)
rerank_correct_predictions_count = (
(rerank_max_idxs == torch.tensor(nll_labels))
.sum()
.cpu()
.detach()
.numpy()
.item()
)
rerank_correct_predictions_percentage = (
rerank_correct_predictions_count / len(nll_labels)
) * 100

if self.args.kl_div_loss or self.args.margin_mse_loss:
colbert_softmax_score = torch.nn.functional.softmax(label_scores, dim=-1)
colbert_max_score, colbert_max_idxs = torch.max(colbert_softmax_score, 1)
Expand Down Expand Up @@ -3013,6 +3058,7 @@ def _calculate_loss(
reranking_loss=reranking_loss.item() if reranking_loss else None,
nll_loss=nll_loss.item(),
colbert_correct_predictions_percentage=colbert_correct_predictions_percentage,
reranking_correct_predictions_percentage=rerank_correct_predictions_percentage,
)

return retrieval_output
Expand Down Expand Up @@ -3042,16 +3088,16 @@ def _rerank_passages(self, query_outputs, context_outputs, is_evaluating=False):
reranking_model_token_type_ids = torch.zeros_like(
reranking_model_inputs_embeds[:, :, 0]
)
reranking_model_token_type_ids[:, 0] = torch.ones(
(query_outputs.size(0))
)
reranking_model_token_type_ids[:, 0] = torch.ones((query_outputs.size(0)))
else:
reranking_model_inputs_embeds = torch.zeros(
(query_outputs.size(0), self.args.max_seq_length, query_outputs.size(1))
)

reranking_model_inputs_embeds[:, 0, :] = query_outputs
reranking_model_inputs_embeds[:, 1 : context_outputs.size(0) + 1, :] = context_outputs
reranking_model_inputs_embeds[
:, 1 : context_outputs.size(0) + 1, :
] = context_outputs

reranking_model_attention_mask = torch.zeros(
(query_outputs.size(0), self.args.max_seq_length)
Expand All @@ -3065,9 +3111,7 @@ def _rerank_passages(self, query_outputs, context_outputs, is_evaluating=False):
reranking_model_token_type_ids = torch.zeros(
(query_outputs.size(0), self.args.max_seq_length)
)
reranking_model_token_type_ids[:, 0] = torch.ones(
(query_outputs.size(0))
)
reranking_model_token_type_ids[:, 0] = torch.ones((query_outputs.size(0)))

reranking_model_inputs = {
"inputs_embeds": reranking_model_inputs_embeds.to(self.device),
Expand All @@ -3078,10 +3122,13 @@ def _rerank_passages(self, query_outputs, context_outputs, is_evaluating=False):
reranking_outputs = self.reranking_model(**reranking_model_inputs)

reranking_query_outputs = reranking_outputs[0][:, 0, :]
reranking_context_outputs = reranking_outputs[0][:, 1: context_outputs.size(1 if is_evaluating else 0) + 1, :]
reranking_context_outputs = reranking_outputs[0][
:, 1 : context_outputs.size(1 if is_evaluating else 0) + 1, :
]

reranking_dot_score = torch.bmm(
reranking_query_outputs.unsqueeze(1), reranking_context_outputs.transpose(-1, -2)
reranking_query_outputs.unsqueeze(1),
reranking_context_outputs.transpose(-1, -2),
)
reranking_softmax_score = torch.nn.functional.log_softmax(
reranking_dot_score.squeeze(1), dim=-1
Expand Down
6 changes: 5 additions & 1 deletion simpletransformers/retrieval/retrieval_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def generate_latex_row(
row += f"\\textbf{{{results[model_name][dataset_name][metric]:.3f}}}\\rlap{{\\textsuperscript{{*}}}} & "
else:
row += f"\\textbf{{{results[model_name][dataset_name][metric]:.3f}}} & "
elif model_name == second_best_model_name and best_model_name != second_best_model_name and len(all_scores) > 2:
elif (
model_name == second_best_model_name
and best_model_name != second_best_model_name
and len(all_scores) > 2
):
row += f"\\textit{{{results[model_name][dataset_name][metric]:.3f}}} & "

else:
Expand Down
34 changes: 26 additions & 8 deletions simpletransformers/retrieval/retrieval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def preprocess_batch_for_hf_dataset(
}


def get_output_embeddings(embeddings, concatenate_embeddings=False, n_cls_tokens=3, use_pooler_output=False):
def get_output_embeddings(
embeddings, concatenate_embeddings=False, n_cls_tokens=3, use_pooler_output=False
):
"""
Extracts the embeddings from the output of the model.
Concatenates CLS embeddings if concatenate_embeddings is True.
Expand Down Expand Up @@ -910,7 +912,9 @@ def get_doc_dicts(self, doc_ids):
for i in tqdm(range(doc_ids.shape[0]), desc="Retrieving doc dicts")
]

def get_top_docs(self, question_hidden_states, n_docs=5, passages_only=False, return_indices=True):
def get_top_docs(
self, question_hidden_states, n_docs=5, passages_only=False, return_indices=True
):
if passages_only:
_, docs = self.dataset.get_nearest_examples_batch(
"embeddings", question_hidden_states, n_docs
Expand Down Expand Up @@ -1652,7 +1656,9 @@ def select_batch_from_pandas(batch, df):
return ClusteredDataset(batch_datasets, len(clustered_batches))


def load_trec_file(file_name, data_dir=None, header=False, loading_qrels=False, data_format=None):
def load_trec_file(
file_name, data_dir=None, header=False, loading_qrels=False, data_format=None
):
if data_dir:
if loading_qrels:
if os.path.exists(os.path.join(data_dir, "qrels", f"{file_name}.tsv")):
Expand Down Expand Up @@ -1744,7 +1750,13 @@ def load_trec_format(
if not skip_passages:
collection = load_trec_file("corpus", data_dir, collection_header)
queries = load_trec_file("queries", data_dir, queries_header)
qrels = load_trec_file(qrels_name, data_dir, qrels_header, loading_qrels=True, data_format=data_format)
qrels = load_trec_file(
qrels_name,
data_dir,
qrels_header,
loading_qrels=True,
data_format=data_format,
)

else:
if not collection_path or not queries_path or not qrels_path:
Expand All @@ -1759,7 +1771,11 @@ def load_trec_format(

# Also check if an index exists

return None if skip_passages else collection["train"], queries["train"], qrels["train"]
return (
None if skip_passages else collection["train"],
queries["train"],
qrels["train"],
)


def convert_beir_columns_to_trec_format(
Expand Down Expand Up @@ -1841,9 +1857,7 @@ def embed_passages_trec_format(
logger.info("Generating embeddings for evaluation passages completed.")

if args.save_passage_dataset:
output_dataset_directory = os.path.join(
args.output_dir, "passage_dataset"
)
output_dataset_directory = os.path.join(args.output_dir, "passage_dataset")
os.makedirs(output_dataset_directory, exist_ok=True)
passage_dataset.save_to_disk(output_dataset_directory)

Expand Down Expand Up @@ -1888,6 +1902,7 @@ def __init__(
reranking_loss=None,
nll_loss=None,
colbert_correct_predictions_percentage=None,
reranking_correct_predictions_percentage=None,
):
self.loss = loss
self.context_outputs = context_outputs
Expand All @@ -1901,6 +1916,9 @@ def __init__(
self.colbert_correct_predictions_percentage = (
colbert_correct_predictions_percentage
)
self.reranking_correct_predictions_percentage = (
reranking_correct_predictions_percentage
)


class MarginMSELoss(nn.Module):
Expand Down

0 comments on commit 969d206

Please sign in to comment.