Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: RoPE refactor #32135

Merged
merged 30 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1e1d6b4
Add YaRN and Dynamic-YaRN RoPE Scaling Methods
mig-mfreitas May 19, 2024
7efbbc3
Refactor YaRN implementation for LLaMA
miguelm-almeida Jun 16, 2024
f2122cd
Refactor Tensor Building Logic for YaRN
miguelm-almeida Jul 10, 2024
b8df7a2
remove unwanted file
gante Jul 16, 2024
0166869
all diff except the llama folder
gante Jul 22, 2024
7e4e4d8
add updated config
gante Jul 22, 2024
6564195
add updated rope class (and break related copies)
gante Jul 22, 2024
95304b5
related classes
gante Jul 22, 2024
3904d32
llama attention
gante Jul 22, 2024
c7837eb
fa2 (and break a few more copies)
gante Jul 22, 2024
e5e1cde
sdpa (and break a few more copies)
gante Jul 22, 2024
f68b9cd
up to the model class
gante Jul 22, 2024
2f5ace3
up to ForSequenceClassification
gante Jul 22, 2024
5d35287
last set?
gante Jul 22, 2024
4c56e43
missing this one
gante Jul 22, 2024
f36ec3a
make fixup
gante Jul 22, 2024
35699b3
Update src/transformers/modeling_rope_utils.py
gante Jul 22, 2024
b095ebb
Update src/transformers/modeling_rope_utils.py
gante Jul 22, 2024
3f6458b
rename 'type' and 'scaling_type' to a clearer 'rope_type'
gante Jul 22, 2024
6d315ca
abstract out key validation
gante Jul 22, 2024
5809b5e
safety getattr; explicit docstring
gante Jul 22, 2024
a7502ed
docstring nit
gante Jul 22, 2024
3bc7c52
add tests
gante Jul 22, 2024
39e216a
remove external position_embeddings interface
gante Jul 22, 2024
000aeba
test nit
gante Jul 22, 2024
80a0422
Update src/transformers/models/llama/modeling_llama.py
gante Jul 22, 2024
48ed251
Update src/transformers/models/llama/modeling_llama.py
gante Jul 22, 2024
c824be0
make fixu
gante Jul 22, 2024
fc1255e
Merge branch 'main' into llama_rope_refactor
gante Jul 23, 2024
75b2391
make fixup and make fix-copies
gante Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove external position_embeddings interface
  • Loading branch information
gante committed Jul 23, 2024
commit 39e216a2c9d3be56cc4937883650a2f01a176fd0
3 changes: 1 addition & 2 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,7 @@ def _update_causal_mask(
return causal_mask


# copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
# TODO(joao): add me back asap :)
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ def prepare_inputs_for_generation(
""",
JETMOE_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral->JetMoe, MISTRAL->JETMOE
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoe, LLAMA->JETMOE
class JetMoeForSequenceClassification(JetMoePreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1363,7 +1363,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down
18 changes: 2 additions & 16 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,6 @@ def _init_weights(self, module):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. This input is used to dynamically
overwrite the default positional embeddings.
"""


Expand Down Expand Up @@ -876,7 +872,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -921,9 +916,8 @@ def forward(
)
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers, if not passed
if position_embeddings is None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down Expand Up @@ -1111,7 +1105,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1156,7 +1149,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -1282,7 +1274,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand All @@ -1302,7 +1293,6 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_embeddings=position_embeddings,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
Expand Down Expand Up @@ -1401,7 +1391,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand All @@ -1424,7 +1413,6 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_embeddings=position_embeddings,
)

sequence_output = outputs[0]
Expand Down Expand Up @@ -1507,7 +1495,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand All @@ -1527,7 +1514,6 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_embeddings=position_embeddings,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,8 +1129,7 @@ def prepare_inputs_for_generation(
""",
MISTRAL_START_DOCSTRING,
)
# copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
# TODO(joao): add me back asap :)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
class MistralForSequenceClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1150,7 +1149,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -1246,8 +1245,7 @@ def forward(
""",
MISTRAL_START_DOCSTRING,
)
# copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
# TODO(joao): add me back asap :)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
class MistralForTokenClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ def prepare_inputs_for_generation(
""",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral->Mixtral, MISTRAL->MIXTRAL
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForSequenceClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1362,7 +1362,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -1458,7 +1458,7 @@ def forward(
""",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForTokenClassification with Mistral->Mixtral, MISTRAL->MIXTRAL
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForTokenClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,7 @@ def _update_causal_mask(
return causal_mask


# copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
# TODO(joao): add me back asap :)
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
class OlmoForCausalLM(OlmoPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def prepare_inputs_for_generation(
""",
PERSIMMON_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with MISTRAL->PERSIMMON,Mistral->Persimmon
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PERSIMMON,Llama->Persimmon
class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -999,7 +999,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -1095,7 +1095,7 @@ def forward(
""",
PERSIMMON_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForTokenClassification with Mistral->Persimmon, MISTRAL->PERSIMMON
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON
class PersimmonForTokenClassification(PersimmonPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def prepare_inputs_for_generation(
""",
PHI_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with MISTRAL->PHI,Mistral->Phi with self.transformer->self.model, transformer_outputs->model_outputs
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
class PhiForSequenceClassification(PhiPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1282,7 +1282,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def prepare_inputs_for_generation(
""",
PHI3_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with MISTRAL->PHI3, Mistral->Phi3, self.transformer->self.model, transformer_outputs->model_outputs
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
class Phi3ForSequenceClassification(Phi3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1278,7 +1278,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ def forward(
""",
QWEN2_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForTokenClassification with Mistral->Qwen2, MISTRAL->QWEN2
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def prepare_inputs_for_generation(
""",
QWEN2MOE_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1370,7 +1370,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -1466,7 +1466,7 @@ def forward(
""",
QWEN2MOE_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForTokenClassification with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
return causal_mask


# copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
# TODO(joao): add me back asap :)
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ def prepare_inputs_for_generation(
""",
STABLELM_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with MISTRAL->STABLELM,Mistral->StableLm
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->STABLELM,Llama->StableLm
class StableLmForSequenceClassification(StableLmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1275,7 +1275,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -1371,7 +1371,7 @@ def forward(
""",
STABLELM_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForTokenClassification with Mistral->StableLm, MISTRAL->STABLELM
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM
class StableLmForTokenClassification(StableLmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def prepare_inputs_for_generation(
""",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral->Starcoder2, MISTRAL->STARCODER2
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Starcoder2, LLAMA->STARCODER2
class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1153,7 +1153,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -1249,7 +1249,7 @@ def forward(
""",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForTokenClassification with Mistral->Starcoder2, MISTRAL->STARCODER2
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand Down