Skip to content

Commit

Permalink
Fix some issues encountered by modelscope and community (#2428)
Browse files Browse the repository at this point in the history
* fix modelscope

* fix llava model when input images have size (x, 1)

* larger interval

* skip get_font for xcomposer2d5

* fix custom image token position

* fix potential mismatching issues

* update docs
  • Loading branch information
irexyc authored Sep 7, 2024
1 parent 3df11e7 commit 659a6b0
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 41 deletions.
6 changes: 3 additions & 3 deletions docs/en/multi_modal/internvl.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
dict(type='text', text=f'<img>{IMAGE_TOKEN}{IMAGE_TOKEN}</img>\nDescribe the two images in detail.'),
dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
Expand All @@ -90,7 +90,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
dict(type='text', text=f'Image-1: <img>{IMAGE_TOKEN}</img>\nImage-2: <img>{IMAGE_TOKEN}</img>\nDescribe the two images in detail.'),
dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
Expand Down Expand Up @@ -152,7 +152,7 @@ imgs = load_video(video_path, num_segments=8)

question = ''
for i in range(len(imgs)):
question = question + f'Frame{i+1}: <img>{IMAGE_TOKEN}</img>\n'
question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n'

question += 'What is the red panda doing?'

Expand Down
6 changes: 3 additions & 3 deletions docs/zh_cn/multi_modal/internvl.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
dict(type='text', text=f'<img>{IMAGE_TOKEN}{IMAGE_TOKEN}</img>\nDescribe the two images in detail.'),
dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
Expand All @@ -90,7 +90,7 @@ from lmdeploy.vl.constants import IMAGE_TOKEN
pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')
messages = [
dict(role='user', content=[
dict(type='text', text=f'Image-1: <img>{IMAGE_TOKEN}</img>\nImage-2: <img>{IMAGE_TOKEN}</img>\nDescribe the two images in detail.'),
dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),
dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))
])
Expand Down Expand Up @@ -151,7 +151,7 @@ imgs = load_video(video_path, num_segments=8)

question = ''
for i in range(len(imgs)):
question = question + f'Frame{i+1}: <img>{IMAGE_TOKEN}</img>\n'
question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n'

question += 'What is the red panda doing?'

Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,9 @@ def match(cls, model_path: str) -> Optional[str]:
Args:
model_path (str): the model path used for matching.
"""
# reject InternVL2-Llama3-76B
if 'internvl2' in model_path.lower():
return None
if 'llama-3-' in model_path.lower() or 'llama3-' in model_path.lower():
return 'llama3'

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def _forward_loop(self):
while record.total == 0 or (self._que.qsize() and
record.total < self.max_batch_size):
while self._que.qsize() == 0:
await asyncio.sleep(0)
await asyncio.sleep(0.01)
item = await self._que.get()
record.enqueue(item[0], item[1], item[2])
inputs, kwargs = record.dequeue(self.max_batch_size)
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def build_model(self):
@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
pixel_values = self.processor(images,
return_tensors='pt')['pixel_values']
pixel_values = self.processor(
images, return_tensors='pt',
input_data_format='channels_last')['pixel_values']
pixel_values = pixel_values.to(device=self.model.device,
dtype=self.model.dtype)
image_outputs = self.model.vision_tower.forward(
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
from transformers.models.llava_next.modeling_llava_next import \
image_size_to_num_patches
"""forward."""
processed_inputs = self.processor(images, return_tensors='pt')
processed_inputs = self.processor(images,
return_tensors='pt',
input_data_format='channels_last')
pixel_values = processed_inputs['pixel_values'].to(
device=self.model.device, dtype=self.model.dtype)
image_sizes = processed_inputs['image_sizes'].to(
Expand Down
51 changes: 30 additions & 21 deletions lmdeploy/vl/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import os
import sys
from contextlib import contextmanager
Expand Down Expand Up @@ -104,7 +105,7 @@ def hack_import_with(src: List[str], dst: str = 'torch'):
sys.modules.pop(item, None)


def _set_func(origin_func_path: str,
def _set_func(origin_func_path: Union[str, None],
rewrite_func: Callable,
origin_func: Callable = None):
"""Replace old function with the new function.
Expand All @@ -115,22 +116,19 @@ def _set_func(origin_func_path: str,
origin_func (Callable): function to replace
"""
# import module
split_path = origin_func_path.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue

method_class = False
if len(split_path) > 1:
module_or_class = eval('.'.join(split_path[:-1]))
if isinstance(module_or_class, type):
method_class = True

origin_func = eval(origin_func_path) \
if origin_func is None else origin_func
if isinstance(origin_func_path, str):
split_path = origin_func_path.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue

origin_func = eval(origin_func_path) \
if origin_func is None else origin_func

method_class = inspect.ismethod(origin_func)

# replace method
if not method_class:
Expand All @@ -146,23 +144,34 @@ def _set_func(origin_func_path: str,
for i, v in enumerate(ref):
if id(v) == obj_id:
ref[i] = rewrite_func
exec(f'{origin_func_path} = rewrite_func')
if isinstance(origin_func_path, str):
exec(f'{origin_func_path} = rewrite_func')
elif method_class:
raise NotImplementedError

return origin_func


@contextmanager
def rewrite_ctx(origin_func_path: List[str], rewrite_func: List[Callable]):
def rewrite_ctx(origin_func_path: List[Union[str, Callable]],
rewrite_func: List[Callable]):
"""rewrite context."""
assert len(origin_func_path) == len(rewrite_func)
origin_func_list = []
for (func_path, dst_func) in zip(origin_func_path, rewrite_func):
origin_func = _set_func(func_path, dst_func)
if isinstance(func_path, Callable):
origin_func = _set_func(None, dst_func, func_path)
else:
origin_func = _set_func(func_path, dst_func)
origin_func_list.append(origin_func)
yield
for (func_path, dst_func, origin_func) in zip(origin_func_path,
rewrite_func,
origin_func_list):
_set_func(func_path, origin_func, dst_func)
if isinstance(func_path, Callable):
_set_func(None, origin_func, dst_func)
else:
_set_func(func_path, origin_func, dst_func)


def add_device_hook(module: torch.nn.Module,
Expand Down
24 changes: 22 additions & 2 deletions lmdeploy/vl/model/xcomposer2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.

import enum
import os
import sys
import warnings
from contextlib import contextmanager
from typing import Any, List, Tuple
Expand Down Expand Up @@ -48,14 +50,31 @@ def _CLIPVisionModel_from_pretrained(vision_tower_name):


@contextmanager
def init_empty_vit():
def init_empty_vit(model_path):
"""skip download vision model."""
origin_func_path = [
'transformers.CLIPVisionModel.from_pretrained',
]
rewrite_func = [
_CLIPVisionModel_from_pretrained,
]

model_type, _ = get_xcomposer_type(model_path)
if model_type == ModelType.XCOMPOSER2D5:
from transformers.dynamic_module_utils import \
get_class_from_dynamic_module
from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME
_ = get_class_from_dynamic_module(
'modeling_internlm_xcomposer2.get_font', model_path)
folder = model_path.rstrip(os.sep).split(os.sep)[-1]
module_path = '.'.join([
TRANSFORMERS_DYNAMIC_MODULE_NAME, folder,
'modeling_internlm_xcomposer2'
])
origin_get_font_func = getattr(sys.modules[module_path], 'get_font')
origin_func_path.append(origin_get_font_func)
rewrite_func.append(lambda: None)

with rewrite_ctx(origin_func_path, rewrite_func):
yield

Expand All @@ -77,7 +96,8 @@ def match(cls, config: AutoConfig):

def build_model(self):
from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings(), init_empty_vit():
with init_empty_weights(), warnings.catch_warnings(), \
init_empty_vit(self.model_path):
warnings.simplefilter('ignore')
config = self.hf_config
model = AutoModelForCausalLM.from_config(config,
Expand Down
37 changes: 29 additions & 8 deletions lmdeploy/vl/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def _inner_call(i, images):

def append_image_token(self, prompt, num_images: int):
"""append image token to user prompt."""
if IMAGE_TOKEN in prompt:
return prompt
return (IMAGE_TOKEN + '\n') * num_images + prompt

def convert_messages(self, messages, sequence_start=True):
Expand All @@ -130,9 +132,8 @@ def convert_messages(self, messages, sequence_start=True):
num_images += 1
elif item['type'] == 'text':
prompt = item['text']
# if IMAGE_TOKEN in user prompt, use user custom prompt instead
# of adding IMAGE_TOKEN to user prompt
if IMAGE_TOKEN not in prompt and num_images > 0:
if num_images > 0:
# add IMAGE_TOKEN to user prompt
prompt = self.append_image_token(prompt, num_images)
new_item = {'role': 'user', 'content': prompt}
new_messages.append(new_item)
Expand Down Expand Up @@ -161,15 +162,26 @@ class InternVLChatTemplateWrapper(VLChatTemplateWrapper):

def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
# not sure whether support multi images.
return f'<img>{IMAGE_TOKEN * num_images}</img>\n' + prompt
# lmdeploy uses <IMAGET_TOKEN> as image token
# internvl uses special tags
if IMAGE_TOKEN in prompt and f'<img>{IMAGE_TOKEN}' not in prompt:
prompt = prompt.replace(f'{IMAGE_TOKEN}',
f'<img>{IMAGE_TOKEN}</img>')
prompt = prompt.replace('</img><img>', '')
prompt = prompt.replace('<img><img>', '<img>')
prompt = prompt.replace('</img></img>', '</img>')
elif IMAGE_TOKEN not in prompt:
prompt = f'<img>{IMAGE_TOKEN * num_images}</img>\n' + prompt
return prompt


class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper):
"""DeepSeek vl chat template."""

def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
if IMAGE_TOKEN in prompt:
return prompt
logger.error(
f'for deepseek-vl model, the user should insert the {IMAGE_TOKEN} '
'to user prompt manually, please read https://lmdeploy.readthedocs'
Expand All @@ -188,6 +200,8 @@ class QwenVLChatTemplateWrapper(VLChatTemplateWrapper):

def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
if IMAGE_TOKEN in prompt:
return prompt
res = ''
for i in range(num_images):
res += f'Picture {str(i)}:{IMAGE_TOKEN}\n'
Expand Down Expand Up @@ -256,6 +270,8 @@ class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper):
"""InternLM-XComposer2 chat template."""

def append_image_token(self, prompt, num_images: int):
if IMAGE_TOKEN in prompt:
return prompt
logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, '
'the user can manually insert the token to prompt')
return ' '.join([IMAGE_TOKEN] * num_images) + prompt
Expand All @@ -268,6 +284,8 @@ def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
if num_images == 0:
return prompt
if IMAGE_TOKEN in prompt:
return prompt
res = f'{IMAGE_TOKEN}\n'
assert num_images <= 1, 'MiniGeminiLlama accepts 1 input image'
res = res + prompt
Expand All @@ -278,12 +296,15 @@ class MiniCPMVTempateWrapper(VLChatTemplateWrapper):
"""MiniCPM-Llama3-V-2_5 chat template."""

def append_image_token(self, prompt, num_images: int):
return f'<image>{IMAGE_TOKEN}</image>\n' * num_images + prompt
if IMAGE_TOKEN in prompt:
return prompt
prompt = f'{IMAGE_TOKEN}\n' * num_images + prompt
return prompt

def update_image_token(self, prompt, features):
_features = []
_prompt = []
segs = prompt.split(f'<image>{IMAGE_TOKEN}</image>\n')
segs = prompt.split(f'{IMAGE_TOKEN}\n')
for i, seg in enumerate(segs):
if i > 0 and i <= len(features):
_feat = features[i - 1]['embeddings'].split(1)
Expand All @@ -309,7 +330,7 @@ class MiniCPMV26TempateWrapper(MiniCPMVTempateWrapper):
def update_image_token(self, prompt, features):
_features = []
_prompt = []
segs = prompt.split(f'<image>{IMAGE_TOKEN}</image>\n')
segs = prompt.split(f'{IMAGE_TOKEN}\n')
idx = 0
for i, seg in enumerate(segs):
if i > 0 and i <= len(features):
Expand Down

0 comments on commit 659a6b0

Please sign in to comment.