Skip to content

Commit 69b0f44

Browse files
Add gemma 2 (#31659)
* inital commit * Add doc * protect? * fixup stuffs * update tests * fix build documentation * mmmmmmm config attributes * style * nit * uodate * nit * Fix docs * protect some stuff --------- Co-authored-by: Lysandre <lysandre@huggingface.co>
1 parent be50a03 commit 69b0f44

24 files changed

+3057
-69
lines changed

docs/source/en/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ Flax), PyTorch, and/or TensorFlow.
145145
| [Funnel Transformer](model_doc/funnel) ||||
146146
| [Fuyu](model_doc/fuyu) ||||
147147
| [Gemma](model_doc/gemma) ||||
148+
| [Gemma2](model_doc/gemma2) ||||
148149
| [GIT](model_doc/git) ||||
149150
| [GLPN](model_doc/glpn) ||||
150151
| [GPT Neo](model_doc/gpt_neo) ||||

docs/source/en/model_doc/gemma2.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
the License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
13+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
14+
rendered properly in your Markdown viewer.
15+
16+
-->
17+
18+
# Gemma2
19+
20+
## Overview
21+
22+
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/Gemma2-open-models/) by Gemma2 Team, Google.
23+
Gemma2 models are trained on 6T tokens, and released with 2 versions, 2b and 7b.
24+
25+
The abstract from the paper is the following:
26+
27+
*This work introduces Gemma2, a new family of open language models demonstrating strong performance across academic benchmarks for language understanding, reasoning, and safety. We release two sizes of models (2 billion and 7 billion parameters), and provide both pretrained and fine-tuned checkpoints. Gemma2 outperforms similarly sized open models on 11 out of 18 text-based tasks, and we present comprehensive evaluations of safety and responsibility aspects of the models, alongside a detailed description of our model development. We believe the responsible release of LLMs is critical for improving the safety of frontier models, and for enabling the next wave of LLM innovations*
28+
29+
Tips:
30+
31+
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
32+
33+
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
34+
35+
36+
## Gemma2Config
37+
38+
[[autodoc]] Gemma2Config
39+
40+
## Gemma2Model
41+
42+
[[autodoc]] Gemma2Model
43+
- forward
44+
45+
## Gemma2ForCausalLM
46+
47+
[[autodoc]] Gemma2ForCausalLM
48+
- forward
49+
50+
## Gemma2ForSequenceClassification
51+
52+
[[autodoc]] Gemma2ForSequenceClassification
53+
- forward
54+
55+
## Gemma2ForTokenClassification
56+
57+
[[autodoc]] Gemma2ForTokenClassification
58+
- forward

src/transformers/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@
435435
],
436436
"models.fuyu": ["FuyuConfig"],
437437
"models.gemma": ["GemmaConfig"],
438+
"models.gemma2": ["Gemma2Config"],
438439
"models.git": [
439440
"GitConfig",
440441
"GitProcessor",
@@ -2181,6 +2182,15 @@
21812182
"GemmaPreTrainedModel",
21822183
]
21832184
)
2185+
_import_structure["models.gemma2"].extend(
2186+
[
2187+
"Gemma2ForCausalLM",
2188+
"Gemma2ForSequenceClassification",
2189+
"Gemma2ForTokenClassification",
2190+
"Gemma2Model",
2191+
"Gemma2PreTrainedModel",
2192+
]
2193+
)
21842194
_import_structure["models.git"].extend(
21852195
[
21862196
"GitForCausalLM",
@@ -5062,6 +5072,7 @@
50625072
)
50635073
from .models.fuyu import FuyuConfig
50645074
from .models.gemma import GemmaConfig
5075+
from .models.gemma2 import Gemma2Config
50655076
from .models.git import (
50665077
GitConfig,
50675078
GitProcessor,
@@ -6694,6 +6705,13 @@
66946705
GemmaModel,
66956706
GemmaPreTrainedModel,
66966707
)
6708+
from .models.gemma2 import (
6709+
Gemma2ForCausalLM,
6710+
Gemma2ForSequenceClassification,
6711+
Gemma2ForTokenClassification,
6712+
Gemma2Model,
6713+
Gemma2PreTrainedModel,
6714+
)
66976715
from .models.git import (
66986716
GitForCausalLM,
66996717
GitModel,

src/transformers/cache_utils.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,125 @@ def get_max_length(self) -> Optional[int]:
970970
# in theory there is no limit because the sliding window size is fixed
971971
# no matter how long the sentence is
972972
return None
973+
974+
975+
class HybridCache(Cache):
976+
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
977+
if not hasattr(config, "sliding_window") or config.sliding_window is None:
978+
raise ValueError(
979+
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
980+
"sliding window attention, please check if there is a `sliding_window` field in the model "
981+
"config and it's not set to None."
982+
)
983+
self.max_cache_len = max_cache_len
984+
self.max_batch_size = max_batch_size
985+
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
986+
self.head_dim = (
987+
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
988+
)
989+
990+
self.dtype = dtype if dtype is not None else torch.float32
991+
self.num_key_value_heads = (
992+
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
993+
)
994+
self.is_sliding = torch.tensor(
995+
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
996+
)
997+
self.key_cache: List[torch.Tensor] = []
998+
self.value_cache: List[torch.Tensor] = []
999+
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
1000+
sliding_cache_shape = (
1001+
max_batch_size,
1002+
self.num_key_value_heads,
1003+
min(config.sliding_window, max_cache_len),
1004+
self.head_dim,
1005+
)
1006+
for i in range(config.num_hidden_layers):
1007+
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1008+
# breaks when updating the cache.
1009+
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
1010+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
1011+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
1012+
torch._dynamo.mark_static_address(new_layer_key_cache)
1013+
torch._dynamo.mark_static_address(new_layer_value_cache)
1014+
self.key_cache.append(new_layer_key_cache)
1015+
self.value_cache.append(new_layer_value_cache)
1016+
1017+
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
1018+
if cache_position.shape[0] > max_cache_len:
1019+
k_out = key_states[:, :, -max_cache_len:, :]
1020+
v_out = value_states[:, :, -max_cache_len:, :]
1021+
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
1022+
self.key_cache[layer_idx] += k_out
1023+
self.value_cache[layer_idx] += v_out
1024+
# we should return the whole states instead of k_out, v_out to take the whole prompt
1025+
# into consideration when building kv cache instead of just throwing away tokens outside of the window
1026+
return key_states, value_states
1027+
1028+
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
1029+
cache_position = cache_position.clamp(0, max_cache_len - 1)
1030+
to_shift = cache_position >= max_cache_len - 1
1031+
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
1032+
k_out = k_out[:, :, indices]
1033+
v_out = v_out[:, :, indices]
1034+
1035+
k_out[:, :, cache_position] = key_states
1036+
v_out[:, :, cache_position] = value_states
1037+
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
1038+
self.key_cache[layer_idx].zero_()
1039+
self.value_cache[layer_idx].zero_()
1040+
1041+
self.key_cache[layer_idx] += k_out
1042+
self.value_cache[layer_idx] += v_out
1043+
return k_out, v_out
1044+
1045+
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
1046+
k_out[:, :, cache_position] = key_states
1047+
v_out[:, :, cache_position] = value_states
1048+
1049+
self.key_cache[layer_idx] = k_out
1050+
self.value_cache[layer_idx] = v_out
1051+
return k_out, v_out
1052+
1053+
def update(
1054+
self,
1055+
key_states: torch.Tensor,
1056+
value_states: torch.Tensor,
1057+
layer_idx: int,
1058+
cache_kwargs: Optional[Dict[str, Any]] = None,
1059+
sliding_window: Optional[int] = None,
1060+
) -> Tuple[torch.Tensor]:
1061+
cache_position = cache_kwargs.get("cache_position")
1062+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
1063+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
1064+
k_out = self.key_cache[layer_idx]
1065+
v_out = self.value_cache[layer_idx]
1066+
if sliding_window:
1067+
update_fn = self._sliding_update
1068+
else:
1069+
update_fn = self._static_update
1070+
1071+
return update_fn(
1072+
cache_position,
1073+
layer_idx,
1074+
key_states,
1075+
value_states,
1076+
k_out,
1077+
v_out,
1078+
k_out.shape[2],
1079+
)
1080+
1081+
def get_max_length(self) -> Optional[int]:
1082+
# in theory there is no limit because the sliding window size is fixed
1083+
# no matter how long the sentence is
1084+
return self.max_cache_len
1085+
1086+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
1087+
return None
1088+
1089+
def reset(self):
1090+
"""Resets the cache values while preserving the objects"""
1091+
for layer_idx in range(len(self.key_cache)):
1092+
# In-place ops prevent breaking the static address
1093+
self.key_cache[layer_idx].zero_()
1094+
self.value_cache[layer_idx].zero_()

src/transformers/generation/configuration_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def __init__(self, **kwargs):
400400
# Cache implementation
401401
self.cache_implementation = kwargs.pop("cache_implementation", None)
402402
self.cache_config = kwargs.pop("cache_config", None)
403-
if self.cache_implementation is not None:
403+
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
404404
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
405405
if self.cache_config is None:
406406
self.cache_config = cache_config_class()

src/transformers/generation/utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Cache,
2929
DynamicCache,
3030
HQQQuantizedCache,
31+
HybridCache,
3132
QuantizedCacheConfig,
3233
QuantoQuantizedCache,
3334
SlidingWindowCache,
@@ -112,7 +113,7 @@
112113
if is_accelerate_available():
113114
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
114115

115-
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
116+
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
116117
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
117118

118119

@@ -1395,10 +1396,12 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
13951396

13961397
past_length = 0
13971398
if model_kwargs.get("past_key_values") is not None:
1398-
if isinstance(model_kwargs["past_key_values"], Cache):
1399-
past_length = model_kwargs["past_key_values"].get_seq_length()
1400-
else:
1401-
past_length = model_kwargs["past_key_values"][0][0].shape[2]
1399+
cache = model_kwargs["past_key_values"]
1400+
if not isinstance(cache, Cache):
1401+
past_length = cache[0][0].shape[2]
1402+
elif hasattr(cache, "get_seq_length"):
1403+
past_length = cache.get_seq_length()
1404+
14021405
if "inputs_embeds" in model_kwargs:
14031406
cur_len = model_kwargs["inputs_embeds"].shape[1]
14041407
else:
@@ -1739,7 +1742,9 @@ def generate(
17391742
"issue: https://github.com/huggingface/transformers/issues/28981"
17401743
)
17411744
model_kwargs["past_key_values"] = self._get_cache(
1742-
generation_config.cache_implementation, batch_size, generation_config.max_length
1745+
generation_config.cache_implementation,
1746+
getattr(generation_config, "num_beams", 1) * batch_size,
1747+
generation_config.max_length,
17431748
)
17441749
elif generation_config.cache_implementation == "quantized":
17451750
if not self._supports_quantized_cache:

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
funnel,
9393
fuyu,
9494
gemma,
95+
gemma2,
9596
git,
9697
glpn,
9798
gpt2,

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
("funnel", "FunnelConfig"),
109109
("fuyu", "FuyuConfig"),
110110
("gemma", "GemmaConfig"),
111+
("gemma2", "Gemma2Config"),
111112
("git", "GitConfig"),
112113
("glpn", "GLPNConfig"),
113114
("gpt-sw3", "GPT2Config"),
@@ -385,6 +386,7 @@
385386
("funnel", "Funnel Transformer"),
386387
("fuyu", "Fuyu"),
387388
("gemma", "Gemma"),
389+
("gemma2", "Gemma2"),
388390
("git", "GIT"),
389391
("glpn", "GLPN"),
390392
("gpt-sw3", "GPT-Sw3"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
("fsmt", "FSMTModel"),
106106
("funnel", ("FunnelModel", "FunnelBaseModel")),
107107
("gemma", "GemmaModel"),
108+
("gemma2", "Gemma2Model"),
108109
("git", "GitModel"),
109110
("glpn", "GLPNModel"),
110111
("gpt-sw3", "GPT2Model"),
@@ -454,6 +455,7 @@
454455
("falcon", "FalconForCausalLM"),
455456
("fuyu", "FuyuForCausalLM"),
456457
("gemma", "GemmaForCausalLM"),
458+
("gemma2", "Gemma2ForCausalLM"),
457459
("git", "GitForCausalLM"),
458460
("gpt-sw3", "GPT2LMHeadModel"),
459461
("gpt2", "GPT2LMHeadModel"),
@@ -863,6 +865,7 @@
863865
("fnet", "FNetForSequenceClassification"),
864866
("funnel", "FunnelForSequenceClassification"),
865867
("gemma", "GemmaForSequenceClassification"),
868+
("gemma2", "Gemma2ForSequenceClassification"),
866869
("gpt-sw3", "GPT2ForSequenceClassification"),
867870
("gpt2", "GPT2ForSequenceClassification"),
868871
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
@@ -1044,6 +1047,7 @@
10441047
("fnet", "FNetForTokenClassification"),
10451048
("funnel", "FunnelForTokenClassification"),
10461049
("gemma", "GemmaForTokenClassification"),
1050+
("gemma2", "Gemma2ForTokenClassification"),
10471051
("gpt-sw3", "GPT2ForTokenClassification"),
10481052
("gpt2", "GPT2ForTokenClassification"),
10491053
("gpt_bigcode", "GPTBigCodeForTokenClassification"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@
188188
"GemmaTokenizerFast" if is_tokenizers_available() else None,
189189
),
190190
),
191+
(
192+
"gemma2",
193+
(
194+
"GemmaTokenizer" if is_sentencepiece_available() else None,
195+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
196+
),
197+
),
191198
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
192199
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
193200
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),

0 commit comments

Comments
 (0)