-
Notifications
You must be signed in to change notification settings - Fork 35
/
chat_ui.py
74 lines (60 loc) · 1.89 KB
/
chat_ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import gradio as gr
from mlx_vlm import load
from .prompt_utils import apply_chat_template
from .utils import load, load_config, load_image_processor, stream_generate
def parse_arguments():
parser = argparse.ArgumentParser(
description="Generate text from an image using a model."
)
parser.add_argument(
"--model",
type=str,
default="qnguyen3/nanoLLaVA",
help="The path to the local model directory or Hugging Face repo.",
)
return parser.parse_args()
args = parse_arguments()
config = load_config(args.model)
model, processor = load(args.model, {"trust_remote_code": True})
image_processor = load_image_processor(args.model)
def chat(message, history, temperature, max_tokens):
chat = []
if len(message.files) >= 1:
chat.append(message.text)
else:
raise gr.Error("Please upload an image. Text only chat is not supported.")
files = message.files[-1].path
if model.config.model_type != "paligemma":
messages = apply_chat_template(processor, config, chat)
else:
messages = message.text
response = ""
for chunk in stream_generate(
model, processor, files, messages, image_processor, max_tokens, temp=temperature
):
response += chunk
yield response
demo = gr.ChatInterface(
fn=chat,
title="MLX-VLM Chat UI",
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", render=False
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=200,
label="Max new tokens",
render=False,
),
],
description=f"Now Running {args.model}",
multimodal=True,
)
demo.launch(inbrowser=True)