Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFDPR #8203

Merged
merged 27 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
28e589c
Merge pull request #1 from huggingface/master
ratthachat Oct 30, 2020
08ac715
Create modeling_tf_dpr.py
ratthachat Oct 30, 2020
767ce1f
Add TFDPR
ratthachat Oct 30, 2020
4590576
Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot
ratthachat Oct 30, 2020
30cc105
Add TFDPR
ratthachat Oct 30, 2020
94f8226
Add TFDPR
ratthachat Oct 31, 2020
33c8268
clean up some comments, add TF input-style doc string
ratthachat Oct 31, 2020
3ef6d9d
Add TFDPR
ratthachat Oct 31, 2020
44e7399
Make return_dict=False as default
ratthachat Nov 4, 2020
70ccaed
Fix return_dict bug (in .from_pretrained)
ratthachat Nov 4, 2020
94c72a8
Add get_input_embeddings()
ratthachat Nov 4, 2020
0c81407
Create test_modeling_tf_dpr.py
ratthachat Nov 4, 2020
a0da436
fix quality
patrickvonplaten Nov 10, 2020
8a54032
Merge remote-tracking branch 'main/master'
patrickvonplaten Nov 10, 2020
e4e8a9e
delete init weights
patrickvonplaten Nov 10, 2020
a01cb28
run fix copies
patrickvonplaten Nov 10, 2020
bcbe6a8
fix repo consis
patrickvonplaten Nov 10, 2020
ccfd0f4
del config_class, load_tf_weights
ratthachat Nov 11, 2020
9ab698a
add config_class back
ratthachat Nov 11, 2020
a619f6b
newline after .. note::
ratthachat Nov 11, 2020
7525403
import tf, np (Necessary for ModelIntegrationTest)
ratthachat Nov 11, 2020
4c2d085
slow_test from_pretrained with from_pt=True
ratthachat Nov 11, 2020
e029755
Add simple TFDPRModelIntegrationTest
ratthachat Nov 11, 2020
619b271
Merge pull request #2 from ratthachat/tf-dpr-finalize
ratthachat Nov 11, 2020
df6decc
Merge branch 'master' into master
patrickvonplaten Nov 11, 2020
735febe
upload correct tf model
Nov 11, 2020
6dfefc6
remove position_ids as missing keys
patrickvonplaten Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/model_doc/dpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,22 @@ DPRReader

.. autoclass:: transformers.DPRReader
:members: forward

TFDPRContextEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDPRContextEncoder
:members: call

TFDPRQuestionEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDPRQuestionEncoder
:members: call


TFDPRReader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDPRReader
:members: call
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@
DistilBertPreTrainedModel,
)
from .modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPretrainedQuestionEncoder,
Expand Down Expand Up @@ -713,6 +716,17 @@
TFDistilBertModel,
TFDistilBertPreTrainedModel,
)
from .modeling_tf_dpr import (
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDPRContextEncoder,
TFDPRPretrainedContextEncoder,
TFDPRPretrainedQuestionEncoder,
TFDPRPretrainedReader,
TFDPRQuestionEncoder,
TFDPRReader,
)
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand All @@ -43,6 +46,7 @@
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
FlaubertConfig,
GPT2Config,
Expand All @@ -59,6 +63,9 @@
TFCTRLLMHeadModel,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
TFElectraForPreTraining,
TFFlaubertWithLMHeadModel,
TFGPT2LMHeadModel,
Expand Down Expand Up @@ -98,6 +105,9 @@
CTRLLMHeadModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
ElectraForPreTraining,
FlaubertWithLMHeadModel,
GPT2LMHeadModel,
Expand Down Expand Up @@ -147,6 +157,18 @@
BertForSequenceClassification,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"dpr": (
DPRConfig,
TFDPRQuestionEncoder,
TFDPRContextEncoder,
TFDPRReader,
DPRQuestionEncoder,
DPRContextEncoder,
DPRReader,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
replace_list_option_in_docstrings,
)
from .configuration_blenderbot import BlenderbotConfig
from .configuration_dpr import DPRConfig
from .configuration_marian import MarianConfig
from .configuration_mbart import MBartConfig
from .configuration_pegasus import PegasusConfig
Expand Down Expand Up @@ -87,6 +88,7 @@
TFDistilBertForTokenClassification,
TFDistilBertModel,
)
from .modeling_tf_dpr import TFDPRQuestionEncoder
from .modeling_tf_electra import (
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
Expand Down Expand Up @@ -192,6 +194,7 @@
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel),
(DPRConfig, TFDPRQuestionEncoder),
]
)

Expand Down
Loading