Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen/Qwen2-7B-Instruct gives garbled outputs in LongBench with load_in_low_bit="fp16" and optimize_model=False #11796

Open
ATMxsp01 opened this issue Aug 14, 2024 · 0 comments

Comments

@ATMxsp01
Copy link
Contributor

ATMxsp01 commented Aug 14, 2024

device: Intel(R) Data Center GPU Max 1100
ipex-llm: 2.1.0b20240813
transformers: 4.37.0
model: Qwen/Qwen2-7B-Instruct


It's confirmed :

  • set optimize_model=True can prevent qwen2 from outputting garbled words.
  • This also happened in Arc GPU and precision=sym_int4

Here is a minimal reproducible case, which use dataset multi_news as an example, run this example:

from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from datasets import load_dataset
import json
from tqdm import tqdm
import numpy as np
import random
import torch


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)


@torch.inference_mode()
def get_pred_single_gpu(data, max_gen, 
                        prompt_format, model_path, out_path):
    

    model = AutoModelForCausalLM.from_pretrained(
                model_path,
                optimize_model=False,
                load_in_low_bit="fp16",
                use_cache=True,
                torch_dtype = torch.float16,
    ).to("xpu").eval()
    tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            padding_side="right",
            use_fast=False,
            trust_remote_code=True,
    )

    device = model.device
    print(f"model_device: {model.device}")
    

    for json_obj in tqdm(data):
        
        
        prompt = prompt_format.format(**json_obj)
        
        input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        
        context_length = input.input_ids.shape[-1]
        print(f'context_length = {context_length}')
    
        output = model.generate(
                **input,
                max_new_tokens=max_gen,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
                min_length=context_length+1,
            )[0]
        
        pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)

        with open(out_path, "a", encoding="utf-8") as f:
            json.dump({
                        "prompt": prompt,
                        "pred": pred, 
                        "answers": json_obj["answers"], 
                        "all_classes": json_obj["all_classes"], 
                        "length": json_obj["length"]
                       }, 
                       f, ensure_ascii=False, indent=4)
            f.write('\n')


if __name__ == '__main__':
    seed_everything(42)

    model_name = "qwen-2"
    model_path = "Qwen/Qwen2-7B-Instruct"

    dataset = "multi_news"
 
    # prompt format in config/dataset2prompt.json
    prompt_format = "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:"
    max_gen = 512    # 512 in config/dataset2maxlen.json

    data = load_dataset('THUDM/LongBench', dataset, split='test')
    data_all = [data_sample for data_sample in data][0:4]    
    
    out_path = "./qwen-check-out.jsonl"
    
    get_pred_single_gpu(data_all, max_gen, prompt_format, model_path, out_path)

Here is some of the Qwen's garbled output:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant