Skip to content

Commit

Permalink
add qwen conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueZeros committed Mar 14, 2024
1 parent 8b31023 commit 5b1a732
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
23 changes: 23 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SeparatorStyle(Enum):
CHATGLM = auto()
DOCTOR = auto()
BLOOM = auto()
QWEN = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -92,6 +93,15 @@ def get_prompt(self):
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.QWEN:
seps = [self.sep, self.sep2]
ret = f"<|im_start|>system\n{self.system}<|im_end|>\n"
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + "\n" + message + seps[i % 2]
else:
ret += role + "\n"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

Expand Down Expand Up @@ -237,20 +247,33 @@ def dict(self):
sep2="</s>",
)

conv_qwen = Conversation(
system="You are a helpful assistant.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.QWEN,
sep="<|im_end|>\n",
sep2="<|im_end|>\n",
)

conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1.1": conv_vicuna_v1_1,
"koala_v1": conv_koala_v1,
"dolly": conv_dolly,
"baichuan": conv_baichuan,
"bloom": conv_bloom,
"qwen": conv_qwen
}


def get_default_conv_template(model_name):
model_name = model_name.lower()
if "vicuna" in model_name or "output" in model_name:
return conv_vicuna_v1_1
elif "qwen" in model_name:
return conv_qwen
elif "baichuan" in model_name:
# print("load conv_baichuan")
return conv_baichuan
Expand Down
6 changes: 4 additions & 2 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def generate_stream(model, tokenizer, params, device, beam_size,
temperature=temperature,
)

output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
outputs = outputs[0][len(input_ids[0]):]
output = tokenizer.decode(outputs, skip_special_tokens=True)


return output

Expand Down Expand Up @@ -204,7 +206,7 @@ def chat_loop(model_path: str, device: str, num_gpus: str,
chatio.prompt_for_output(conv.roles[1])
context_len = len(prompt) + max_new_tokens + 8
T1 = time.time()
output_stream = generate_stream_func(model, tokenizer, params, device, beam_size,context_len=context_len)[-1][skip_echo_len:].strip()
output_stream = generate_stream_func(model, tokenizer, params, device, beam_size,context_len=context_len)
T2 = time.time()
if debug:
print('程序运行时间:%s秒' % ((T2 - T1)))
Expand Down

0 comments on commit 5b1a732

Please sign in to comment.