-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in #6813 (comment) * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
- Loading branch information
1 parent
1ee2194
commit c754c41
Showing
37 changed files
with
5,176 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
RAG | ||
---------------------------------------------------- | ||
|
||
Overview | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models. | ||
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs. | ||
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks. | ||
|
||
It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. | ||
|
||
The abstract from the paper is the following: | ||
|
||
*Large pre-trained language models have been shown to store factual knowledge | ||
in their parameters, and achieve state-of-the-art results when fine-tuned on | ||
downstream NLP tasks. However, their ability to access and precisely manipulate | ||
knowledge is still limited, and hence on knowledge-intensive tasks, their | ||
performance lags behind task-specific architectures. Additionally, providing | ||
provenance for their decisions and updating their world knowledge remain open | ||
research problems. Pre-trained models with a differentiable access mechanism to | ||
explicit nonparametric memory can overcome this issue, but have so far been only | ||
investigated for extractive downstream tasks. We explore a general-purpose | ||
fine-tuning recipe for retrieval-augmented generation (RAG) — models which combine | ||
pre-trained parametric and non-parametric memory for language generation. We | ||
introduce RAG models where the parametric memory is a pre-trained seq2seq model and | ||
the non-parametric memory is a dense vector index of Wikipedia, accessed with | ||
a pre-trained neural retriever. We compare two RAG formulations, one which | ||
conditions on the same retrieved passages across the whole generated sequence, the | ||
other can use different passages per token. We fine-tune and evaluate our models | ||
on a wide range of knowledge-intensive NLP tasks and set the state-of-the-art | ||
on three open domain QA tasks, outperforming parametric seq2seq models and | ||
task-specific retrieve-and-extract architectures. For language generation tasks, we | ||
find that RAG models generate more specific, diverse and factual language than a | ||
state-of-the-art parametric-only seq2seq baseline.* | ||
|
||
|
||
|
||
RagConfig | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.RagConfig | ||
:members: | ||
|
||
|
||
RagTokenizer | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.RagTokenizer | ||
:members: | ||
|
||
|
||
Rag specific outputs | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.modeling_rag.RetrievAugLMMarginOutput | ||
:members: | ||
|
||
.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput | ||
:members: | ||
|
||
|
||
RAGRetriever | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.RagRetriever | ||
:members: | ||
|
||
|
||
RagModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.RagModel | ||
:members: forward | ||
|
||
|
||
RagSequenceForGeneration | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.RagSequenceForGeneration | ||
:members: forward, generate | ||
|
||
|
||
RagTokenForGeneration | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.RagTokenForGeneration | ||
:members: forward, generate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Intro | ||
RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator. | ||
During a forward pass, we encode the input with the question encoder and pass it | ||
to the retriever to extract relevant context documents. The documents are then prepended to the input. | ||
Such contextualized inputs is passed to the generator. | ||
|
||
The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`. | ||
|
||
The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details. | ||
The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``. | ||
The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``. | ||
|
||
RAG models were released with the paper `Retrieval-Augmented Generation for | ||
Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al. | ||
|
||
|
||
# Finetuning | ||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). | ||
Follow instructions there regarding data preprocessing. A sample finetuning command: | ||
|
||
``` | ||
python examples/rag/finetune.py \ | ||
--data_dir $DATA_DIR \ | ||
--output_dir $OUTPUT_DIR \ | ||
--model_name_or_path $MODEL_NAME_OR_PATH \ | ||
--model_type rag_sequence \ | ||
--fp16 \ | ||
--gpus 8 | ||
``` | ||
|
||
|
||
# Evaluation | ||
Apart from the parameters specifying the model to evaluate and some extra parameters, the evaluation script expects paths to two files: | ||
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single datapoint per line, e.g. | ||
```who is the owner of reading football club``` | ||
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`. | ||
|
||
We expect the following formats of the gold data file: | ||
|
||
- for e2e evaluation, we support two formats of the gold file: | ||
- `qa` - where a single line in the following format: input [tab] output_list, e.g.: | ||
``` | ||
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai'] | ||
``` | ||
- `ans` - where a single line of the gold file contains the expected output string, e.g.: | ||
``` | ||
Xiu Li Dai | ||
``` | ||
- for retrieval evaluation, we expect a tab-separated list of Wikipedia page titles constituting positive contexts for a given query, e.g. given a question `who sings does he love me with reba`, a line with ground truth retrieval data could look as follows: | ||
``` | ||
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album) | ||
``` | ||
## Retrieval evaluation | ||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45). | ||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz. | ||
2. Parse the unziped file using the `parse_dpr_relevance_data.py` | ||
``` | ||
python examples/rag/parse_dpr_relevance_data.py --src_path path/to/unziped/biencoder-nq-dev.json --evaluation_set path/to/output/biencoder-nq-dev.questions --gold_data_path path/to/output/biencoder-nq-dev.pages | ||
``` | ||
3. Run evaluation: | ||
``` | ||
python examples/rag/eval_rag.py \ | ||
--model_name_or_path $MODEL_NAME_OR_PATH \ # model name or path of the model we're evaluating | ||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence) | ||
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation | ||
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set | ||
--predictions_path path/to/retrieval_preds.tsv \ # name of file in which predictions will be stored | ||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation | ||
--recalculate # if predictions_filename already exists, and this option is set - we regenerate the answers, otherwise we reuse the predicsion file to calculate metrics. | ||
``` | ||
## End-to-end evaluation | ||
``` | ||
python examples/rag/eval_rag.py \ | ||
--model_name_or_path $MODEL_NAME_OR_PATH \ | ||
--model_type rag_sequence \ | ||
--evaluation_set path/to/test.source \ | ||
--gold_data_path path/to/gold_data \ | ||
--predictions_path path/to/e2e_preds.txt \ | ||
--eval_mode e2e \ # indicates whether we're performing retrieval evaluation or e2e evaluation (default) | ||
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time | ||
--print_predictions | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import logging | ||
import os | ||
|
||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_checkpoint_callback(output_dir, metric): | ||
"""Saves the best model by validation EM score.""" | ||
if metric == "rouge2": | ||
exp = "{val_avg_rouge2:.4f}-{step_count}" | ||
elif metric == "bleu": | ||
exp = "{val_avg_bleu:.4f}-{step_count}" | ||
elif metric == "em": | ||
exp = "{val_avg_em:.4f}-{step_count}" | ||
else: | ||
raise NotImplementedError( | ||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." | ||
) | ||
|
||
checkpoint_callback = ModelCheckpoint( | ||
filepath=os.path.join(output_dir, exp), | ||
monitor=f"val_{metric}", | ||
mode="max", | ||
save_top_k=3, | ||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch. | ||
) | ||
return checkpoint_callback |
Oops, something went wrong.