Skip to content

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jul 15, 2020

This PR adds Longformer

In a first step, it is made sure that code is clean and that all tests pass. Todo:

ToDo List:

  • same output for local attention only
  • same output for local + global attention only
  • same output for aggressive test
  • add all other tests
  • add longformer QA
  • refactor code and run benchmark
  • adds weights to all QA models and check performance via notebook

ToDo after PR is merged:

  • Add Longformer for SeqClass, MC, ... ("good first issue")
  • Speed up performance and make GPU XLA work -> use Benchmark tools and possible TF Profiler

For Review

This PR adds TFLongformer and the two most important parent classes TFLongformerForMaskedLM and TFLongformerForQuestionAnswering. Many tests are added to verify that TFLongformer gives identical results to PT Longformer and a colab notebook (see below) is attached to show performance on a real task.

Below you can find a Benchmark showing that TFLongformer is about 1.5x slower than PT on GPU. For now this is acceptable IMO, but in a future PR I want to take a deeper look at how TF code can be optimized and also solve a problem there is currently with TF XLA.

I spent a lot of time, trying to solve this issue: #5815 for TFLongformer and didn't manage to find a good solution. The corresponding tests are in SLOW mode so they won't fail on this PR. Since we are currently thinking about a better solution than using cast_bool_to_primitive to solve the know TF graph boolean error, I think I will leave this small bug in TFLongformer for now (it's quite an edge IMO anymays).

Docs are added and checked, comments are added, performance on TriviaQA is verified in TF colab: https://colab.research.google.com/drive/1UmU3T1nPmJ2LgXQtPcaEVtXUnoBJ1onF?usp=sharing and TF weights were added to all longformer models here: https://huggingface.co/models?search=longformer.

Would be happy about a review @jplu @ibeltagy @LysandreJik @sshleifer @sgugger @julien-c @thomwolf

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Aug 6, 2020

In RUN_SLOW=1 the new tests: test_saved_model_with_attentions_outputs and test_saved_model_with_hidden_states_output fail @jplu from #5468. The problem is that I have to use tf.cond(...) and it seems like this forces me to also use cast_bool_.... Not sure if you have any ideas on how to fix this @jplu .

@jplu
Copy link
Contributor

jplu commented Aug 6, 2020

Yes, It is still an issue with the AutoGraph thing 😢 I suggest to comment them for now.

@jplu
Copy link
Contributor

jplu commented Aug 6, 2020

I'm currently thinking on how to properly rework all the booleans handling in TF. As this is the main issue.

@patrickvonplaten
Copy link
Contributor Author

Ok...will leave it for now since the test are only in RUN_SLOW mode, so they won't show up in Circle CI

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Aug 6, 2020

Confirmed that this PR will not slow down the Longformer PyTorch version.

Running the following benchmark on master:

python examples/benchmarking/run_benchmark.py --models allenai/longformer-base-4096 --no_memory --sequence_length 512 1024

gives same performance is in as in #5811.

@patrickvonplaten
Copy link
Contributor Author

Benchmarking the model in TF leads to a slow-down vs. PyTorch of ca. 1.5, which is reasonable IMO:

Running:

python examples/benchmarking/run_benchmark_tf.py --models allenai/longformer-base-4096 --no_memory --sequence_length 512 1024

gives:

                                                                                                                                                                                                                   
====================       INFERENCE - SPEED - RESULT       ====================                                                                                                                                   
--------------------------------------------------------------------------------                                                                                                                                   
          Model Name             Batch Size     Seq Length     Time in s                                                                                                                                           
--------------------------------------------------------------------------------                                                                                                                                   
 allenai/longformer-base-4096        8              512            0.226                                                                                                                                           
 allenai/longformer-base-4096        8              1024           0.446                                                                                                                                           
--------------------------------------------------------------------------------                                                                                                                                   

====================        ENVIRONMENT INFORMATION         ====================
- transformers_version: 3.0.2
- framework: TensorFlow
- eager_mode: False
- use_xla: False
- framework_version: 2.3.0
- python_version: 3.6.10
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-08-06
- time: 15:55:55.754100
- fp16: False
- use_multiprocessing: True
- only_pretrain_model: False
- cpu_ram_mb: 32088
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 0
- use_tpu: False

At the moment running the model in XLA on GPU fails...=> should take a closer look in a next PR.

@codecov
Copy link

codecov bot commented Aug 6, 2020

Codecov Report

Merging #5764 into master will increase coverage by 0.05%.
The diff coverage is 24.52%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5764      +/-   ##
==========================================
+ Coverage   79.33%   79.38%   +0.05%     
==========================================
  Files         148      149       +1     
  Lines       27196    27670     +474     
==========================================
+ Hits        21577    21967     +390     
- Misses       5619     5703      +84     
Impacted Files Coverage Δ
src/transformers/modeling_tf_longformer.py 15.76% <15.76%> (ø)
src/transformers/modeling_longformer.py 89.76% <92.98%> (+0.54%) ⬆️
src/transformers/__init__.py 99.25% <100.00%> (+<0.01%) ⬆️
src/transformers/modeling_tf_auto.py 66.86% <100.00%> (+0.20%) ⬆️
src/transformers/modeling_tf_flaubert.py 24.53% <0.00%> (-63.20%) ⬇️
src/transformers/tokenization_bert.py 92.41% <0.00%> (+1.33%) ⬆️
src/transformers/tokenization_openai.py 84.09% <0.00%> (+1.51%) ⬆️
src/transformers/tokenization_utils.py 91.60% <0.00%> (+1.59%) ⬆️
... and 7 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1b8a7ff...41cb64f. Read the comment docs.

@codeninja
Copy link

Where do I send the pizza & beer to get this merged?

@patrickvonplaten patrickvonplaten changed the title [WIP - TF Longformer] Add TF Longformer TF Longformer Aug 7, 2020
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks good to me! Looking forward to the post-PR upgrades as well!

Comment on lines 573 to 600
global_query_vectors_only_global /= math.sqrt(self.head_dim)
global_query_vectors_only_global /= torch.sqrt(
torch.tensor(
self.head_dim,
device=global_query_vectors_only_global.device,
dtype=global_query_vectors_only_global.dtype,
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is the creation of a tensor + pytorch operation faster than the previous math operation on a tensor ?

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noting that, I actually didn't time it and thought it's nice to get rid of the math library.
From these tests on a GPU: https://colab.research.google.com/drive/1q830_JmCLyxmlQCTBxe3Rj2G3pZuniCY?usp=sharing it seems that it's actually slower to first create a tensor and use torch.sqrt in PyTorch. In TF it seems to be equally fast on GPU and a bit faster on CPU (in tf.function mode) in eager model math.sqrt is faster.

So for PyTorch I will revert the change and I guess and I have to update my beliefs a bit about "doing everything in tensors".

Is this obvious to you @LysandreJik that using math is faster? I would have assumed that under the hood a tensor is created anyways....do you have an idea why math is faster?

Also cc @mfuntowicz, maybe you have a good opinion on that :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it isn't obvious to me at all! I would have thought that dividing a tensor by a python value would create a tensor for the python value, resulting in the same number of operations. I think it's cleaner to only use torch like functions too, but from your benchmarking it seems that it's not always the case!

Thanks for diving into it, really useful!

Comment on lines 1319 to 1348
if global_attention_mask is None:
if global_attention_mask is None and input_ids is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Curious) The global_attention_mask doesn't need to be computed when passing in inputs_embeds?

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! In this case they cannot really be set automatically since its hard to trace back where the sep_token_id was in input_embeds -> will change it and display a warning in this case.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you'd prefer that I do the return_dict part. I think it would be good to have it in the PR before merging but I can push it on this branch.

@codeninja
Copy link

I'm curious if you think the longformer should be added to the language_generation_model?

@patrickvonplaten
Copy link
Contributor Author

@sgugger - I added the return_dict functionality and adapted the doc strings -> the new docstring functions are awesome!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes, @patrickvonplaten!

@patrickvonplaten
Copy link
Contributor Author

I'm curious if you think the longformer should be added to the language_generation_model?

We will add it to the EncoderDecoderModel framework, where it can be used with generate()

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left some nits. Nice tests!

This class overrides :class:`~transformers.RobertaModel` to provide the ability to process
long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer selfattention
This class copied code from :class:`~transformers.RobertaModel` and overwrote standard self-attention with longformer self-attention to provide the ability to process
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nittiest of all nits)

present tense verbs for docstrings.
copies, overwrites. The class is a thing that exists in the present, at least from the reader's perspective.

Copy link
Contributor

@ibeltagy ibeltagy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Thanks, @patrickvonplaten.

@patrickvonplaten patrickvonplaten merged commit 00bb0b2 into huggingface:master Aug 10, 2020
@patrickvonplaten patrickvonplaten deleted the longformer_tf branch August 11, 2020 08:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants