diff --git a/docs/en/multi_modal/internvl.md b/docs/en/multi_modal/internvl.md
index efa2b30a2..24c79357c 100644
--- a/docs/en/multi_modal/internvl.md
+++ b/docs/en/multi_modal/internvl.md
@@ -64,7 +64,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
- dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'),
+ dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
@@ -90,7 +90,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
- dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'),
+ dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
@@ -152,7 +152,7 @@ imgs = load_video(video_path, num_segments=8)
question = ''
for i in range(len(imgs)):
- question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n'
+ question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n'
question += 'What is the red panda doing?'
diff --git a/docs/zh_cn/multi_modal/internvl.md b/docs/zh_cn/multi_modal/internvl.md
index 1abcbc7d0..3d948353a 100644
--- a/docs/zh_cn/multi_modal/internvl.md
+++ b/docs/zh_cn/multi_modal/internvl.md
@@ -64,7 +64,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
- dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'),
+ dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
@@ -90,7 +90,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
- dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'),
+ dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
@@ -151,7 +151,7 @@ imgs = load_video(video_path, num_segments=8)
question = ''
for i in range(len(imgs)):
- question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n'
+ question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n'
question += 'What is the red panda doing?'
diff --git a/lmdeploy/model.py b/lmdeploy/model.py
index 5b77bedd0..1d54dd1b0 100644
--- a/lmdeploy/model.py
+++ b/lmdeploy/model.py
@@ -765,6 +765,9 @@ def match(cls, model_path: str) -> Optional[str]:
Args:
model_path (str): the model path used for matching.
"""
+ # reject InternVL2-Llama3-76B
+ if 'internvl2' in model_path.lower():
+ return None
if 'llama-3-' in model_path.lower() or 'llama3-' in model_path.lower():
return 'llama3'
diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py
index 4cf5cb83d..124fd537c 100644
--- a/lmdeploy/vl/engine.py
+++ b/lmdeploy/vl/engine.py
@@ -138,7 +138,7 @@ async def _forward_loop(self):
while record.total == 0 or (self._que.qsize() and
record.total < self.max_batch_size):
while self._que.qsize() == 0:
- await asyncio.sleep(0)
+ await asyncio.sleep(0.01)
item = await self._que.get()
record.enqueue(item[0], item[1], item[2])
inputs, kwargs = record.dequeue(self.max_batch_size)
diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py
index 9187af98a..66faf4f46 100644
--- a/lmdeploy/vl/model/llava_hf.py
+++ b/lmdeploy/vl/model/llava_hf.py
@@ -52,8 +52,9 @@ def build_model(self):
@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
- pixel_values = self.processor(images,
- return_tensors='pt')['pixel_values']
+ pixel_values = self.processor(
+ images, return_tensors='pt',
+ input_data_format='channels_last')['pixel_values']
pixel_values = pixel_values.to(device=self.model.device,
dtype=self.model.dtype)
image_outputs = self.model.vision_tower.forward(
diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py
index 78af894ba..9223ebea4 100644
--- a/lmdeploy/vl/model/llava_next.py
+++ b/lmdeploy/vl/model/llava_next.py
@@ -75,7 +75,9 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
from transformers.models.llava_next.modeling_llava_next import \
image_size_to_num_patches
"""forward."""
- processed_inputs = self.processor(images, return_tensors='pt')
+ processed_inputs = self.processor(images,
+ return_tensors='pt',
+ input_data_format='channels_last')
pixel_values = processed_inputs['pixel_values'].to(
device=self.model.device, dtype=self.model.dtype)
image_sizes = processed_inputs['image_sizes'].to(
diff --git a/lmdeploy/vl/model/utils.py b/lmdeploy/vl/model/utils.py
index 23f757744..447d5f8a1 100644
--- a/lmdeploy/vl/model/utils.py
+++ b/lmdeploy/vl/model/utils.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
import os
import sys
from contextlib import contextmanager
@@ -104,7 +105,7 @@ def hack_import_with(src: List[str], dst: str = 'torch'):
sys.modules.pop(item, None)
-def _set_func(origin_func_path: str,
+def _set_func(origin_func_path: Union[str, None],
rewrite_func: Callable,
origin_func: Callable = None):
"""Replace old function with the new function.
@@ -115,22 +116,19 @@ def _set_func(origin_func_path: str,
origin_func (Callable): function to replace
"""
# import module
- split_path = origin_func_path.split('.')
- for i in range(len(split_path), 0, -1):
- try:
- exec('import {}'.format('.'.join(split_path[:i])))
- break
- except Exception:
- continue
-
- method_class = False
- if len(split_path) > 1:
- module_or_class = eval('.'.join(split_path[:-1]))
- if isinstance(module_or_class, type):
- method_class = True
-
- origin_func = eval(origin_func_path) \
- if origin_func is None else origin_func
+ if isinstance(origin_func_path, str):
+ split_path = origin_func_path.split('.')
+ for i in range(len(split_path), 0, -1):
+ try:
+ exec('import {}'.format('.'.join(split_path[:i])))
+ break
+ except Exception:
+ continue
+
+ origin_func = eval(origin_func_path) \
+ if origin_func is None else origin_func
+
+ method_class = inspect.ismethod(origin_func)
# replace method
if not method_class:
@@ -146,23 +144,34 @@ def _set_func(origin_func_path: str,
for i, v in enumerate(ref):
if id(v) == obj_id:
ref[i] = rewrite_func
- exec(f'{origin_func_path} = rewrite_func')
+ if isinstance(origin_func_path, str):
+ exec(f'{origin_func_path} = rewrite_func')
+ elif method_class:
+ raise NotImplementedError
+
return origin_func
@contextmanager
-def rewrite_ctx(origin_func_path: List[str], rewrite_func: List[Callable]):
+def rewrite_ctx(origin_func_path: List[Union[str, Callable]],
+ rewrite_func: List[Callable]):
"""rewrite context."""
assert len(origin_func_path) == len(rewrite_func)
origin_func_list = []
for (func_path, dst_func) in zip(origin_func_path, rewrite_func):
- origin_func = _set_func(func_path, dst_func)
+ if isinstance(func_path, Callable):
+ origin_func = _set_func(None, dst_func, func_path)
+ else:
+ origin_func = _set_func(func_path, dst_func)
origin_func_list.append(origin_func)
yield
for (func_path, dst_func, origin_func) in zip(origin_func_path,
rewrite_func,
origin_func_list):
- _set_func(func_path, origin_func, dst_func)
+ if isinstance(func_path, Callable):
+ _set_func(None, origin_func, dst_func)
+ else:
+ _set_func(func_path, origin_func, dst_func)
def add_device_hook(module: torch.nn.Module,
diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py
index 1d4393591..96bc900c0 100644
--- a/lmdeploy/vl/model/xcomposer2.py
+++ b/lmdeploy/vl/model/xcomposer2.py
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
+import os
+import sys
import warnings
from contextlib import contextmanager
from typing import Any, List, Tuple
@@ -48,7 +50,7 @@ def _CLIPVisionModel_from_pretrained(vision_tower_name):
@contextmanager
-def init_empty_vit():
+def init_empty_vit(model_path):
"""skip download vision model."""
origin_func_path = [
'transformers.CLIPVisionModel.from_pretrained',
@@ -56,6 +58,23 @@ def init_empty_vit():
rewrite_func = [
_CLIPVisionModel_from_pretrained,
]
+
+ model_type, _ = get_xcomposer_type(model_path)
+ if model_type == ModelType.XCOMPOSER2D5:
+ from transformers.dynamic_module_utils import \
+ get_class_from_dynamic_module
+ from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME
+ _ = get_class_from_dynamic_module(
+ 'modeling_internlm_xcomposer2.get_font', model_path)
+ folder = model_path.rstrip(os.sep).split(os.sep)[-1]
+ module_path = '.'.join([
+ TRANSFORMERS_DYNAMIC_MODULE_NAME, folder,
+ 'modeling_internlm_xcomposer2'
+ ])
+ origin_get_font_func = getattr(sys.modules[module_path], 'get_font')
+ origin_func_path.append(origin_get_font_func)
+ rewrite_func.append(lambda: None)
+
with rewrite_ctx(origin_func_path, rewrite_func):
yield
@@ -77,7 +96,8 @@ def match(cls, config: AutoConfig):
def build_model(self):
from accelerate import init_empty_weights
- with init_empty_weights(), warnings.catch_warnings(), init_empty_vit():
+ with init_empty_weights(), warnings.catch_warnings(), \
+ init_empty_vit(self.model_path):
warnings.simplefilter('ignore')
config = self.hf_config
model = AutoModelForCausalLM.from_config(config,
diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py
index 952a2f153..bf471225f 100644
--- a/lmdeploy/vl/templates.py
+++ b/lmdeploy/vl/templates.py
@@ -106,6 +106,8 @@ def _inner_call(i, images):
def append_image_token(self, prompt, num_images: int):
"""append image token to user prompt."""
+ if IMAGE_TOKEN in prompt:
+ return prompt
return (IMAGE_TOKEN + '\n') * num_images + prompt
def convert_messages(self, messages, sequence_start=True):
@@ -130,9 +132,8 @@ def convert_messages(self, messages, sequence_start=True):
num_images += 1
elif item['type'] == 'text':
prompt = item['text']
- # if IMAGE_TOKEN in user prompt, use user custom prompt instead
- # of adding IMAGE_TOKEN to user prompt
- if IMAGE_TOKEN not in prompt and num_images > 0:
+ if num_images > 0:
+ # add IMAGE_TOKEN to user prompt
prompt = self.append_image_token(prompt, num_images)
new_item = {'role': 'user', 'content': prompt}
new_messages.append(new_item)
@@ -161,8 +162,17 @@ class InternVLChatTemplateWrapper(VLChatTemplateWrapper):
def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
- # not sure whether support multi images.
- return f'{IMAGE_TOKEN * num_images}\n' + prompt
+ # lmdeploy uses as image token
+ # internvl uses special tags
+ if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt:
+ prompt = prompt.replace(f'{IMAGE_TOKEN}',
+ f'{IMAGE_TOKEN}')
+ prompt = prompt.replace('', '')
+ prompt = prompt.replace('', '')
+ prompt = prompt.replace('', '')
+ elif IMAGE_TOKEN not in prompt:
+ prompt = f'{IMAGE_TOKEN * num_images}\n' + prompt
+ return prompt
class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper):
@@ -170,6 +180,8 @@ class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper):
def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
+ if IMAGE_TOKEN in prompt:
+ return prompt
logger.error(
f'for deepseek-vl model, the user should insert the {IMAGE_TOKEN} '
'to user prompt manually, please read https://lmdeploy.readthedocs'
@@ -188,6 +200,8 @@ class QwenVLChatTemplateWrapper(VLChatTemplateWrapper):
def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
+ if IMAGE_TOKEN in prompt:
+ return prompt
res = ''
for i in range(num_images):
res += f'Picture {str(i)}:{IMAGE_TOKEN}\n'
@@ -256,6 +270,8 @@ class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper):
"""InternLM-XComposer2 chat template."""
def append_image_token(self, prompt, num_images: int):
+ if IMAGE_TOKEN in prompt:
+ return prompt
logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, '
'the user can manually insert the token to prompt')
return ' '.join([IMAGE_TOKEN] * num_images) + prompt
@@ -268,6 +284,8 @@ def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
if num_images == 0:
return prompt
+ if IMAGE_TOKEN in prompt:
+ return prompt
res = f'{IMAGE_TOKEN}\n'
assert num_images <= 1, 'MiniGeminiLlama accepts 1 input image'
res = res + prompt
@@ -278,12 +296,15 @@ class MiniCPMVTempateWrapper(VLChatTemplateWrapper):
"""MiniCPM-Llama3-V-2_5 chat template."""
def append_image_token(self, prompt, num_images: int):
- return f'{IMAGE_TOKEN}\n' * num_images + prompt
+ if IMAGE_TOKEN in prompt:
+ return prompt
+ prompt = f'{IMAGE_TOKEN}\n' * num_images + prompt
+ return prompt
def update_image_token(self, prompt, features):
_features = []
_prompt = []
- segs = prompt.split(f'{IMAGE_TOKEN}\n')
+ segs = prompt.split(f'{IMAGE_TOKEN}\n')
for i, seg in enumerate(segs):
if i > 0 and i <= len(features):
_feat = features[i - 1]['embeddings'].split(1)
@@ -309,7 +330,7 @@ class MiniCPMV26TempateWrapper(MiniCPMVTempateWrapper):
def update_image_token(self, prompt, features):
_features = []
_prompt = []
- segs = prompt.split(f'{IMAGE_TOKEN}\n')
+ segs = prompt.split(f'{IMAGE_TOKEN}\n')
idx = 0
for i, seg in enumerate(segs):
if i > 0 and i <= len(features):