Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

benchmark script #1

Merged
merged 1 commit into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions tests/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import torch
from mlc_llm import utils
import argparse, os, time
from transformers import AutoTokenizer, AutoModelForCausalLM
import tvm
from tvm import relax
from mlc_llm.conversation import SeparatorStyle, conv_templates
from utils import get_tokenizer, get_pytorch_model, get_tvm_model, sample_top_p

torch_device = None


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument("--device-name", type=str, default="auto")
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--model", type=str, default="vicuna-v1-7b")
args.add_argument("--max-gen-len", type=int, default=2048)
args.add_argument("--run-torch-model", action="store_true", default=False)
args.add_argument(
"--dtype", type=str, choices=["float32", "float16", "int4"], default="float16"
)
parsed = args.parse_args()
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
parsed.artifact_path = os.path.join(
parsed.artifact_path, parsed.model, parsed.dtype
)

if parsed.device_name == "auto":
if tvm.cuda().exist:
parsed.device_name = "cuda"
elif tvm.metal().exist:
parsed.device_name = "metal"
else:
raise ValueError("Cannot auto deduce device-name, please set it")
return parsed


class ModelWrapper:
def __init__(self, tokenizer, max_gen_len):
self.tokenizer = tokenizer
self.max_gen_len = max_gen_len

def generate(
self,
prompt: str,
temperature: float = 0.8,
top_p: float = 0.95,
stream_interval: int = 2,
stop_str: str = None,
add_bos=True,
):
assert 0, "Need to implement"


class TvmModelWrapper(ModelWrapper):
def __init__(
self, tokenizer, max_gen_len, artifact_path, model, device_name, dtype
):
super().__init__(tokenizer, max_gen_len)
self.model = get_tvm_model(artifact_path, model, device_name, dtype)

def generate(
self,
prompt_tokens,
prompt_len,
temperature: float = 0.8,
top_p: float = 0.95,
stream_interval: int = 2,
stop_str: str = None,
add_bos=True,
):
total_len = self.max_gen_len + len(prompt_tokens)
# TODO: Replace Torch ops to TVM
tokens = (
torch.full((1, total_len), self.tokenizer.pad_token_id)
.to(torch.int32)
.to(torch_device)
)
tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens).to(torch_device)

start_pos = len(prompt_tokens)
for cur_pos in range(start_pos, total_len):
if cur_pos == start_pos:
logits = self.model(tokens[:, :cur_pos])
else:
logits = self.model(tokens[:, cur_pos - 1 : cur_pos])

logits = logits[:, -1, :]
if temperature > 0:
probs = torch.softmax(
(logits / temperature).to(torch.float32), dim=-1
).to(torch_device)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1).to(torch_device)
next_token = next_token.reshape(-1)
tokens[:, cur_pos] = next_token
# the following code assumes bsz == 1
if next_token[0] == tokenizer.eos_token_id:
stopped = True
else:
stopped = False

i = cur_pos - start_pos
if i % stream_interval == 0 or i == self.max_gen_len - 1 or stopped:
# TODO: Parallelize decoding
output = tokens[0, : cur_pos + 1]
output = tokenizer.decode(output, skip_special_tokens=True)
pos = output.rfind(stop_str, prompt_len)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break


class TorchModelWrapper(ModelWrapper):
def __init__(self, tokenizer, max_gen_len, model_path, torch_device, dtype):
super().__init__(tokenizer, max_gen_len)
self.model = get_pytorch_model(model_path, torch_device, dtype)

def generate(
self,
prompt_tokens,
prompt_len,
temperature: float = 0.8,
top_p: float = 0.95,
stream_interval: int = 2,
stop_str: str = None,
add_bos=True,
):
total_len = self.max_gen_len + len(prompt_tokens)
tokens = (
torch.full((1, total_len), self.tokenizer.pad_token_id)
.to(torch.int32)
.to(torch_device)
)
tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens).to(torch_device)
start_pos = len(prompt_tokens)
for cur_pos in range(start_pos, total_len):
logits = self.model(tokens[:, :cur_pos])
logits = logits[:, -1, :]
if temperature > 0:
probs = torch.softmax(
(logits / temperature).to(torch.float32), dim=-1
).to(torch_device)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1).to(torch_device)
next_token = next_token.reshape(-1)
tokens[:, cur_pos] = next_token
# the following code assumes bsz == 1
if next_token[0] == tokenizer.eos_token_id:
stopped = True
else:
stopped = False

i = cur_pos - start_pos
if i % stream_interval == 0 or i == self.max_gen_len - 1 or stopped:
# TODO: Parallelize decoding
output = tokens[0, : cur_pos + 1]
output = tokenizer.decode(output, skip_special_tokens=True)
pos = output.rfind(stop_str, prompt_len)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break


def chat(model_wrapper, user_inps):
conv = conv_templates["vicuna_v1.1"].copy()
add_bos = True

for iid, inp in enumerate(user_inps):
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
# TODO: Torch does not work with the following function for multi-round conv.
# Check this w/ authors.
# prompt = conv.get_prompt_unprocessed()
prompt = conv.get_prompt()

print(f"=== Input {iid+1} ===")
print(f"{conv.roles[0]}: {inp}", flush=True)
print(f"{conv.roles[1]}: ", end="", flush=True)

t0 = time.time()
prompt_tokens = model_wrapper.tokenizer.encode(prompt)
if not add_bos:
prompt_tokens = prompt_tokens[1:]

prompt_len = len(prompt)
pre = 0
for outputs in model_wrapper.generate(
prompt_tokens,
prompt_len,
temperature=0, # Use greedy to make it deterministic for benchmarking
stop_str=conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
add_bos=add_bos,
):
outputs = outputs[prompt_len + 1 :].strip()
outputs = outputs.split(" ")
now = len(outputs)
if now - 1 > pre:
print(" ".join(outputs[pre : now - 1]), end=" ", flush=True)
pre = now - 1
t1 = time.time()
print(" ".join(outputs[pre:]), flush=True)
print(f" - # input token: {len(prompt_tokens)}")
print(f" - # output token: {len(outputs)}")
print(f" - process time: {(t1-t0):.3f} s")

conv.messages[-1][-1] = " ".join(outputs)
add_bos = False


if __name__ == "__main__":
ARGS = _parse_args()
tokenizer = get_tokenizer(ARGS.model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
if not ARGS.run_torch_model:
torch_device = torch.device("cpu")
model = TvmModelWrapper(
tokenizer,
ARGS.max_gen_len,
ARGS.artifact_path,
ARGS.model,
ARGS.device_name,
ARGS.dtype,
)
else:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = TorchModelWrapper(
tokenizer, ARGS.max_gen_len, ARGS.model_path, torch_device, ARGS.dtype
)

inputs = [
"Hi",
"Repeat this sentence: Sure! Amazon is a multinational technology company that operates in a variety of industries, including e-commerce, cloud computing, digital streaming, and more. The company's headquarters, also known as Amazon Headquarters or Amazon HQ, is located in Seattle, Washington. It is the main base of operations for Amazon.com, Inc., the parent company of Amazon's various subsidiaries and businesses. The headquarters is home to many of the company's executives, as well as its research and development teams, customer service teams, and other support staff.",
]
chat(model, inputs)
65 changes: 65 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import tvm
from mlc_llm import utils
from tvm import relax
from transformers import AutoTokenizer, AutoModelForCausalLM
import time


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token


def get_tvm_model(artifact_path, model, device_name, dtype):
device = tvm.device(device_name)
const_params = utils.load_params(artifact_path, device)
ex = tvm.runtime.load_module(f"{artifact_path}/{model}_{device_name}_{dtype}.so")
vm = relax.VirtualMachine(ex, device)

class Model:
def __init__(self) -> None:
self.tot_seq_len = 0
self.kv_cache = vm["create_kv_cache"]()

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = tvm.nd.array(inputs.numpy(), device=device)
self.tot_seq_len += inputs.shape[1]
seq_len_shape = tvm.runtime.ShapeTuple([self.tot_seq_len])
if inputs.shape[1] > 1:
logits, kv_cache = vm["encoding"](
inputs, seq_len_shape, self.kv_cache, const_params
)
else:
logits, kv_cache = vm["decoding"](
inputs, seq_len_shape, self.kv_cache, const_params
)
self.kv_cache = kv_cache
return torch.from_numpy(logits.numpy())

model = Model()
return model.forward


def get_pytorch_model(model_path, torch_device, dtype):
model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()
if dtype == "float16":
model = model.to(torch.float16)
model = model.to(torch_device)

def forward(inputs: torch.Tensor) -> torch.Tensor:
logits = model(inputs, use_cache=False).logits
return logits

return forward


def get_tokenizer(model_path):
return AutoTokenizer.from_pretrained(model_path)