Skip to content

Commit 5e1c914

Browse files
committed
Several fixes for Gemma3n (#39135)
* remove the skips * fix the epsilon to a small value (does not make sense otherwise) * safeguard * overload test_eager_matches_sdpa * Update test_modeling_common.py * skip appropriate tests * correct no_split_layer * fix all devices issue * fix backward * fix
1 parent 8446e2a commit 5e1c914

File tree

5 files changed

+491
-390
lines changed

5 files changed

+491
-390
lines changed

src/transformers/models/gemma3n/modeling_gemma3n.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,9 +1135,17 @@ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.T
11351135
corrected += predictions # add the original input
11361136
return corrected.contiguous().type_as(activated)
11371137

1138+
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
1139+
"""
1140+
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
1141+
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
1142+
`scale_corrected_output`
1143+
"""
1144+
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
1145+
11381146
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
11391147
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
1140-
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
1148+
return self.forward(corrected)
11411149

11421150

11431151
class Gemma3nTextRotaryEmbedding(nn.Module):
@@ -1290,7 +1298,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
12901298
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
12911299

12921300
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
1293-
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
1301+
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
12941302
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
12951303
layer_type = config.layer_types[layer_idx]
12961304
self.kv_shared_layer_index = (
@@ -1319,21 +1327,22 @@ def forward(
13191327
query_states = query_states.transpose(1, 2)
13201328

13211329
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
1322-
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
1330+
# Device of past layer may be different from current one
1331+
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
1332+
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
13231333
if isinstance(past_key_value, HybridCache) and self.is_sliding:
13241334
max_length = past_key_value.sliding_window
1325-
if cache_position.shape[0] > max_length:
1326-
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
1327-
# slice into the entire cache.
1328-
indices = slice(0, max_length)
1329-
else:
1330-
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
1331-
indices = cache_position.clamp(min=0, max=max_length - 1)
1332-
else:
1333-
indices = cache_position
1335+
indices = (
1336+
slice(0, max_length)
1337+
if cache_position.shape[0] > max_length
1338+
else cache_position.clamp(min=0, max=max_length - 1)
1339+
)
13341340

1335-
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
1336-
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
1341+
# Device of past layer may be different from current one
1342+
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
1343+
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
1344+
query_states.device
1345+
)
13371346
else:
13381347
key_states = self.k_proj(hidden_states).view(hidden_shape)
13391348
key_states = self.k_norm(key_states)
@@ -1447,10 +1456,9 @@ def forward(
14471456
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
14481457
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
14491458

1450-
first_prediction = corrected_predictions[self.config.altup_active_idx]
1451-
first_prediction_clone = first_prediction.clone()
1459+
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
14521460
if self.config.altup_correct_scale:
1453-
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
1461+
first_prediction = self.altup.scale_corrected_output(first_prediction)
14541462

14551463
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
14561464
first_prediction = self.per_layer_input_gate(first_prediction)
@@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
14751483
config_class = Gemma3nConfig
14761484
base_model_prefix = ""
14771485
supports_gradient_checkpointing = True
1478-
_no_split_modules = ["Gemma3nDecoderLayer"]
1486+
_no_split_modules = ["Gemma3nTextDecoderLayer"]
14791487
_skip_keys_device_placement = ["past_key_values"]
14801488
_supports_flash_attn_3 = True
14811489
_supports_flash_attn_2 = True
@@ -1656,18 +1664,17 @@ def forward(
16561664
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
16571665

16581666
# Expand hidden_states to support per-layer inputs
1659-
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
1660-
epsilon_tensor = torch.tensor(torch.finfo().min)
1667+
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
1668+
epsilon_tensor = torch.tensor(1e-5)
16611669

16621670
temp_hidden_states = [hidden_states_0]
16631671
for i in range(1, self.config.altup_num_inputs):
16641672
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
1665-
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
1666-
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
1667-
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
1668-
current_hidden_state = current_hidden_state * (
1669-
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
1670-
)
1673+
altup_proj = self.altup_projections[i - 1](hidden_states_0)
1674+
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1675+
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
1676+
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
1677+
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
16711678
temp_hidden_states.append(current_hidden_state)
16721679

16731680
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
@@ -1685,9 +1692,9 @@ def forward(
16851692

16861693
layer_outputs = decoder_layer(
16871694
hidden_states,
1688-
position_embeddings_global=position_embeddings_global,
1689-
position_embeddings_local=position_embeddings_local,
1690-
per_layer_input=per_layer_input,
1695+
position_embeddings_global,
1696+
position_embeddings_local,
1697+
per_layer_input,
16911698
attention_mask=causal_mask,
16921699
position_ids=position_ids,
16931700
past_key_value=past_key_values,
@@ -1712,11 +1719,10 @@ def forward(
17121719
for i in range(1, self.config.altup_num_inputs):
17131720
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
17141721
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
1715-
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
1716-
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
1717-
current_hidden_state = current_hidden_state * (
1718-
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
1719-
)
1722+
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1723+
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
1724+
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
1725+
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
17201726
temp_hidden_states.append(current_hidden_state)
17211727

17221728
hidden_states = torch.stack(temp_hidden_states)
@@ -1743,7 +1749,9 @@ def project_per_layer_inputs(
17431749
per_layer_inputs: Optional[torch.Tensor] = None,
17441750
) -> torch.Tensor:
17451751
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
1746-
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
1752+
per_layer_projection *= self.per_layer_projection_scale.to(
1753+
dtype=inputs_embeds.dtype, device=per_layer_projection.device
1754+
)
17471755
per_layer_projection = per_layer_projection.reshape(
17481756
*inputs_embeds.shape[:-1],
17491757
self.config.num_hidden_layers,
@@ -1758,7 +1766,9 @@ def project_per_layer_inputs(
17581766
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
17591767
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
17601768

1761-
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
1769+
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
1770+
dtype=inputs_embeds.dtype, device=per_layer_projection.device
1771+
)
17621772

17631773

17641774
@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,9 +1685,17 @@ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.T
16851685
corrected += predictions # add the original input
16861686
return corrected.contiguous().type_as(activated)
16871687

1688+
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
1689+
"""
1690+
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
1691+
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
1692+
`scale_corrected_output`
1693+
"""
1694+
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
1695+
16881696
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
16891697
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
1690-
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
1698+
return self.forward(corrected)
16911699

16921700

16931701
class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
@@ -1732,7 +1740,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
17321740
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
17331741

17341742
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
1735-
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
1743+
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
17361744
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
17371745
layer_type = config.layer_types[layer_idx]
17381746
self.kv_shared_layer_index = (
@@ -1761,21 +1769,22 @@ def forward(
17611769
query_states = query_states.transpose(1, 2)
17621770

17631771
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
1764-
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
1772+
# Device of past layer may be different from current one
1773+
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
1774+
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
17651775
if isinstance(past_key_value, HybridCache) and self.is_sliding:
17661776
max_length = past_key_value.sliding_window
1767-
if cache_position.shape[0] > max_length:
1768-
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
1769-
# slice into the entire cache.
1770-
indices = slice(0, max_length)
1771-
else:
1772-
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
1773-
indices = cache_position.clamp(min=0, max=max_length - 1)
1774-
else:
1775-
indices = cache_position
1777+
indices = (
1778+
slice(0, max_length)
1779+
if cache_position.shape[0] > max_length
1780+
else cache_position.clamp(min=0, max=max_length - 1)
1781+
)
17761782

1777-
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
1778-
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
1783+
# Device of past layer may be different from current one
1784+
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
1785+
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
1786+
query_states.device
1787+
)
17791788
else:
17801789
key_states = self.k_proj(hidden_states).view(hidden_shape)
17811790
key_states = self.k_norm(key_states)
@@ -1880,10 +1889,9 @@ def forward(
18801889
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
18811890
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
18821891

1883-
first_prediction = corrected_predictions[self.config.altup_active_idx]
1884-
first_prediction_clone = first_prediction.clone()
1892+
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
18851893
if self.config.altup_correct_scale:
1886-
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
1894+
first_prediction = self.altup.scale_corrected_output(first_prediction)
18871895

18881896
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
18891897
first_prediction = self.per_layer_input_gate(first_prediction)
@@ -1906,7 +1914,7 @@ def forward(
19061914
class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
19071915
config_class = Gemma3nConfig
19081916
base_model_prefix = ""
1909-
_no_split_modules = ["Gemma3nDecoderLayer"]
1917+
_no_split_modules = ["Gemma3nTextDecoderLayer"]
19101918

19111919
def _init_weights(self, module):
19121920
# important: this ported version of Gemma2 isn't meant for training from scratch - only
@@ -1995,7 +2003,9 @@ def project_per_layer_inputs(
19952003
per_layer_inputs: Optional[torch.Tensor] = None,
19962004
) -> torch.Tensor:
19972005
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
1998-
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
2006+
per_layer_projection *= self.per_layer_projection_scale.to(
2007+
dtype=inputs_embeds.dtype, device=per_layer_projection.device
2008+
)
19992009
per_layer_projection = per_layer_projection.reshape(
20002010
*inputs_embeds.shape[:-1],
20012011
self.config.num_hidden_layers,
@@ -2010,7 +2020,9 @@ def project_per_layer_inputs(
20102020
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
20112021
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
20122022

2013-
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
2023+
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
2024+
dtype=inputs_embeds.dtype, device=per_layer_projection.device
2025+
)
20142026

20152027
@can_return_tuple
20162028
@auto_docstring
@@ -2091,18 +2103,17 @@ def forward(
20912103
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
20922104

20932105
# Expand hidden_states to support per-layer inputs
2094-
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
2095-
epsilon_tensor = torch.tensor(torch.finfo().min)
2106+
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
2107+
epsilon_tensor = torch.tensor(1e-5)
20962108

20972109
temp_hidden_states = [hidden_states_0]
20982110
for i in range(1, self.config.altup_num_inputs):
20992111
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
2100-
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
2101-
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
2102-
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
2103-
current_hidden_state = current_hidden_state * (
2104-
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
2105-
)
2112+
altup_proj = self.altup_projections[i - 1](hidden_states_0)
2113+
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
2114+
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
2115+
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
2116+
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
21062117
temp_hidden_states.append(current_hidden_state)
21072118

21082119
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
@@ -2120,9 +2131,9 @@ def forward(
21202131

21212132
layer_outputs = decoder_layer(
21222133
hidden_states,
2123-
position_embeddings_global=position_embeddings_global,
2124-
position_embeddings_local=position_embeddings_local,
2125-
per_layer_input=per_layer_input,
2134+
position_embeddings_global,
2135+
position_embeddings_local,
2136+
per_layer_input,
21262137
attention_mask=causal_mask,
21272138
position_ids=position_ids,
21282139
past_key_value=past_key_values,
@@ -2147,11 +2158,10 @@ def forward(
21472158
for i in range(1, self.config.altup_num_inputs):
21482159
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
21492160
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
2150-
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
2151-
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
2152-
current_hidden_state = current_hidden_state * (
2153-
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
2154-
)
2161+
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
2162+
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
2163+
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
2164+
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
21552165
temp_hidden_states.append(current_hidden_state)
21562166

21572167
hidden_states = torch.stack(temp_hidden_states)

src/transformers/testing_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1642,7 +1642,6 @@ def set_model_tester_for_less_flaky_test(test_case):
16421642
"AriaVisionText2TextModelTester",
16431643
"GPTNeoModelTester",
16441644
"DPTModelTester",
1645-
"Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
16461645
]
16471646
if test_case.model_tester.__class__.__name__ in exceptional_classes:
16481647
target_num_hidden_layers = None

0 commit comments

Comments
 (0)