Skip to content

Commit 7126879

Browse files
[Llama3] Fix: add pad token fallback and improve tensor reshaping (#3025)
1 parent 21e2fdc commit 7126879

File tree

7 files changed

+20
-10
lines changed

7 files changed

+20
-10
lines changed

paddleformers/cli/train/auto_parallel/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
208208

209209
config = config_class.from_pretrained(model_args.model_name_or_path)
210210
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
211+
if tokenizer.pad_token_id is None:
212+
tokenizer.pad_token_id = tokenizer.eos_token_id
211213
# config = AutoConfig.from_pretrained(model_args.model_name_or_path)
212214
LlmMetaConfig.set_llm_config(config, training_args)
213215
config.use_fast_layer_norm = model_args.use_fast_layer_norm

paddleformers/cli/train/dpo/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def run_dpo(
191191
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
192192
else:
193193
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
194+
if tokenizer.pad_token_id is None:
195+
tokenizer.pad_token_id = tokenizer.eos_token_id
194196

195197
logger.info("Loading model & tokenizer successfully !")
196198

paddleformers/cli/train/pretrain/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,8 @@ def run_dsv3_pretrain(model_args, data_args, generating_args, training_args):
409409
)
410410

411411
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
412+
if tokenizer.pad_token_id is None:
413+
tokenizer.pad_token_id = tokenizer.eos_token_id
412414
config = DeepseekV2FastConfig.from_pretrained(model_args.model_name_or_path)
413415

414416
# set all llm config

paddleformers/cli/train/sft/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def neft_post_hook(module, input, output):
266266

267267
# Load tokenizer & dataset
268268
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
269+
if tokenizer.pad_token_id is None:
270+
tokenizer.pad_token_id = tokenizer.eos_token_id
269271

270272
# if using chat_template, data_args.eval_with_do_generation must be false
271273
if tokenizer.chat_template is not None:

paddleformers/transformers/llama/modeling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ def forward(
159159
q_shape = (batch_size, seq_len, self.num_heads, self.head_dim)
160160
kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim)
161161

162-
query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2)
163-
key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2)
164-
value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2)
162+
query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2)
163+
key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
164+
value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
165165

166166
cos, sin = position_embeddings
167167
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

tests/transformers/llama/test_modeling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def test_inference_no_attention(self):
439439
"Paddleformers/tiny-random-llama3",
440440
download_hub="aistudio",
441441
convert_from_hf=True,
442+
dtype="float32",
442443
)
443444
model.eval()
444445
input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
@@ -452,9 +453,9 @@ def test_inference_no_attention(self):
452453
expected_slice = paddle.to_tensor(
453454
[
454455
[
455-
[0.02366970, -0.42482421, 0.47202760],
456-
[-0.12180223, 0.00559035, 0.83846688],
457-
[0.45073321, 0.25703996, 1.36826384],
456+
[0.01802453, -0.42128855, 0.45844582],
457+
[-0.12787277, 0.00660499, 0.83033413],
458+
[0.44403678, 0.26123494, 1.36080980],
458459
]
459460
],
460461
dtype=output.dtype,
@@ -467,6 +468,7 @@ def test_inference_with_attention(self):
467468
"Paddleformers/tiny-random-llama3",
468469
download_hub="aistudio",
469470
convert_from_hf=True,
471+
dtype="float32",
470472
)
471473
model.eval()
472474
input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
@@ -479,9 +481,9 @@ def test_inference_with_attention(self):
479481
expected_slice = paddle.to_tensor(
480482
[
481483
[
482-
[0.02366970, -0.42482421, 0.47202760],
483-
[-0.12180223, 0.00559035, 0.83846688],
484-
[0.45073321, 0.25703996, 1.36826384],
484+
[0.01802453, -0.42128855, 0.45844582],
485+
[-0.12787277, 0.00660499, 0.83033413],
486+
[0.44403678, 0.26123494, 1.36080980],
485487
]
486488
],
487489
dtype=output.dtype,

tests/transformers/test_shard_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self):
8181
convert_from_hf=convert,
8282
)
8383
for p1, p2 in zip(m1.parameters(), m2.parameters()):
84-
self.assertTrue(paddle.allclose(p1, p2))
84+
self.assertTrue(paddle.allclose(p1.float(), p2.float()))
8585

8686
@unittest.skipIf(not is_paddle_cuda_available(), "some op is missing in cpu mode")
8787
def test_keep_in_fp32_modules(self):

0 commit comments

Comments
 (0)