|
1 | 1 | # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
2 |
| -# This file was automatically generated from src/transformers/models/deepseek_v2/modular_deepseek_V2.py. |
| 2 | +# This file was automatically generated from src/transformers/models/deepseek_v2/modular_deepseek_v2.py. |
3 | 3 | # Do NOT edit this file manually as any edits will be overwritten by the generation of
|
4 | 4 | # the file from the modular. If any change should be done, please apply the change to the
|
5 |
| -# modular_deepseek_V2.py file directly. One of our CI enforces this. |
| 5 | +# modular_deepseek_v2.py file directly. One of our CI enforces this. |
6 | 6 | # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
7 | 7 | # coding=utf-8
|
8 | 8 | # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
|
20 | 20 | # limitations under the License.
|
21 | 21 |
|
22 | 22 | import warnings
|
23 |
| -from typing import Callable, List, Optional, Tuple, Union |
| 23 | +from typing import Callable, Optional, Tuple, Union |
24 | 24 |
|
25 | 25 | import torch
|
26 | 26 | import torch.nn.functional as F
|
|
46 | 46 | replace_return_docstrings,
|
47 | 47 | )
|
48 | 48 | from ...utils.deprecation import deprecate_kwarg
|
49 |
| -from .configuration_deepseek_V2 import DeepseekV2Config |
| 49 | +from .configuration_deepseek_v2 import DeepseekV2Config |
50 | 50 |
|
51 | 51 |
|
52 | 52 | if is_torch_flex_attn_available():
|
@@ -640,20 +640,12 @@ def _init_weights(self, module):
|
640 | 640 | config.n_positions - 1]`.
|
641 | 641 |
|
642 | 642 | [What are position IDs?](../glossary#position-ids)
|
643 |
| - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| 643 | + past_key_values (`Cache`, *optional*): |
644 | 644 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
645 | 645 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
646 | 646 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
647 | 647 |
|
648 |
| - Two formats are allowed: |
649 |
| - - a [`~cache_utils.Cache`] instance, see our |
650 |
| - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); |
651 |
| - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
652 |
| - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
653 |
| - cache format. |
654 |
| -
|
655 |
| - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
656 |
| - legacy cache format will be returned. |
| 648 | + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
657 | 649 |
|
658 | 650 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
659 | 651 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
@@ -744,6 +736,10 @@ def forward(
|
744 | 736 | )
|
745 | 737 | use_cache = False
|
746 | 738 |
|
| 739 | + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache |
| 740 | + if not isinstance(past_key_values, (type(None), Cache)): |
| 741 | + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") |
| 742 | + |
747 | 743 | if inputs_embeds is None:
|
748 | 744 | inputs_embeds = self.embed_tokens(input_ids)
|
749 | 745 |
|
@@ -992,7 +988,7 @@ def forward(
|
992 | 988 | input_ids: torch.LongTensor = None,
|
993 | 989 | attention_mask: Optional[torch.Tensor] = None,
|
994 | 990 | position_ids: Optional[torch.LongTensor] = None,
|
995 |
| - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| 991 | + past_key_values: Optional[Cache] = None, |
996 | 992 | inputs_embeds: Optional[torch.FloatTensor] = None,
|
997 | 993 | labels: Optional[torch.LongTensor] = None,
|
998 | 994 | use_cache: Optional[bool] = None,
|
@@ -1114,7 +1110,7 @@ def forward(
|
1114 | 1110 | input_ids: Optional[torch.LongTensor] = None,
|
1115 | 1111 | attention_mask: Optional[torch.Tensor] = None,
|
1116 | 1112 | position_ids: Optional[torch.LongTensor] = None,
|
1117 |
| - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| 1113 | + past_key_values: Optional[Cache] = None, |
1118 | 1114 | inputs_embeds: Optional[torch.FloatTensor] = None,
|
1119 | 1115 | labels: Optional[torch.LongTensor] = None,
|
1120 | 1116 | use_cache: Optional[bool] = None,
|
|
0 commit comments