Skip to content

Commit 3fcef13

Browse files
committed
Remove unneeded use of sliding_window
1 parent c414288 commit 3fcef13

File tree

3 files changed

+73
-13
lines changed

3 files changed

+73
-13
lines changed

src/transformers/models/flex_olmo/configuration_flex_olmo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,5 @@ def __init__(
195195
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
196196
rope_config_validation(self)
197197

198-
# Set this to None because it is needed by MixtralModel
199-
self.sliding_window = None
200-
201198

202199
__all__ = ["FlexOlmoConfig"]

src/transformers/models/flex_olmo/modeling_flex_olmo.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,19 @@
2525
import torch.nn.functional as F
2626
from torch import nn
2727

28-
from transformers.utils.generic import check_model_inputs
29-
3028
from ...activations import ACT2FN
3129
from ...cache_utils import Cache, DynamicCache
3230
from ...generation import GenerationMixin
3331
from ...integrations import use_kernel_forward_from_hub
34-
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
32+
from ...masking_utils import create_causal_mask
3533
from ...modeling_layers import GradientCheckpointingLayer
3634
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
3735
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
3836
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3937
from ...processing_utils import Unpack
4038
from ...utils import TransformersKwargs, auto_docstring
4139
from ...utils.deprecation import deprecate_kwarg
42-
from ...utils.generic import OutputRecorder
40+
from ...utils.generic import OutputRecorder, check_model_inputs
4341
from .configuration_flex_olmo import FlexOlmoConfig
4442

4543

@@ -449,8 +447,7 @@ def forward(
449447
if position_ids is None:
450448
position_ids = cache_position.unsqueeze(0)
451449

452-
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
453-
causal_mask = mask_function(
450+
causal_mask = create_causal_mask(
454451
config=self.config,
455452
input_embeds=inputs_embeds,
456453
attention_mask=attention_mask,

src/transformers/models/flex_olmo/modular_flex_olmo.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
import torch
1919

20-
from ...cache_utils import Cache
20+
from ...cache_utils import Cache, DynamicCache
21+
from ...masking_utils import create_causal_mask
22+
from ...modeling_outputs import MoeModelOutputWithPast
23+
from ...processing_utils import Unpack
24+
from ...utils import TransformersKwargs, auto_docstring
25+
from ...utils.generic import check_model_inputs
2126
from ..mixtral.modeling_mixtral import MixtralModel, MixtralPreTrainedModel
2227
from ..olmo2.modeling_olmo2 import Olmo2Attention, Olmo2RMSNorm, Olmo2RotaryEmbedding
2328
from ..olmoe.configuration_olmoe import OlmoeConfig
@@ -190,8 +195,6 @@ def __init__(
190195
**kwargs,
191196
)
192197

193-
# Set this to None because it is needed by MixtralModel
194-
self.sliding_window = None
195198
del self.clip_qkv
196199

197200

@@ -271,8 +274,71 @@ class FlexOlmoPreTrainedModel(MixtralPreTrainedModel):
271274

272275
# FlexOlmo uses Mixtral model as its base instead of OlmoE model since Mixtral is more up-to-date with the rest
273276
# of the transformers library. For example, it uses the newer mechanisms of recording submodule outputs.
277+
# FlexOlmo model is identical to Mixtral model except:
278+
# - FlexOlmo does not use sliding window attention.
274279
class FlexOlmoModel(MixtralModel):
275-
pass
280+
@check_model_inputs
281+
@auto_docstring
282+
def forward(
283+
self,
284+
input_ids: Optional[torch.LongTensor] = None,
285+
attention_mask: Optional[torch.Tensor] = None,
286+
position_ids: Optional[torch.LongTensor] = None,
287+
past_key_values: Optional[Cache] = None,
288+
inputs_embeds: Optional[torch.FloatTensor] = None,
289+
use_cache: Optional[bool] = None,
290+
cache_position: Optional[torch.LongTensor] = None,
291+
**kwargs: Unpack[TransformersKwargs],
292+
) -> MoeModelOutputWithPast:
293+
if (input_ids is None) ^ (inputs_embeds is not None):
294+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
295+
296+
if use_cache and past_key_values is None:
297+
past_key_values = DynamicCache(config=self.config)
298+
299+
if inputs_embeds is None:
300+
inputs_embeds = self.embed_tokens(input_ids)
301+
302+
if cache_position is None:
303+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
304+
cache_position = torch.arange(
305+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
306+
)
307+
if position_ids is None:
308+
position_ids = cache_position.unsqueeze(0)
309+
310+
causal_mask = create_causal_mask(
311+
config=self.config,
312+
input_embeds=inputs_embeds,
313+
attention_mask=attention_mask,
314+
cache_position=cache_position,
315+
past_key_values=past_key_values,
316+
position_ids=position_ids,
317+
)
318+
319+
hidden_states = inputs_embeds
320+
321+
# create position embeddings to be shared across the decoder layers
322+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
323+
324+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
325+
hidden_states = decoder_layer(
326+
hidden_states,
327+
position_embeddings=position_embeddings,
328+
attention_mask=causal_mask,
329+
position_ids=position_ids,
330+
past_key_values=past_key_values,
331+
use_cache=use_cache,
332+
cache_position=cache_position,
333+
**kwargs,
334+
)
335+
336+
hidden_states = self.norm(hidden_states)
337+
338+
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
339+
last_hidden_state=hidden_states,
340+
past_key_values=past_key_values,
341+
)
276342

277343

278344
class FlexOlmoForCausalLM(OlmoeForCausalLM):

0 commit comments

Comments
 (0)