Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StreamingLLM support to studio2 chat #2060

Merged
merged 23 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 109 additions & 12 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from turbine_models.custom_models import stateless_llama
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.web.utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc
Expand All @@ -29,33 +30,79 @@
},
}

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<s>", "</s>"

def append_user_prompt(history, input_prompt):
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
history += user_prompt
return history

def append_bot_prompt(history, input_prompt):
user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}"
history += user_prompt
return history

class LanguageModel:
def __init__(
self,
model_name,
hf_auth_token=None,
device=None,
precision="fp32",
quantization="int4",
precision="",
external_weights=None,
use_system_prompt=True,
streaming_llm=False,
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.tempfile_name = get_resource_path("llm.torch.tempfile")
self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
self.device = device
self.precision = precision
self.safe_name = self.hf_model_name.strip("/").replace("/", "_")
self.device = device.split("=>")[-1].strip()
self.driver = self.device.split("://")[0]
print(f"Selected {self.driver} as device driver")
self.precision = "f32" if "cpu" in self.driver else "f16"
self.quantization = quantization
#TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
self.file_spec = "_".join([
"llama2",
"streaming" if streaming_llm else "chat",
self.precision,
self.quantization,
])
self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile")
#TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = get_resource_path(f"{self.file_spec}_{self.driver}.vmfb.tempfile")
self.safe_name = self.hf_model_name.split("/")[-1].replace("-", "_")
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.external_weight_file = None
self.streaming_llm = streaming_llm
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.safe_name + "." + external_weights
self.safe_name
+ "_" + self.precision
+ "_" + self.quantization
+ "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
)
else:
print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
)
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
Expand All @@ -66,7 +113,7 @@ def __init__(
self.iree_module_dict["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.vmfb_name,
device,
self.driver,
device_idx=0,
rt_flags=[],
external_weight_file=self.external_weight_file,
Expand All @@ -83,6 +130,9 @@ def __init__(
compile_to="torch",
external_weights=external_weights,
external_weight_file=self.external_weight_file,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
Expand All @@ -99,14 +149,38 @@ def __init__(

def compile(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not moving all the compile code into this non reusable place, unless this is part of a migration intended to deprecate the shark api

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistake, should just be model arch/api-specific compile flags

# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
flags = [
"--iree-input-type=torch",
"--mlir-print-debuginfo",
"--mlir-print-op-on-diagnostic=false",
"--iree-llvmcpu-target-cpu-features=host",
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-opt-const-expr-hoisting=False",
]
if "cpu" in self.driver:
flags.extend(
[
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-ukernels=all"
]
)
elif self.driver == "vulkan":
flags.extend(
[
"--iree-stream-resource-max-allocation-size=4294967296"
]
)
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name,
device=self.device,
mmap=True,
frontend="torch",
external_weight_file=self.external_weight_file,
write_to=self.vmfb_name,
extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"],
extra_args=flags,
)
# TODO: delete the temp file

Expand All @@ -129,20 +203,40 @@ def chat(self, prompt):

input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids

if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_tensor = input_tensor[:, token_slice:]

def format_out(results):
return torch.tensor(results.to_host()[0][0])

history = []
for iter in range(self.max_tokens):
if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600:
print("Evicting cache space!")
self.iree_module_dict["vmfb"]["evict_kvcache_space"]()
st_time = time.time()
if iter == 0:
token_len = input_tensor.shape[-1]
if iter == 0 and not self.streaming_llm:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
token_len += 1
elif iter == 0:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_cached_initialize"](*device_inputs)
token_len += 1
else:
if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600:
print("Evicting cache space!")
self.iree_module_dict["vmfb"]["evict_kvcache_space"]()
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
Expand All @@ -153,7 +247,10 @@ def format_out(results):

total_time = time.time() - st_time
history.append(format_out(token))
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
res = self.tokenizer.decode(history, skip_special_tokens=True)
prompt = append_bot_prompt(prompt, res)
yield prompt, total_time

if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
break
Expand Down
Loading
Loading