|
17 | 17 |
|
18 | 18 | import torch |
19 | 19 |
|
20 | | -from ...cache_utils import Cache |
| 20 | +from ...cache_utils import Cache, DynamicCache |
| 21 | +from ...masking_utils import create_causal_mask |
| 22 | +from ...modeling_outputs import MoeModelOutputWithPast |
| 23 | +from ...processing_utils import Unpack |
| 24 | +from ...utils import TransformersKwargs, auto_docstring |
| 25 | +from ...utils.generic import check_model_inputs |
21 | 26 | from ..mixtral.modeling_mixtral import MixtralModel, MixtralPreTrainedModel |
22 | 27 | from ..olmo2.modeling_olmo2 import Olmo2Attention, Olmo2RMSNorm, Olmo2RotaryEmbedding |
23 | 28 | from ..olmoe.configuration_olmoe import OlmoeConfig |
@@ -190,8 +195,6 @@ def __init__( |
190 | 195 | **kwargs, |
191 | 196 | ) |
192 | 197 |
|
193 | | - # Set this to None because it is needed by MixtralModel |
194 | | - self.sliding_window = None |
195 | 198 | del self.clip_qkv |
196 | 199 |
|
197 | 200 |
|
@@ -271,8 +274,71 @@ class FlexOlmoPreTrainedModel(MixtralPreTrainedModel): |
271 | 274 |
|
272 | 275 | # FlexOlmo uses Mixtral model as its base instead of OlmoE model since Mixtral is more up-to-date with the rest |
273 | 276 | # of the transformers library. For example, it uses the newer mechanisms of recording submodule outputs. |
| 277 | +# FlexOlmo model is identical to Mixtral model except: |
| 278 | +# - FlexOlmo does not use sliding window attention. |
274 | 279 | class FlexOlmoModel(MixtralModel): |
275 | | - pass |
| 280 | + @check_model_inputs |
| 281 | + @auto_docstring |
| 282 | + def forward( |
| 283 | + self, |
| 284 | + input_ids: Optional[torch.LongTensor] = None, |
| 285 | + attention_mask: Optional[torch.Tensor] = None, |
| 286 | + position_ids: Optional[torch.LongTensor] = None, |
| 287 | + past_key_values: Optional[Cache] = None, |
| 288 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 289 | + use_cache: Optional[bool] = None, |
| 290 | + cache_position: Optional[torch.LongTensor] = None, |
| 291 | + **kwargs: Unpack[TransformersKwargs], |
| 292 | + ) -> MoeModelOutputWithPast: |
| 293 | + if (input_ids is None) ^ (inputs_embeds is not None): |
| 294 | + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| 295 | + |
| 296 | + if use_cache and past_key_values is None: |
| 297 | + past_key_values = DynamicCache(config=self.config) |
| 298 | + |
| 299 | + if inputs_embeds is None: |
| 300 | + inputs_embeds = self.embed_tokens(input_ids) |
| 301 | + |
| 302 | + if cache_position is None: |
| 303 | + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| 304 | + cache_position = torch.arange( |
| 305 | + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| 306 | + ) |
| 307 | + if position_ids is None: |
| 308 | + position_ids = cache_position.unsqueeze(0) |
| 309 | + |
| 310 | + causal_mask = create_causal_mask( |
| 311 | + config=self.config, |
| 312 | + input_embeds=inputs_embeds, |
| 313 | + attention_mask=attention_mask, |
| 314 | + cache_position=cache_position, |
| 315 | + past_key_values=past_key_values, |
| 316 | + position_ids=position_ids, |
| 317 | + ) |
| 318 | + |
| 319 | + hidden_states = inputs_embeds |
| 320 | + |
| 321 | + # create position embeddings to be shared across the decoder layers |
| 322 | + position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| 323 | + |
| 324 | + for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| 325 | + hidden_states = decoder_layer( |
| 326 | + hidden_states, |
| 327 | + position_embeddings=position_embeddings, |
| 328 | + attention_mask=causal_mask, |
| 329 | + position_ids=position_ids, |
| 330 | + past_key_values=past_key_values, |
| 331 | + use_cache=use_cache, |
| 332 | + cache_position=cache_position, |
| 333 | + **kwargs, |
| 334 | + ) |
| 335 | + |
| 336 | + hidden_states = self.norm(hidden_states) |
| 337 | + |
| 338 | + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE |
| 339 | + last_hidden_state=hidden_states, |
| 340 | + past_key_values=past_key_values, |
| 341 | + ) |
276 | 342 |
|
277 | 343 |
|
278 | 344 | class FlexOlmoForCausalLM(OlmoeForCausalLM): |
|
0 commit comments