Skip to content

Commit 4721c36

Browse files
committed
Merge branch 'main' into torch-sdpa-preliminary-support
2 parents f116cce + ce0bbd5 commit 4721c36

File tree

14 files changed

+286
-42
lines changed

14 files changed

+286
-42
lines changed

docs/source/en/pad_truncation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The following table summarizes the recommended way to setup padding and truncati
5454
| | | `tokenizer(batch_sentences, padding='longest')` |
5555
| | padding to max model input length | `tokenizer(batch_sentences, padding='max_length')` |
5656
| | padding to specific length | `tokenizer(batch_sentences, padding='max_length', max_length=42)` |
57-
| | padding to a multiple of a value | `tokenizer(batch_sentences, padding=True, pad_to_multiple_of=8) |
57+
| | padding to a multiple of a value | `tokenizer(batch_sentences, padding=True, pad_to_multiple_of=8)` |
5858
| truncation to max model input length | no padding | `tokenizer(batch_sentences, truncation=True)` or |
5959
| | | `tokenizer(batch_sentences, truncation=STRATEGY)` |
6060
| | padding to max sequence in batch | `tokenizer(batch_sentences, padding=True, truncation=True)` or |

docs/source/es/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
- sections:
7676
- local: philosophy
7777
title: Filosofía
78+
- local: pad_truncation
79+
title: Relleno y truncamiento
7880
- local: bertology
7981
title: BERTología
8082
- local: perplexity

docs/source/es/pad_truncation.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Relleno y truncamiento
18+
19+
Las entradas agrupadas por lotes (batched) suelen tener longitudes diferentes, por lo que no se pueden convertir en tensores de tamaño fijo. El relleno (también conocido como "Padding") y el truncamiento (conocido como "Truncation") son estrategias para abordar este problema y crear tensores rectangulares a partir de lotes de longitudes variables. El relleno agrega un **padding token** especial para garantizar que las secuencias más cortas tengan la misma longitud que la secuencia más larga en un lote o la longitud máxima aceptada por el modelo. El truncamiento funciona en la otra dirección al truncar secuencias largas.
20+
21+
En la mayoría de los casos, es bastante eficaz rellenar el lote hasta la longitud de la secuencia más larga y truncar hasta la longitud máxima que un modelo puede aceptar. Sin embargo, la API admite más estrategias si las necesitas. Los tres argumentos que necesitas son: `padding`, `truncation` y `max_length`.
22+
23+
El argumento `padding` controla el relleno. Puede ser un booleano o una cadena:
24+
25+
- `True` o `'longest'`: rellena hasta la longitud de la secuencia más larga en el lote (no se aplica relleno si solo proporcionas una única secuencia).
26+
- `'max_length'`: rellena hasta una longitud especificada por el argumento `max_length` o la longitud máxima aceptada
27+
por el modelo si no se proporciona `max_length` (`max_length=None`). El relleno se aplicará incluso si solo proporcionas una única secuencia.
28+
- `False` o `'do_not_pad'`: no se aplica relleno. Este es el comportamiento predeterminado.
29+
30+
El argumento `truncation` controla el truncamiento. Puede ser un booleano o una cadena:
31+
32+
- `True` o `'longest_first'`: trunca hasta una longitud máxima especificada por el argumento `max_length` o
33+
la longitud máxima aceptada por el modelo si no se proporciona `max_length` (`max_length=None`). Esto
34+
truncará token por token, eliminando un token de la secuencia más larga en el par hasta alcanzar la longitud adecuada.
35+
- `'only_second'`: trunca hasta una longitud máxima especificada por el argumento `max_length` o la longitud máxima
36+
aceptada por el modelo si no se proporciona `max_length` (`max_length=None`). Esto solo truncará
37+
la segunda oración de un par si se proporciona un par de secuencias (o un lote de pares de secuencias).
38+
- `'only_first'`: trunca hasta una longitud máxima especificada por el argumento `max_length` o la longitud máxima
39+
aceptada por el modelo si no se proporciona `max_length` (`max_length=None`). Esto solo truncará
40+
la primera oración de un par si se proporciona un par de secuencias (o un lote de pares de secuencias).
41+
- `False` o `'do_not_truncate'`: no se aplica truncamiento. Este es el comportamiento predeterminado.
42+
43+
El argumento `max_length` controla la longitud del relleno y del truncamiento. Puede ser un número entero o `None`, en cuyo caso se establecerá automáticamente en la longitud máxima que el modelo puede aceptar. Si el modelo no tiene una longitud máxima de entrada específica, se desactiva el truncamiento o el relleno hasta `max_length`.
44+
45+
La siguiente tabla resume la forma recomendada de configurar el relleno y el truncamiento. Si usas pares de secuencias de entrada en alguno de los siguientes ejemplos, puedes reemplazar `truncation=True` por una `ESTRATEGIA` seleccionada en
46+
`['only_first', 'only_second', 'longest_first']`, es decir, `truncation='only_second'` o `truncation='longest_first'` para controlar cómo se truncan ambas secuencias en el par, como se detalló anteriormente.
47+
48+
| Truncation | Padding | Instrucción |
49+
|-----------------------------------------|--------------------------------------|---------------------------------------------------------------------------------------------|
50+
| sin truncamiento | sin relleno | `tokenizer(batch_sentences)` |
51+
| | relleno hasta la longitud máxima del lote | `tokenizer(batch_sentences, padding=True)` o |
52+
| | | `tokenizer(batch_sentences, padding='longest')` |
53+
| | relleno hasta la longitud máxima del modelo | `tokenizer(batch_sentences, padding='max_length')` |
54+
| | relleno hasta una longitud específica | `tokenizer(batch_sentences, padding='max_length', max_length=42)` |
55+
| | relleno hasta un múltiplo de un valor | `tokenizer(batch_sentences, padding=True, pad_to_multiple_of=8)` |
56+
| truncamiento hasta la longitud máxima del modelo | sin relleno | `tokenizer(batch_sentences, truncation=True)` o |
57+
| | | `tokenizer(batch_sentences, truncation=ESTRATEGIA)` |
58+
| | relleno hasta la longitud máxima del lote | `tokenizer(batch_sentences, padding=True, truncation=True)` o |
59+
| | | `tokenizer(batch_sentences, padding=True, truncation=ESTRATEGIA)` |
60+
| | relleno hasta la longitud máxima del modelo | `tokenizer(batch_sentences, padding='max_length', truncation=True)` o |
61+
| | | `tokenizer(batch_sentences, padding='max_length', truncation=ESTRATEGIA)` |
62+
| | relleno hasta una longitud específica | No es posible |
63+
| truncamiento hasta una longitud específica | sin relleno | `tokenizer(batch_sentences, truncation=True, max_length=42)` o |
64+
| | | `tokenizer(batch_sentences, truncation=ESTRATEGIA, max_length=42)` |
65+
| | relleno hasta la longitud máxima del lote | `tokenizer(batch_sentences, padding=True, truncation=True, max_length=42)` o |
66+
| | | `tokenizer(batch_sentences, padding=True, truncation=ESTRATEGIA, max_length=42)` |
67+
| | relleno hasta la longitud máxima del modelo | No es posible |
68+
| | relleno hasta una longitud específica | `tokenizer(batch_sentences, padding='max_length', truncation=True, max_length=42)` o |
69+
| | | `tokenizer(batch_sentences, padding='max_length', truncation=ESTRATEGIA, max_length=42)` |

src/transformers/cache_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
3838
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
3939
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
4040

41+
def get_max_length(self) -> Optional[int]:
42+
"""Returns the maximum sequence length of the cached states, if there is any."""
43+
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
44+
45+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
46+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
47+
# Cache without size limit -> all cache is usable
48+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
49+
# length, we will need to evict part of the cache (and thus not all cache is usable)
50+
max_length = self.get_max_length()
51+
previous_seq_length = self.get_seq_length(layer_idx)
52+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
53+
return max_length - new_seq_length
54+
return previous_seq_length
55+
4156

4257
class DynamicCache(Cache):
4358
"""
@@ -120,6 +135,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120135
return 0
121136
return self.key_cache[layer_idx].shape[-2]
122137

138+
def get_max_length(self) -> Optional[int]:
139+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
140+
return None
141+
123142
def reorder_cache(self, beam_idx: torch.LongTensor):
124143
"""Reorders the cache for beam search, given the selected beam indices."""
125144
for layer_idx in range(len(self.key_cache)):
@@ -209,8 +228,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
209228
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
210229
if len(self.key_cache) <= layer_idx:
211230
return 0
212-
cache_length = self.key_cache[layer_idx].shape[-2]
213-
return min(cache_length, self.window_length - 1)
231+
return self.key_cache[layer_idx].shape[-2]
232+
233+
def get_max_length(self) -> Optional[int]:
234+
"""Returns the maximum sequence length of the cached states."""
235+
return self.window_length
214236

215237
def update(
216238
self,
@@ -267,7 +289,9 @@ def update(
267289

268290
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
269291
if using_rope:
270-
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
292+
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
293+
key_states, cos[: self.window_length], sin[: self.window_length]
294+
)
271295
if partial_rotation_size is not None:
272296
keys_to_keep, keys_pass = (
273297
keys_to_keep[..., :partial_rotation_size],

src/transformers/models/blip/image_processing_blip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class BlipImageProcessor(BaseImageProcessor):
5757
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
5858
overridden by the `resample` parameter in the `preprocess` method.
5959
do_rescale (`bool`, *optional*, defaults to `True`):
60-
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
60+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
6161
`do_rescale` parameter in the `preprocess` method.
6262
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
6363
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be

src/transformers/models/llama/modeling_llama.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def forward(
399399
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
400400
"with a layer index."
401401
)
402-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
402+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
403403
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
404404
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
405405

@@ -504,7 +504,7 @@ def forward(
504504

505505
kv_seq_len = key_states.shape[-2]
506506
if past_key_value is not None:
507-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
507+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
508508
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
509509

510510
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -998,7 +998,7 @@ def forward(
998998
use_legacy_cache = not isinstance(past_key_values, Cache)
999999
if use_legacy_cache:
10001000
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1001-
past_key_values_length = past_key_values.get_seq_length()
1001+
past_key_values_length = past_key_values.get_usable_length(seq_length)
10021002

10031003
if position_ids is None:
10041004
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1224,8 +1224,10 @@ def prepare_inputs_for_generation(
12241224
if isinstance(past_key_values, Cache):
12251225
cache_length = past_key_values.get_seq_length()
12261226
past_length = past_key_values.seen_tokens
1227+
max_cache_length = past_key_values.get_max_length()
12271228
else:
12281229
cache_length = past_length = past_key_values[0][0].shape[2]
1230+
max_cache_length = None
12291231

12301232
# Keep only the unprocessed tokens:
12311233
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -1239,10 +1241,13 @@ def prepare_inputs_for_generation(
12391241
input_ids = input_ids[:, past_length:]
12401242
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
12411243

1242-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1243-
# older attention values, as their corresponding values are not part of the input.
1244-
if cache_length < past_length and attention_mask is not None:
1245-
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1244+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1245+
if (
1246+
max_cache_length is not None
1247+
and attention_mask is not None
1248+
and cache_length + input_ids.shape[1] > max_cache_length
1249+
):
1250+
attention_mask = attention_mask[:, -max_cache_length:]
12461251

12471252
position_ids = kwargs.get("position_ids", None)
12481253
if attention_mask is not None and position_ids is None:

0 commit comments

Comments
 (0)