@@ -1135,9 +1135,17 @@ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.T
11351135 corrected += predictions # add the original input
11361136 return corrected .contiguous ().type_as (activated )
11371137
1138+ def forward (self , corrected : torch .Tensor ) -> torch .Tensor :
1139+ """
1140+ This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
1141+ (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
1142+ `scale_corrected_output`
1143+ """
1144+ return (corrected .type_as (self .correct_output_scale ) * self .correct_output_scale ).type_as (corrected )
1145+
11381146 def scale_corrected_output (self , corrected : torch .Tensor ) -> torch .Tensor :
11391147 """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
1140- return ( corrected . type_as ( self .correct_output_scale ) * self . correct_output_scale ). type_as (corrected )
1148+ return self .forward (corrected )
11411149
11421150
11431151class Gemma3nTextRotaryEmbedding (nn .Module ):
@@ -1290,7 +1298,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
12901298 self .v_norm = Gemma3nRMSNorm (dim = config .head_dim , eps = config .rms_norm_eps , with_scale = False )
12911299
12921300 first_kv_shared_layer_idx = self .config .num_hidden_layers - self .config .num_kv_shared_layers
1293- self .is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
1301+ self .is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
12941302 # Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
12951303 layer_type = config .layer_types [layer_idx ]
12961304 self .kv_shared_layer_index = (
@@ -1319,21 +1327,22 @@ def forward(
13191327 query_states = query_states .transpose (1 , 2 )
13201328
13211329 if self .is_kv_shared_layer and self .kv_shared_layer_index is not None and past_key_value is not None :
1322- # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
1330+ # Device of past layer may be different from current one
1331+ indices = cache_position .to (past_key_value .key_cache [self .kv_shared_layer_index ].device )
1332+ # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
13231333 if isinstance (past_key_value , HybridCache ) and self .is_sliding :
13241334 max_length = past_key_value .sliding_window
1325- if cache_position .shape [0 ] > max_length :
1326- # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
1327- # slice into the entire cache.
1328- indices = slice (0 , max_length )
1329- else :
1330- # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
1331- indices = cache_position .clamp (min = 0 , max = max_length - 1 )
1332- else :
1333- indices = cache_position
1335+ indices = (
1336+ slice (0 , max_length )
1337+ if cache_position .shape [0 ] > max_length
1338+ else cache_position .clamp (min = 0 , max = max_length - 1 )
1339+ )
13341340
1335- key_states = past_key_value .key_cache [self .kv_shared_layer_index ][:, :, indices ]
1336- value_states = past_key_value .value_cache [self .kv_shared_layer_index ][:, :, indices ]
1341+ # Device of past layer may be different from current one
1342+ key_states = past_key_value .key_cache [self .kv_shared_layer_index ][:, :, indices ].to (query_states .device )
1343+ value_states = past_key_value .value_cache [self .kv_shared_layer_index ][:, :, indices ].to (
1344+ query_states .device
1345+ )
13371346 else :
13381347 key_states = self .k_proj (hidden_states ).view (hidden_shape )
13391348 key_states = self .k_norm (key_states )
@@ -1447,10 +1456,9 @@ def forward(
14471456 attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
14481457 corrected_predictions = self .altup .correct (predictions , attn_ffw_laurel_gated )
14491458
1450- first_prediction = corrected_predictions [self .config .altup_active_idx ]
1451- first_prediction_clone = first_prediction .clone ()
1459+ first_prediction = corrected_predictions [self .config .altup_active_idx ].clone ()
14521460 if self .config .altup_correct_scale :
1453- first_prediction = self .altup .scale_corrected_output (first_prediction_clone )
1461+ first_prediction = self .altup .scale_corrected_output (first_prediction )
14541462
14551463 # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
14561464 first_prediction = self .per_layer_input_gate (first_prediction )
@@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
14751483 config_class = Gemma3nConfig
14761484 base_model_prefix = ""
14771485 supports_gradient_checkpointing = True
1478- _no_split_modules = ["Gemma3nDecoderLayer " ]
1486+ _no_split_modules = ["Gemma3nTextDecoderLayer " ]
14791487 _skip_keys_device_placement = ["past_key_values" ]
14801488 _supports_flash_attn_3 = True
14811489 _supports_flash_attn_2 = True
@@ -1656,18 +1664,17 @@ def forward(
16561664 position_embeddings_local = self .rotary_emb_local (hidden_states_0 , position_ids )
16571665
16581666 # Expand hidden_states to support per-layer inputs
1659- target_magnitude : torch . Tensor = torch .mean (hidden_states_0 ** 2 , dim = - 1 , keepdim = True ) ** 0.5
1660- epsilon_tensor = torch .tensor (torch . finfo (). min )
1667+ target_magnitude = torch .mean (hidden_states_0 ** 2 , dim = - 1 , keepdim = True ) ** 0.5
1668+ epsilon_tensor = torch .tensor (1e-5 )
16611669
16621670 temp_hidden_states = [hidden_states_0 ]
16631671 for i in range (1 , self .config .altup_num_inputs ):
16641672 # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
1665- altup_proj : torch .Tensor = self .altup_projections [i - 1 ](hidden_states_0 )
1666- current_hidden_state = altup_proj .type (hidden_states_0 .dtype )
1667- new_magnitude = torch .mean (current_hidden_state ** 2 , dim = - 1 , keepdim = True ) ** 0.5
1668- current_hidden_state = current_hidden_state * (
1669- target_magnitude / torch .maximum (new_magnitude , epsilon_tensor )
1670- )
1673+ altup_proj = self .altup_projections [i - 1 ](hidden_states_0 )
1674+ current_hidden_state = altup_proj .to (dtype = hidden_states_0 .dtype , device = target_magnitude .device )
1675+ new_magnitude = torch .mean (current_hidden_state ** 2 , dim = - 1 , keepdim = True )
1676+ new_magnitude = torch .sqrt (torch .maximum (new_magnitude , epsilon_tensor .to (target_magnitude .device )))
1677+ current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
16711678 temp_hidden_states .append (current_hidden_state )
16721679
16731680 hidden_states = torch .stack (temp_hidden_states , dim = 0 ) # [num_altup_inputs, batch, seq_len, hidden_size]
@@ -1685,9 +1692,9 @@ def forward(
16851692
16861693 layer_outputs = decoder_layer (
16871694 hidden_states ,
1688- position_embeddings_global = position_embeddings_global ,
1689- position_embeddings_local = position_embeddings_local ,
1690- per_layer_input = per_layer_input ,
1695+ position_embeddings_global ,
1696+ position_embeddings_local ,
1697+ per_layer_input ,
16911698 attention_mask = causal_mask ,
16921699 position_ids = position_ids ,
16931700 past_key_value = past_key_values ,
@@ -1712,11 +1719,10 @@ def forward(
17121719 for i in range (1 , self .config .altup_num_inputs ):
17131720 # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
17141721 altup_unemb_proj : torch .Tensor = self .altup_unembed_projections [i - 1 ](hidden_states [i ])
1715- current_hidden_state = altup_unemb_proj .type (hidden_states_0 .dtype )
1716- new_magnitude = torch .mean (current_hidden_state ** 2 , dim = - 1 , keepdim = True ) ** 0.5
1717- current_hidden_state = current_hidden_state * (
1718- target_magnitude / torch .maximum (new_magnitude , epsilon_tensor )
1719- )
1722+ current_hidden_state = altup_unemb_proj .to (dtype = hidden_states_0 .dtype , device = target_magnitude .device )
1723+ new_magnitude = torch .mean (current_hidden_state ** 2 , dim = - 1 , keepdim = True )
1724+ new_magnitude = torch .sqrt (torch .maximum (new_magnitude , epsilon_tensor .to (target_magnitude .device )))
1725+ current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
17201726 temp_hidden_states .append (current_hidden_state )
17211727
17221728 hidden_states = torch .stack (temp_hidden_states )
@@ -1743,7 +1749,9 @@ def project_per_layer_inputs(
17431749 per_layer_inputs : Optional [torch .Tensor ] = None ,
17441750 ) -> torch .Tensor :
17451751 per_layer_projection : torch .Tensor = self .per_layer_model_projection (inputs_embeds )
1746- per_layer_projection *= self .per_layer_projection_scale .type (inputs_embeds .dtype )
1752+ per_layer_projection *= self .per_layer_projection_scale .to (
1753+ dtype = inputs_embeds .dtype , device = per_layer_projection .device
1754+ )
17471755 per_layer_projection = per_layer_projection .reshape (
17481756 * inputs_embeds .shape [:- 1 ],
17491757 self .config .num_hidden_layers ,
@@ -1758,7 +1766,9 @@ def project_per_layer_inputs(
17581766 # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
17591767 per_layer_inputs = per_layer_inputs [..., : self .config .num_hidden_layers , :]
17601768
1761- return (per_layer_projection + per_layer_inputs ) * self .per_layer_input_scale .type (inputs_embeds .dtype )
1769+ return (per_layer_projection + per_layer_inputs ) * self .per_layer_input_scale .to (
1770+ dtype = inputs_embeds .dtype , device = per_layer_projection .device
1771+ )
17621772
17631773
17641774@auto_docstring (custom_intro = "The base Gemma 3n language model with a language modeling head." )
0 commit comments