Skip to content

Commit 1378b1d

Browse files
committed
ruff fix
Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>
1 parent 17876fc commit 1378b1d

File tree

2 files changed

+28
-40
lines changed

2 files changed

+28
-40
lines changed

QEfficient/transformers/models/mllama/modeling_mllama.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
import math
1111
from typing import List, Optional, Tuple, Union
1212

13-
import requests
1413
import torch
1514
import torch.nn.functional as F
1615
import torch.utils.checkpoint
17-
from PIL import Image
1816
from torch import nn
1917
from torch.nn import CrossEntropyLoss
2018
from transformers.cache_utils import Cache, DynamicCache
@@ -1197,12 +1195,13 @@ def forward(
11971195
return outputs
11981196

11991197
def generate_input(self, processor, kv_offload):
1200-
1201-
#vision_inputs
1198+
# vision_inputs
12021199
vision_inputs = {
1203-
"pixel_values": torch.zeros((bs, max_num_images,max_image_tiles,num_channel, image_length, image_width ), dtype=torch.int64),
1200+
"pixel_values": torch.zeros(
1201+
(bs, max_num_images, max_image_tiles, num_channel, image_length, image_width), dtype=torch.int64
1202+
),
12041203
"aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64),
1205-
"aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles,1 ), dtype=torch.int64)
1204+
"aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles, 1), dtype=torch.int64),
12061205
}
12071206

12081207
vision_output_names = []
@@ -1220,19 +1219,19 @@ def generate_input(self, processor, kv_offload):
12201219
},
12211220
}
12221221

1223-
#lang_inputs
1222+
# lang_inputs
12241223
lang_inputs = {
1225-
"input_ids": torch.zeros((bs,seq_len),dtype=torch.int64),
1224+
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
12261225
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
1227-
"cross_attention_mask": torch.ones((bs, max_image_tiles),dtype=torch.int64),
1228-
"attention_mask": torch.ones((bs,seq_len),dtype=torch.int64)
1226+
"cross_attention_mask": torch.ones((bs, max_image_tiles), dtype=torch.int64),
1227+
"attention_mask": torch.ones((bs, seq_len), dtype=torch.int64),
12291228
}
12301229

12311230
lang_inputs["position_ids"] = torch.where(
12321231
lang_inputs.pop("attention_mask") == 1,
12331232
torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1),
12341233
-1,
1235-
)
1234+
)
12361235

12371236
ctx_len = Constants.CTX_LEN
12381237
txt_cfg = self.mllama.config.get_text_config()
@@ -1245,7 +1244,6 @@ def generate_input(self, processor, kv_offload):
12451244
num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1
12461245
image_tokens_len = vis_cfg.max_num_tiles * num_patches
12471246

1248-
12491247
lang_inputs["past_key_values"] = DynamicCache(num_hidden_layers)
12501248
lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers
12511249
lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers
@@ -1254,20 +1252,21 @@ def generate_input(self, processor, kv_offload):
12541252
if i in cross_attention_layers:
12551253
idx = cross_attention_layers.index(i)
12561254
assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}"
1257-
lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim)
1255+
lang_inputs["past_key_values"].key_cache[i] = torch.zeros(
1256+
1, num_key_value_heads, image_tokens_len, head_dim
1257+
)
12581258
lang_inputs["past_key_values"].value_cache[i] = torch.zeros(
12591259
1, num_key_value_heads, image_tokens_len, head_dim
12601260
)
12611261
else:
12621262
lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim)
12631263
lang_inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim)
12641264

1265-
12661265
lang_output_names = [
12671266
"logits",
12681267
*[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]],
12691268
]
1270-
1269+
12711270
lang_dynamic_axes = {
12721271
"input_ids": {0: "batch_size", 1: "seq_len"},
12731272
"position_ids": {0: "batch_size", 1: "seq_len"},
@@ -1286,10 +1285,10 @@ def generate_input(self, processor, kv_offload):
12861285
else:
12871286
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
12881287
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
1289-
1288+
12901289
lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache()
12911290
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, ctx_len - 1)
1292-
1291+
12931292
inputs = []
12941293
output_names = []
12951294
dynamic_axes = []
@@ -1304,5 +1303,3 @@ def generate_input(self, processor, kv_offload):
13041303
dynamic_axes.append({**vision_dynamic_axes, **lang_dynamic_axes})
13051304

13061305
return inputs, output_names, dynamic_axes
1307-
1308-

QEfficient/transformers/models/modeling_auto.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
from typing import List, Optional, Union
1515

1616
import numpy as np
17-
import requests
1817
import torch
1918
import torch.nn as nn
2019
import transformers
21-
from PIL import Image
2220
from transformers import (
2321
AutoModel,
2422
AutoModelForCausalLM,
@@ -34,7 +32,6 @@
3432
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
3533
from QEfficient.generation.cloud_infer import QAICInferenceSession
3634
from QEfficient.generation.text_generation_inference import get_compilation_dims
37-
from QEfficient.transformers.cache_utils import QEffDynamicCache
3835
from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder
3936
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform
4037
from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers
@@ -746,8 +743,7 @@ def export(
746743
self,
747744
export_dir: Optional[str] = None,
748745
**kwargs,
749-
) -> str:
750-
746+
) -> str:
751747
self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.processor)
752748
if self.kv_offload:
753749
self.vision_export_path = self.export_vision(export_dir)
@@ -757,12 +753,11 @@ def export(
757753
self._export(self.model, self.inputs[0], self.output_names[0], self.dynamic_axes[0], export_dir=export_dir)
758754

759755
def export_vision(self, export_dir):
760-
761-
self.vision_encoder_model=VisionEncoder(self.model)
756+
self.vision_encoder_model = VisionEncoder(self.model)
762757

763-
vision_inputs=self.inputs[0]
764-
vision_output_names=self.output_names[0]
765-
vision_dynamic_axes=self.dynamic_axes[0]
758+
vision_inputs = self.inputs[0]
759+
vision_output_names = self.output_names[0]
760+
vision_dynamic_axes = self.dynamic_axes[0]
766761

767762
self.vision_onnx_path = self._export(
768763
self.vision_encoder_model,
@@ -775,20 +770,16 @@ def export_vision(self, export_dir):
775770
return self.vision_onnx_path
776771

777772
def export_lang(self, export_dir):
778-
self.lang_model= ModelWrapper(self.model)
773+
self.lang_model = ModelWrapper(self.model)
779774

780-
lang_inputs=self.inputs[1]
781-
lang_output_names=self.output_names[1]
782-
lang_dynamic_axes=self.dynamic_axes[1]
775+
lang_inputs = self.inputs[1]
776+
lang_output_names = self.output_names[1]
777+
lang_dynamic_axes = self.dynamic_axes[1]
783778

784779
self.lang_onnx_path = self._export(
785-
self.lang_model,
786-
lang_inputs,
787-
lang_output_names,
788-
lang_dynamic_axes,
789-
export_dir=export_dir
790-
)
791-
780+
self.lang_model, lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir
781+
)
782+
792783
return self.lang_onnx_path
793784

794785
def compile(

0 commit comments

Comments
 (0)