Skip to content

Mllama(single + dual) + InternVL(single) + Llava (single) #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
48d24da
Mllama Vision support (#254)
quic-rishinr Jan 30, 2025
0bdeea5
Compiler command fix
quic-rishinr Jan 30, 2025
c948607
Mllama single qpc support added (#258)
quic-amitraj Feb 4, 2025
649cd32
Export fix
quic-amitraj Feb 4, 2025
32f544c
Generate fix-1
quic-amitraj Feb 4, 2025
7ebf06b
minor-fix
quic-amitraj Feb 4, 2025
3bc06be
Model swap fix at the time of export and compile
quic-amitraj Feb 5, 2025
5accf3f
two_qpc_working
quic-amitraj Feb 7, 2025
6ae6835
Minor-fix-1
quic-amitraj Feb 7, 2025
1e181f4
working single and double with single soc
quic-amitraj Feb 7, 2025
5fb0acb
ruff checks and format
quic-amitraj Feb 7, 2025
87e07d0
Updated factory class
quic-rishinr Feb 7, 2025
fc323f4
Added support for Llava model single QPC (#265)
quic-amitraj Feb 10, 2025
1ac6b62
Added support for InternVL single QPC (#264)
ochougul Feb 10, 2025
b3a5d22
final revision VLM
asmigosw Feb 10, 2025
d33e9a5
fixed liscence
quic-amitraj Feb 10, 2025
905b703
Minor fixes-1
quic-amitraj Feb 11, 2025
990adb1
Added get_input_info support and other compilers arguments
quic-amitraj Feb 11, 2025
afe3cd2
generalized getting img_size from config
ochougul Feb 11, 2025
b458193
refactor basic
ochougul Feb 11, 2025
d1981c2
added warnings and auto_correct_inputs function
ochougul Feb 12, 2025
a32007e
remove unused imports
ochougul Feb 12, 2025
81cea10
addressed comments
ochougul Feb 12, 2025
2f1ec08
final commit changed documentation added better warnings
ochougul Feb 13, 2025
ad594d7
Addressed comments
quic-amitraj Feb 13, 2025
94e813c
minor bugfix
ochougul Feb 14, 2025
a661840
addressed comments
ochougul Feb 14, 2025
ed7d5f2
removed image_text models tests to avoid pytest issues
ochougul Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def check_qaic_sdk():
# Conditionally import QAIC-related modules if the SDK is installed
__version__ = "0.0.1.dev0"
if QAIC_INSTALLED:
from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader
from QEfficient.base import (
QEFFAutoModel,
QEFFAutoModelForCausalLM,
QEFFAutoModelForImageTextToText,
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
Expand All @@ -43,6 +48,7 @@ def check_qaic_sdk():
"QEFFAutoModel",
"QEFFAutoModelForCausalLM",
"QEffAutoPeftModelForCausalLM",
"QEFFAutoModelForImageTextToText",
"QEFFCommonLoader",
]

Expand Down
6 changes: 5 additions & 1 deletion QEfficient/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
# -----------------------------------------------------------------------------

from QEfficient.base.common import QEFFCommonLoader # noqa: F401
from QEfficient.transformers.models.modeling_auto import QEFFAutoModel, QEFFAutoModelForCausalLM # noqa: F401
from QEfficient.transformers.models.modeling_auto import ( # noqa: F401
QEFFAutoModel,
QEFFAutoModelForCausalLM,
QEFFAutoModelForImageTextToText,
)
2 changes: 2 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _export(
}
if onnx_transform_kwargs is not None:
transform_kwargs.update(onnx_transform_kwargs)

for transform in self._onnx_transforms:
model, transformed = transform.apply(model, **transform_kwargs)
model.metadata_props.append(
Expand All @@ -187,6 +188,7 @@ def _export(

except Exception as e:
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")

raise e

finally:
Expand Down
25 changes: 24 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
from typing import Dict, Tuple, Type
from types import MethodType
from typing import Callable, Dict, Tuple, Type

from torch import nn

Expand Down Expand Up @@ -87,3 +88,25 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
@classmethod
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")


class ModuleMethodMapperTransform(PytorchTransform):
"""
Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
"""

_match_class_replace_method: Dict[nn.Module, Dict[str, Callable]]
_match_string_replace_method: Dict[str, Dict[str, Callable]]

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for module in model.modules():
if (repl_method_map := cls._match_class_replace_method.get(type(module))) or (
repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__)
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
transformed = True

return model, transformed
21 changes: 17 additions & 4 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,23 @@ class CloudAI100ExecInfo:
perf_metrics: PerfMetrics

def __repr__(self):
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)}\
\nDecode token/sec is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)}\
\nTotal token/sec is= {round(self.perf_metrics.total_perf * self.batch_size, 2)}\
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)} sec\
\nDecode is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)} tokens/sec\
\nTotal is= {round(self.perf_metrics.total_perf * self.batch_size, 2)} tokens/sec\
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} tokens/sec"


@dataclass
class CloudAI100ExecInfoNew:
batch_size: int
generated_ids: Union[List[np.ndarray], np.ndarray]
perf_metrics: PerfMetrics

def __repr__(self):
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)} sec\
\nDecode is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)} token/sec\
\nTotal is= {round(self.perf_metrics.total_perf * self.batch_size, 2)} token/sec\
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} sec"


io_files = []
Expand Down
95 changes: 94 additions & 1 deletion QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# -----------------------------------------------------------------------------

from collections import namedtuple
from typing import Dict, Type
from typing import Dict, Optional, Tuple, Type

import torch
import torch.nn as nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
Expand Down Expand Up @@ -242,3 +243,95 @@
GPTBigCodeBlock: QEffGPTBigCodeBlock,
GPTBigCodeModel: QEffGPTBigCodeModel,
}


def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)

# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32)
)

# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32)
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask

return cross_attention_mask, full_text_row_masked_out_mask


def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)

# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0

# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask

# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32)
attention_mask = attention_mask.unsqueeze(1)

return attention_mask


def _create_causal_mask(
position_ids,
target_length,
sliding_window: Optional[int] = None,
):
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with sliding window
"""
if sliding_window is not None:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, -1)
# --- Rolling buffer ---
pos_max = position_ids.max(1, keepdim=True).values
kv_start = (pos_max // target_length) * target_length
kv_indices_high = kv_indices + kv_start
kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length)
kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high)
kv_indices = kv_indices.unsqueeze(1)
# ------
causal_mask = kv_indices > query_indices
attention_mask = causal_mask

window_indices = query_indices - sliding_window + 1
window_mask = kv_indices < window_indices
attention_mask = attention_mask | window_mask
attention_mask = attention_mask.unsqueeze(1)
else:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, 1, -1)
attention_mask = kv_indices > query_indices
attention_mask = attention_mask.unsqueeze(1)

return attention_mask
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/internvl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading