Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 146 additions & 3 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
operations, separating them from the main text generation logic.
"""

from typing import Any, Dict, Optional, Tuple
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils.logging_utils import logger
Expand All @@ -37,6 +38,9 @@ def __init__(
qeff_model: Optional[QAICInferenceSession],
vision_session: Optional[QAICInferenceSession],
processor: Optional[AutoImageProcessor],
tokenizer: Optional[AutoTokenizer],
image_height: Optional[int] = None,
image_width: Optional[int] = None,
config: Optional[Dict[str, Any]] = None,
lang_session: Optional[QAICInferenceSession] = None,
):
Expand All @@ -46,12 +50,18 @@ def __init__(
Args:
vision_session: QAICInferenceSession for vision model
processor: AutoImageProcessor for image preprocessing
tokenizer: AutoTokenizer for text tokenization
image_height: Desired image height for resizing
image_width: Desired image width for resizing
config: Configuration dictionary with vision model parameters
lang_session: Optional language session for coordination (to avoid resource conflicts)
"""
self._qeff_model = qeff_model
self._vision_session = vision_session
self._processor = processor
self._tokenizer = tokenizer
self._image_height = image_height
self._image_width = image_width
self._config = config or {}
self._lang_session = lang_session # Store language session for coordination

Expand All @@ -70,13 +80,132 @@ def is_available(self) -> bool:
"""
return self._vision_session is not None and self._processor is not None

def prepare_internVL_inputs(self, img_url: str, prompt: str) -> Dict[str, np.ndarray]:
"""
Prepare inputs for InternVL model

Args:
image_url: URL or path to image
prompt: Text query to process with image
"""
if not self._tokenizer:
raise ValueError("Tokenizer is required for InternVL input preparation")
pixel_values = []
num_patches_list = []
questions = []
img = requests.get(img_url, stream=True)
image = Image.open(BytesIO(img.content)).convert("RGB")

if self._image_height and self._image_width:
image = image.resize((self._image_height, self._image_width))
else:
logger.warning("Height and Width not specified. Using default image size for num_patches = 13.")
image = image.resize((1000, 747))

# preprocess the resized image
pixel_value = self._processor.load_image(image, max_num=12)
num_patches_list.append(pixel_value.shape[0])
pixel_values.append(pixel_value)

question = "<image>\n" + prompt
questions.append(question)

pixel_values = torch.cat(pixel_values, dim=0)

# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list)

inputs = self._tokenizer(prompt, return_tensors="pt")
inputs["pixel_values"] = pixel_values.clone()

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs

def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Args:
image_url: URL or path to image
query: Text query to process with image
Returns:
Dictionary of vision model inputs
Raises:
ValueError: If vision handler is not properly initialized
RuntimeError: If image processing fails
"""
if not self.is_available():
raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.")

try:
# Download image
if image_url.startswith(("http://", "https://")):
image = Image.open(requests.get(image_url, stream=True).raw)
else:
image = Image.open(image_url)
image = image.resize((536, 354))
inputs = self._processor.process(images=[image], text=query)
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64)
valid = inputs["image_input_idx"] > 0
valid = valid.reshape(1, -1)
inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0)
inputs["pixel_values"] = inputs.pop("images")

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs
except Exception as e:
raise RuntimeError(f"Failed to process image {image_url}: {str(e)}")

def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs

Args:
image_url: URL or path to image
query: Text query to process with image
prefill_seq_len: Padded sequence length for language model

Returns:
Dictionary of vision model inputs
Expand All @@ -95,6 +224,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
else:
image = Image.open(image_url)

if "mistral3" in self._qeff_model.model.config.model_type:
image = image.resize((1540, 1540))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we move this to constants


# Prepare conversation format
conversation = [
{
Expand Down Expand Up @@ -323,7 +455,18 @@ def get_processed_inputs(

try:
## Get vlm inputs ##
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "internvl_chat"
):
vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query)
elif (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "molmo"
):
vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query)
else:
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)

# Handle padding for language model
pad_token_id = 1
Expand Down
10 changes: 10 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
Expand All @@ -107,6 +109,8 @@ def __init__(
enable_debug_logs: Enable debug logging
write_io_dir: Directory for I/O file writing
full_batch_size: Enable continuous batching (new feature)
image_height: Desired image height for resizing
image_width: Desired image width for resizing
is_tlm: Target language model flag
include_sampler: Enable on-device sampling (new feature)
return_pdfs: Return probability distributions
Expand Down Expand Up @@ -143,6 +147,9 @@ def __init__(
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
self.image_height = image_height
self.image_width = image_width
self._vision_qpc_path = vision_qpc_path
self.device_id = device_id # Store device_id for vision components
self.enable_debug_logs = enable_debug_logs # Store for vision components
Expand Down Expand Up @@ -173,6 +180,9 @@ def _init_vision_components(self):
qeff_model=self.qeff_model,
vision_session=self._vision_session,
processor=self.processor,
tokenizer=self.tokenizer,
image_height=self.image_height,
image_width=self.image_width,
config=vision_config,
lang_session=self._session, # Pass language session for coordination
)
Expand Down
Loading
Loading