Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StreamingLLM support to studio2 chat #2060

Merged
merged 23 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Streaming LLM (WIP)
  • Loading branch information
monorimet committed Jan 9, 2024
commit f6b249ad1137218036de5a93fdff8dd196c7613e
32 changes: 29 additions & 3 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.web.utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
precision="fp32",
external_weights=None,
use_system_prompt=True,
streaming_llm=False,
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
Expand All @@ -50,12 +51,15 @@ def __init__(
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.external_weight_file = None
self.streaming_llm = streaming_llm
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.safe_name + "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0

if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
Expand Down Expand Up @@ -83,6 +87,7 @@ def __init__(
compile_to="torch",
external_weights=external_weights,
external_weight_file=self.external_weight_file,
streaming_llm=self.streaming_llm,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
Expand All @@ -106,7 +111,7 @@ def compile(self) -> None:
frontend="torch",
external_weight_file=self.external_weight_file,
write_to=self.vmfb_name,
extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"],
extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"] if "cpu" in self.device else [],
)
# TODO: delete the temp file

Expand All @@ -129,20 +134,40 @@ def chat(self, prompt):

input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids

if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_tensor = input_tensor[:, token_slice:]

def format_out(results):
return torch.tensor(results.to_host()[0][0])

history = []
for iter in range(self.max_tokens):
if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600:
print("Evicting cache space!")
self.iree_module_dict["vmfb"]["evict_kvcache_space"]()
st_time = time.time()
if iter == 0:
token_len = input_tensor.shape[-1]
if iter == 0 and not self.streaming_llm:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
token_len += 1
elif iter == 0:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_cached_initialize"](*device_inputs)
token_len += 1
else:
if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
Expand All @@ -153,6 +178,7 @@ def format_out(results):

total_time = time.time() - st_time
history.append(format_out(token))
self.prev_token_len = token_len + len(history)
yield self.tokenizer.decode(history), total_time

if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
Expand Down
17 changes: 13 additions & 4 deletions apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def chat_fn(
precision,
download_vmfb,
config_file,
streaming_llm,
cli=False,
):
global language_model
Expand All @@ -52,8 +53,8 @@ def chat_fn(
device=device,
precision=precision,
external_weights="safetensors",
external_weight_file="llama2_7b.safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
Expand Down Expand Up @@ -213,12 +214,18 @@ def view_json_file(file_obj):
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=False,
interactive=True,
visible=False,
)
streaming_llm = gr.Checkbox(
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
value=True,
interactive=True,
)

Expand All @@ -241,8 +248,8 @@ def view_json_file(file_obj):
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
Expand All @@ -262,6 +269,7 @@ def view_json_file(file_obj):
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
Expand All @@ -283,6 +291,7 @@ def view_json_file(file_obj):
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
Expand Down
12 changes: 12 additions & 0 deletions apps/shark_studio/web/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import sys


def get_available_devices():
return ["cpu-task"]


def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)