Skip to content

Commit

Permalink
Merge pull request #29 from Blaizzy/pc/add-chat-ui
Browse files Browse the repository at this point in the history
Improve ChatUI
  • Loading branch information
Blaizzy authored May 24, 2024
2 parents 3c47b80 + 1724c29 commit c918c8f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
14 changes: 8 additions & 6 deletions mlx_vlm/chat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def generate(
tokenizer = processor.tokenizer

image_token_index = model.config.image_token_index
input_ids, pixel_values = prepare_inputs(
input_ids, pixel_values, mask = prepare_inputs(
image_processor, processor, image, prompt, image_token_index
)
logits, cache = model(input_ids, pixel_values)
logits, cache = model(input_ids, pixel_values, mask=mask)
logits = logits[:, -1, :]
y, _ = sample(logits, temp, top_p)

Expand All @@ -71,6 +71,7 @@ def generate(
generate_step(
model.language_model,
logits,
mask,
cache,
temp,
repetition_penalty,
Expand All @@ -92,11 +93,12 @@ def generate(
def chat(message, history, temperature, max_tokens):

chat = []
if len(message["files"]) >= 0:
if len(message["files"]) >= 1:
chat.append(get_message_json(config["model_type"], message["text"]))
else:
raise Exception("Please upload an image. Text only chat is not supported.")
raise gr.Error("Please upload an image. Text only chat is not supported.")

files = message["files"][-1]
if "chat_template" in processor.__dict__.keys():
messages = processor.apply_chat_template(
chat,
Expand All @@ -105,7 +107,7 @@ def chat(message, history, temperature, max_tokens):
)

elif "tokenizer" in processor.__dict__.keys():
if processor.tokenizer.chat_template:
if model.config.model_type != "paligemma":
messages = processor.tokenizer.apply_chat_template(
chat,
tokenize=False,
Expand All @@ -118,7 +120,7 @@ def chat(message, history, temperature, max_tokens):
for chunk in generate(
model,
processor,
message["files"][-1],
files,
messages,
image_processor,
temperature,
Expand Down
3 changes: 1 addition & 2 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,10 @@ def prepare_inputs(image_processor, processor, image, prompt, image_token_index)
image = load_image(image)

if image_processor is not None:

text_chunks = [processor(chunk).input_ids for chunk in prompt.split("<image>")]
input_ids = mx.array([text_chunks[0] + [image_token_index] + text_chunks[1]])

pixel_values = image_processor.preprocess(images=[image])[0]
pixel_values = mx.array(np.expand_dims(pixel_values, axis=0))
else:
inputs = processor(prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
Expand Down

0 comments on commit c918c8f

Please sign in to comment.