Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StreamingLLM support to studio2 chat #2060

Merged
merged 23 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixes to runner, device names, vmfb mgmt
  • Loading branch information
monorimet committed Jan 17, 2024
commit dccd5857d87bea33410be89a9bec6156cdbbdacf
84 changes: 50 additions & 34 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
Expand All @@ -18,13 +15,31 @@
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
"max_tokens": 1024,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"anushehchaudry/llama-2-tiny-random": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "anushehchaudry/llama-2-tiny-random",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
Expand All @@ -38,6 +53,7 @@
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
"""


def append_user_prompt(history, input_prompt):
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
history += user_prompt
Expand All @@ -56,43 +72,47 @@ def __init__(
use_system_prompt=True,
streaming_llm=False,
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = (
device.split("=>")[-1].strip() if "cpu" not in device else "local-task"
)
self.driver = (
self.device.split("://")[0]
if not any(x in self.device for x in ["cpu", "local-task"])
else "llvm-cpu"
)
print(f"Selected {self.driver} as IREE target backend.")
self.precision = "f32" if "cpu" in self.driver else "f16"
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
self.driver = self.backend
if "cpu" in device:
self.device = "cpu"
self.backend = "llvm-cpu"
self.driver = "local-task"

print(f"Selected {self.backend} as IREE target backend.")
self.precision = "f32" if "cpu" in device else "f16"
self.quantization = quantization
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
self.external_weight_file = None
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
self.file_spec = "_".join(
[
self.safe_name,
self.precision,
self.quantization,
]
)
if self.quantization != "None":
self.file_spec += "_" + self.quantization

if external_weights is not None:
self.external_weight_file = get_resource_path(
self.file_spec + "." + external_weights
)

if streaming_llm:
# Add streaming suffix to file spec after setting external weights filename.
self.file_spec += "_streaming"
self.streaming_llm = streaming_llm

self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile")
# TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = get_resource_path(
f"{self.file_spec}_{self.driver}.vmfb.tempfile"
f"{self.file_spec}_{self.backend}.vmfb.tempfile"
)
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.external_weight_file = None
self.streaming_llm = streaming_llm
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.file_spec + "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0
Expand All @@ -117,7 +137,7 @@ def __init__(
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.runner = vmfbRunner(
device=self.device,
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
Expand Down Expand Up @@ -163,28 +183,25 @@ def compile(self) -> None:
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-opt-const-expr-hoisting=False",
]
if "cpu" in self.driver:
if "cpu" in self.backend:
flags.extend(
[
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-ukernels=all",
]
)
elif self.driver == "vulkan":
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
self.iree_module_dict = get_iree_compiled_module(
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
mmap=True,
frontend="torch",
external_weight_file=self.external_weight_file,
write_to=self.vmfb_name,
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
)
del self.iree_module_dict
gc.collect()
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
Expand All @@ -194,7 +211,6 @@ def compile(self) -> None:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
# TODO: delete the temp file

def sanitize_prompt(self, prompt):
if isinstance(prompt, list):
Expand Down
30 changes: 28 additions & 2 deletions apps/shark_studio/tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,44 @@
import logging
import unittest
from apps.shark_studio.api.llm import LanguageModel
import gc


class LLMAPITest(unittest.TestCase):
def testLLMSimple(self):
def test01_LLMSmall(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird naming?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forces a certain ordering of tests, but won't be truly necessary until more tests are added.

lm = LanguageModel(
"TinyPixel/small-llama2",
hf_auth_token=None,
device="cpu",
external_weights="safetensors",
precision="fp32",
quantization="None",
)
count = 0
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Turkish Turkish Turkish"
), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}"
break
del lm
gc.collect()

def test02_stream(self):
llama2 = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu",
external_weights="safetensors",
precision="fp32",
quantization="int4",
streaming_llm=True,
)
count = 0
for msg, _ in lm.chat("hi, what are you?"):
for msg, _ in llama2.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
Expand All @@ -29,6 +53,8 @@ def testLLMSimple(self):
msg.strip(" ") == "Hello!"
), f"LLM API failed to return correct response, expected 'Hello!', received {msg}"
break
del llama2
gc.collect()


if __name__ == "__main__":
Expand Down
Loading