Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support llama-3 #8307

Merged
merged 3 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Llama3Tokenizer,
LlamaTokenizer,
)
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -232,7 +233,7 @@ def neft_post_hook(module, input, output):
if tokenizer.chat_template is not None:
data_args.eval_with_do_generation = False

if isinstance(tokenizer, LlamaTokenizer):
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id

if data_args.dataset_name_or_path is None:
Expand Down
21 changes: 14 additions & 7 deletions paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,20 @@
init_class = init_kwargs.pop("tokenizer_class", None)

if init_class:
class_name = cls._name_mapping[init_class]
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
tokenizer_class = getattr(import_class, init_class)
if use_fast:
fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name)
tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class
return tokenizer_class
if init_class in cls._name_mapping:
class_name = cls._name_mapping[init_class]
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
tokenizer_class = getattr(import_class, init_class)
if use_fast:
fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name)
tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class
return tokenizer_class
else:
import_class = import_module("paddlenlp.transformers")
tokenizer_class = getattr(import_class, init_class, None)
assert tokenizer_class is not None, f"Can't find tokenizer {init_class}"
return tokenizer_class

Check warning on line 205 in paddlenlp/transformers/auto/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto/tokenizer.py#L202-L205

Added lines #L202 - L205 were not covered by tests

# If no `init_class`, we use pattern recognition to recognize the tokenizer class.
else:
# TODO: Potential issue https://github.com/PaddlePaddle/PaddleNLP/pull/3786#discussion_r1024689810
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
num_key_value_heads=None,
initializer_range=0.02,
rms_norm_eps=1e-6,
rope_theta=10000.0,
use_cache=True,
use_recompute=False,
recompute_granularity="full",
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(

self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta

self.use_cache = use_cache
self.use_recompute = use_recompute
Expand Down
9 changes: 6 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,24 +813,28 @@ def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "ntk":
self.rotary_emb = LlamaNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "dynamic_ntk":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}")
Expand Down Expand Up @@ -903,6 +907,7 @@ def forward(
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if self.reshard_layer is not None:
if self.sequence_parallel:
assert self.seq_length % self.config.sep_parallel_degree == 0
Expand Down Expand Up @@ -1027,7 +1032,6 @@ def forward(
value_states = paddle.concat([past_key_value[1], value_states], axis=1)

past_key_value = (key_states, value_states) if use_cache else None

if self.kv_indices is not None:
key_states = paddle.index_select(key_states, self.kv_indices, axis=2)
value_states = paddle.index_select(value_states, self.kv_indices, axis=2)
Expand All @@ -1036,7 +1040,7 @@ def forward(
# repeat k/v heads if n_kv_heads < n_heads
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
paddle_version = float(paddle.__version__[:3])
if (paddle_version != 0.0) and (paddle_version <= 2.6):
if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

Expand Down Expand Up @@ -1560,7 +1564,6 @@ def forward(
else:
attention_mask = attention_mask.astype("bool")
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down
Loading
Loading