Skip to content

Commit

Permalink
edit exllamav2_web_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivicai1005 committed Feb 6, 2024
1 parent ecd94df commit 39c8dbf
Showing 1 changed file with 1 addition and 116 deletions.
117 changes: 1 addition & 116 deletions apps/exllamav2_web_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import streamlit as st
import torch.cuda

from other_infer.exllamav2_hf_infer import get_model
from utils.streaming import generate_stream
from other_infer.exllamav2_hf_infer import get_model, generate_stream
import argparse

parser = argparse.ArgumentParser()
Expand All @@ -21,11 +20,9 @@ def cached_get_model(model_path):


model, tokenizer, generation_config = cached_get_model(args.model_path)
generation_config.do_sample = False
generation_config.max_length = args.max_input_length + args.max_generate_length
generation_config.max_new_tokens = args.max_generate_length


tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Response:\n"
tok_eos = "</s>"
Expand All @@ -34,118 +31,6 @@ def cached_get_model(model_path):

device = f"cuda:{torch.cuda.current_device()}"


def main(
model_path: str,
max_input_length: int = 512,
max_generate_length: int = 2048,
stream: bool = True,
):
print(f"loading model: {model_path}...")

model = get_model(model_path)
device = torch.cuda.current_device()
model, tokenizer, generation_config = get_model(model_path)

generation_config.do_sample = False
generation_config.max_length = max_input_length + max_generate_length
generation_config.max_new_tokens = max_generate_length

sess_text = ""

streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
generation_kwargs = generation_config.to_dict()

def eval_generate(**args):
with torch.inference_mode(mode=True):
model.eval()
model.generate(**args)

with torch.inference_mode(mode=True):
model.eval()
while True:
raw_text = input(
'prompt("exit" to end, "clear" to clear session) >>> '
)
if not raw_text:
print("prompt should not be empty!")
continue
if raw_text.strip() == "exit":
print("session ended.")
break
if raw_text.strip() == "clear":
print("session cleared.")
sess_text = ""
continue

query_text = raw_text.strip()
sess_text += tok_ins + query_text
input_text = prompt_input.format_map(
{"instruction": sess_text.split(tok_ins, 1)[1]}
)
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
)
tic = time.perf_counter()
if stream:
generation_kwargs["streamer"] = streamer
for k, v in inputs.items():
generation_kwargs[k] = v.to(device)
thread = Thread(
target=eval_generate, kwargs=generation_kwargs
)
thread.start()
answer = ""
flag = False
print("=" * 100)
for new_text in streamer:
if new_text.endswith(tokenizer.eos_token):
new_text = new_text.rsplit(
tokenizer.eos_token, 1
)[0].strip()
flag = True
print(new_text, end="")
answer += new_text
if flag:
break
toc = time.perf_counter()
num_tok = len(tokenizer.encode(answer))
else:
if "streamer" in generation_kwargs:
del generation_kwargs["streamer"]
inputs = {k: v.to(device) for k, v in inputs.items()}
output = model.generate(
**inputs, **generation_config.to_dict()
)
toc = time.perf_counter()
num_tok = output.shape[1]
output_str = tokenizer.decode(
output[0],
skip_special_tokens=False,
spaces_between_special_tokens=False,
)
answer = output_str.rsplit(tok_res, 1)[1].strip()
if answer.endswith(tokenizer.eos_token):
answer = answer.rsplit(tokenizer.eos_token, 1)[
0
].strip()
print(answer)

sess_text += tok_res + answer
res_time = toc - tic
print(
f"\n[time: {res_time:0.4f} sec, speed: {num_tok / res_time:0.4f} tok/sec]"
)
print("=" * 100)

if "messages" not in st.session_state:
st.session_state.messages = list()

Expand Down

0 comments on commit 39c8dbf

Please sign in to comment.