Skip to content

Commit 652abb9

Browse files
Applied review suggestions for CamemBERT: restored API refs, added examples, badges, and attribution
1 parent b54d431 commit 652abb9

File tree

1 file changed

+130
-22
lines changed

1 file changed

+130
-22
lines changed
Lines changed: 130 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,143 @@
1-
# CamemBERT Base
1+
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
# CamemBERT
217
<div style="float: right;">
318
<div class="flex flex-wrap space-x-1">
419
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
520
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
622
</div>
723
</div>
824

9-
[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert) is a language model based on RoBERTa, but it was trained specifically on French text from the OSCAR dataset.
25+
[CamemBERT](https://huggingface.co/papers/1911.03894) is a language model based on [RoBERTa](./roberta), but trained specifically on French text from the OSCAR dataset, making it more effective for French language tasks.
1026

27+
## CamembertConfig
1128
What sets CamemBERT apart is that it learned from a huge, high quality collection of French data, as opposed to mixing lots of languages. This helps it really understand French better than many multilingual models.
1229

1330
Common applications of CamemBERT include masked language modeling (Fill-mask prediction), text classification (sentiment analysis), token classification (entity recognition) and sentence pair classification (entailment tasks).
1431

15-
You can find all the original CamemBERT checkpoints under the [CamemBERT](https://huggingface.co/models?search=camembert) collection.
32+
[[autodoc]] CamembertConfig
33+
## CamembertTokenizer
34+
You can find all the original CamemBERT checkpoints under the [ALMAnaCH](https://huggingface.co/almanach/models?search=camembert) organization.
1635

1736
> [!TIP]
18-
> This model was contributed by the [Facebook AI](https://huggingface.co/facebook) team.
37+
> This model was contributed by the [ALMAnaCH (Inria)](https://huggingface.co/almanach) team.
1938
>
2039
> Click on the CamemBERT models in the right sidebar for more examples of how to apply CamemBERT to different NLP tasks.
40+
[[autodoc]] CamembertTokenizer
41+
## CamembertTokenizerFast
2142

22-
The examples below demonstrate how to perform masked language modeling with `pipeline` or the `AutoModel` class.
43+
The examples below demonstrate how to predict the `<mask>` token with [`Pipeline`], [`AutoModel`], and from the command line.
2344

2445
<hfoptions id="usage">
2546

2647
<hfoption id="Pipeline">
2748

2849
```python
50+
import torch
2951
from transformers import pipeline
3052

31-
fill_mask = pipeline("fill-mask", model="camembert-base")
32-
result = fill_mask("Le camembert est un délicieux fromage <mask>.")
33-
print(result)
53+
pipeline = pipeline("fill-mask", model="camembert-base", torch_dtype=torch.float16, device=0)
54+
pipeline("Le camembert est un délicieux fromage <mask>.")
3455
```
3556

3657
</hfoption>
3758

3859
<hfoption id="AutoModel">
3960

4061
```python
62+
import torch
4163
from transformers import AutoTokenizer, AutoModelForMaskedLM
4264

4365
tokenizer = AutoTokenizer.from_pretrained("camembert-base")
44-
model = AutoModelForMaskedLM.from_pretrained("camembert-base")
66+
model = AutoModelForMaskedLM.from_pretrained("camembert-base", torch_dtype="auto", device_map="auto", attn_implementation="sdpa")
67+
inputs = tokenizer("Le camembert est un délicieux fromage <mask>.", return_tensors="pt").to("cuda")
4568

46-
inputs = tokenizer("Le camembert est un délicieux fromage <mask>.", return_tensors="pt")
47-
outputs = model(**inputs)
69+
with torch.no_grad():
70+
outputs = model(**inputs)
71+
predictions = outputs.logits
72+
73+
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
74+
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
75+
predicted_token = tokenizer.decode(predicted_token_id)
76+
77+
print(f"The predicted token is: {predicted_token}")
4878
```
4979

5080
</hfoption>
5181

5282
</hfoptions>
83+
[[autodoc]] CamembertTokenizerFast
84+
85+
## CamembertModel
86+
[[autodoc]] CamembertModel
87+
88+
## CamembertForMaskedLM
89+
[[autodoc]] CamembertForMaskedLM
90+
91+
## CamembertForSequenceClassification
5392

54-
Quantization reduces the memory burden of large models by representing weights in lower precision. Refer to the [Quantization](https://huggingface.co/docs/transformers/main/en/quantization) overview for available options.
55-
The example below uses [BitsAndBytes](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#load-in-8bit-or-4bit-using-bitsandbytes) quantization to load the model in 8-bit precision.
93+
[[autodoc]] CamembertForSequenceClassification
94+
95+
## CamembertForMultipleChoice
96+
97+
You can also use CamemBERT for multiple-choice tasks via the command line:
98+
99+
```bash
100+
transformers-cli env
101+
python -m transformers.cli run camembert-base \
102+
--task multiple-choice \
103+
--context "The president of France is" \
104+
--choices "Emmanuel Macron" "Napoleon Bonaparte" "Marie Curie" \
105+
--framework pt
106+
```
107+
108+
[[autodoc]] CamembertForMultipleChoice
109+
110+
## CamembertForTokenClassification
111+
112+
You can use the `camembert/camembert-large` model for token classification tasks like Named Entity Recognition (NER).
56113

114+
```python
115+
from transformers import AutoTokenizer, CamembertForTokenClassification
116+
import torch
117+
118+
model_id = "camembert/camembert-large"
119+
tokenizer = AutoTokenizer.from_pretrained(model_id)
120+
model = CamembertForTokenClassification.from_pretrained(model_id)
121+
122+
text = "Emmanuel Macron est le président de la République française."
123+
tokens = tokenizer(text, return_tensors="pt", truncation=True)
124+
125+
with torch.no_grad():
126+
outputs = model(**tokens)
127+
128+
predictions = torch.argmax(outputs.logits, dim=2)
129+
130+
for token, pred_id in zip(tokenizer.convert_ids_to_tokens(tokens["input_ids"][0]), predictions[0]):
131+
label = model.config.id2label.get(pred_id.item(), "O")
132+
print(f"{token}: {label}")
133+
```
134+
135+
[[autodoc]] CamembertForTokenClassification
136+
137+
Quantization reduces the memory burden of large models by representing weights in lower precision. Refer to the [Quantization](../quantization/overview) overview for available options.
138+
139+
The example below uses [bitsandbytes](../quantization/bitsandbytes) quantization to quantize the weights to 8-bits.
140+
57141
```python
58142
from transformers import AutoTokenizer, AutoModelForMaskedLM, BitsAndBytesConfig
59143

@@ -65,15 +149,39 @@ model = AutoModelForMaskedLM.from_pretrained(
65149
)
66150
tokenizer = AutoTokenizer.from_pretrained("camembert-base")
67151
```
68-
69-
## Notes
70152

71-
- CamemBERT uses RoBERTa pretraining objectives.
72-
- It makes use of a SentencePiece tokenizer.
73-
- It does not support token type IDs (segment embeddings).
74-
- Special pre-processing/post-processing is not needed.
153+
[[autodoc]] CamembertForQuestionAnswering
154+
155+
</pt>
156+
<tf>
157+
158+
## TFCamembertModel
159+
160+
[[autodoc]] TFCamembertModel
161+
162+
## TFCamembertForCausalLM
163+
164+
[[autodoc]] TFCamembertForCausalLM
165+
166+
## TFCamembertForMaskedLM
167+
168+
[[autodoc]] TFCamembertForMaskedLM
169+
170+
## TFCamembertForSequenceClassification
171+
172+
[[autodoc]] TFCamembertForSequenceClassification
173+
174+
## TFCamembertForMultipleChoice
175+
176+
[[autodoc]] TFCamembertForMultipleChoice
177+
178+
## TFCamembertForTokenClassification
179+
180+
[[autodoc]] TFCamembertForTokenClassification
181+
182+
## TFCamembertForQuestionAnswering
75183

76-
## Resources
184+
[[autodoc]] TFCamembertForQuestionAnswering
77185

78-
- [Original Paper](https://arxiv.org/abs/1911.03894)
79-
- [Hugging Face Model Card](https://huggingface.co/camembert-base)
186+
</tf>
187+
</frameworkcontent>

0 commit comments

Comments
 (0)