Skip to content

Commit

Permalink
[LLM] fix Qwen-7b-Chat precision issue
Browse files Browse the repository at this point in the history
fix qwen-7b0chat model batch inference precision issue
add Qwen-7B-Chat to PaddleNLP unit test
  • Loading branch information
ziangqin-baidu committed Jan 15, 2024
1 parent 04142e3 commit 2b55d7a
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,7 @@ def create_predictor(

tensor_parallel_rank, tensor_parallel_degree = init_dist_env()
if not predictor_args.inference_model:
tokenizer.padding_side = "left"
if predictor_args.mode == "dynamic":
if model_args.model_type == "gpt-3":
sys.path.append("./gpt-3")
Expand Down
4 changes: 1 addition & 3 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,10 +1059,8 @@ def greedy_search(

# pre-process distribution
next_token_logits = self.adjust_logits_during_generation(next_token_logits)
next_tokens_scores = logits_processors(input_ids, next_token_logits)
probs = logits_processors(input_ids, next_token_logits)
# greedy
probs = F.softmax(next_tokens_scores)
probs = paddle.log(probs)
next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
next_scores = paddle.index_sample(probs, next_tokens)

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self,
vocab_file,
errors="replace",
padding_side="left",
**kwargs,
):
super().__init__(**kwargs)
Expand Down
2 changes: 2 additions & 0 deletions tests/llm/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ChatGLMForCausalLM,
ChatGLMv2ForCausalLM,
LlamaForCausalLM,
QWenForCausalLM,
)
from paddlenlp.utils.downloader import (
COMMUNITY_MODEL_PREFIX,
Expand All @@ -43,6 +44,7 @@
["__internal_testing__/tiny-fused-bloom", BloomForCausalLM],
["__internal_testing__/tiny-fused-chatglm", ChatGLMForCausalLM],
["__internal_testing__/tiny-fused-chatglm2", ChatGLMv2ForCausalLM],
["__internal_testing__/tiny-fused-qwen", QWenForCausalLM],
],
)
class PredictorTest(LLMTest, unittest.TestCase):
Expand Down

0 comments on commit 2b55d7a

Please sign in to comment.