Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: RoPE refactor #32135

Merged
merged 30 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1e1d6b4
Add YaRN and Dynamic-YaRN RoPE Scaling Methods
mig-mfreitas May 19, 2024
7efbbc3
Refactor YaRN implementation for LLaMA
miguelm-almeida Jun 16, 2024
f2122cd
Refactor Tensor Building Logic for YaRN
miguelm-almeida Jul 10, 2024
b8df7a2
remove unwanted file
gante Jul 16, 2024
0166869
all diff except the llama folder
gante Jul 22, 2024
7e4e4d8
add updated config
gante Jul 22, 2024
6564195
add updated rope class (and break related copies)
gante Jul 22, 2024
95304b5
related classes
gante Jul 22, 2024
3904d32
llama attention
gante Jul 22, 2024
c7837eb
fa2 (and break a few more copies)
gante Jul 22, 2024
e5e1cde
sdpa (and break a few more copies)
gante Jul 22, 2024
f68b9cd
up to the model class
gante Jul 22, 2024
2f5ace3
up to ForSequenceClassification
gante Jul 22, 2024
5d35287
last set?
gante Jul 22, 2024
4c56e43
missing this one
gante Jul 22, 2024
f36ec3a
make fixup
gante Jul 22, 2024
35699b3
Update src/transformers/modeling_rope_utils.py
gante Jul 22, 2024
b095ebb
Update src/transformers/modeling_rope_utils.py
gante Jul 22, 2024
3f6458b
rename 'type' and 'scaling_type' to a clearer 'rope_type'
gante Jul 22, 2024
6d315ca
abstract out key validation
gante Jul 22, 2024
5809b5e
safety getattr; explicit docstring
gante Jul 22, 2024
a7502ed
docstring nit
gante Jul 22, 2024
3bc7c52
add tests
gante Jul 22, 2024
39e216a
remove external position_embeddings interface
gante Jul 22, 2024
000aeba
test nit
gante Jul 22, 2024
80a0422
Update src/transformers/models/llama/modeling_llama.py
gante Jul 22, 2024
48ed251
Update src/transformers/models/llama/modeling_llama.py
gante Jul 22, 2024
c824be0
make fixu
gante Jul 22, 2024
fc1255e
Merge branch 'main' into llama_rope_refactor
gante Jul 23, 2024
75b2391
make fixup and make fix-copies
gante Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
up to the model class
  • Loading branch information
gante committed Jul 23, 2024
commit f68b9cd5b57d85a8dfed5dd8c3a416134598feb4
3 changes: 2 additions & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,8 @@ def forward(
}


# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# TODO(joao): add me back asap :)
class ChameleonDecoderLayer(nn.Module):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ def _init_weights(self, module):
"The bare Cohere Model outputting raw hidden-states without any specific head on top.",
COHERE_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere
# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereModel(CoherePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`]
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -659,6 +660,9 @@ def forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Expand All @@ -676,6 +680,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -813,6 +818,10 @@ def _init_weights(self, module):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. This input is used to dynamically
overwrite the default positional embeddings.
gante marked this conversation as resolved.
Show resolved Hide resolved
"""


Expand All @@ -838,6 +847,7 @@ def __init__(self, config: LlamaConfig):
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand All @@ -862,6 +872,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -904,10 +915,12 @@ def forward(
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

# embed positions
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers, if not passed
if position_embeddings is None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand All @@ -927,6 +940,7 @@ def forward(
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -937,6 +951,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

hidden_states = layer_outputs[0]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ def forward(
}


# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
# TODO(joao): add me back asap :)
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ def __init__(self, config: OlmoConfig, layer_idx: int):
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)

# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
# TODO(joao): add me back asap :)
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -772,7 +773,8 @@ def set_input_embeddings(self, value):
self.embed_tokens = value

@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
# Copied from transformers.models.llama.modeling_llama.LlamaModel.forward
# copied from transformers.models.llama.modeling_llama.LlamaModel.forward
# TODO(joao): add me back asap :)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down