Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ovis 1.6 #2211

Merged
merged 6 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
- [Citation](#-citation)

## 📝 Introduction
SWIFT supports training(PreTraining/Fine-tuning/RLHF), inference, evaluation and deployment of **350+ LLMs and 90+ MLLMs** (multimodal large models). Developers can directly apply our framework to their own research and production environments to realize the complete workflow from model training and evaluation to application. In addition to supporting the lightweight training solutions provided by [PEFT](https://github.com/huggingface/peft), we also provide a complete **Adapters library** to support the latest training techniques such as NEFTune, LoRA+, LLaMA-PRO, etc. This adapter library can be used directly in your own custom workflow without our training scripts.
SWIFT supports training(PreTraining/Fine-tuning/RLHF), inference, evaluation and deployment of **350+ LLMs and 100+ MLLMs** (multimodal large models). Developers can directly apply our framework to their own research and production environments to realize the complete workflow from model training and evaluation to application. In addition to supporting the lightweight training solutions provided by [PEFT](https://github.com/huggingface/peft), we also provide a complete **Adapters library** to support the latest training techniques such as NEFTune, LoRA+, LLaMA-PRO, etc. This adapter library can be used directly in your own custom workflow without our training scripts.

To facilitate use by users unfamiliar with deep learning, we provide a Gradio web-ui for controlling training and inference, as well as accompanying deep learning courses and best practices for beginners. SWIFT web-ui is available both on [Huggingface space](https://huggingface.co/spaces/tastelikefeet/swift) and [ModelScope studio](https://www.modelscope.cn/studios/iic/Scalable-lightWeight-Infrastructure-for-Fine-Tuning/summary), please feel free to try!

Expand All @@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group:
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">

## 🎉 News
- 2024.10.09: Support for training and deploying ovis1.6-gemma2 series models. Experience it using `swift infer --model_type ovis1_6-gemma2-9b`.
- 2024.09.26: Support for training and deploying llama3.2-vision series models. Experience it using `swift infer --model_type llama3_2-11b-vision-instruct`.
- 2024.09.26: Support for training and deploying llama3.2 series models. Experience it using `swift infer --model_type llama3_2-1b-instruct`.
- 2024.09.25: Support for training to deployment with got-ocr2. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/2122).
Expand Down Expand Up @@ -642,6 +643,7 @@ The complete list of supported models and datasets can be found at [Supported Mo
| Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | English | 8B | chat model |
| Pixtral | [mistralai](https://huggingface.co/mistralai) | English | 12B | chat model |
| Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | English | 8B | chat model |
| Ovis | [Ovis](https://github.com/AIDC-AI/Ovis) | English | 9B | chat model |


#### Diffusion Models
Expand Down
4 changes: 3 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
- [引用](#-引用)

## 📝 简介
SWIFT支持**350+ LLM和90+ MLLM**(多模态大模型)的训练(预训练、微调、对齐)、推理、评测和部署。开发者可以直接将我们的框架应用到自己的Research和生产环境中,实现模型训练评测到应用的完整链路。我们除支持了[PEFT](https://github.com/huggingface/peft)提供的轻量训练方案外,也提供了一个完整的**Adapters库**以支持最新的训练技术,如NEFTune、LoRA+、LLaMA-PRO等,这个适配器库可以脱离训练脚本直接使用在自己的自定流程中。
SWIFT支持**350+ LLM和100+ MLLM**(多模态大模型)的训练(预训练、微调、对齐)、推理、评测和部署。开发者可以直接将我们的框架应用到自己的Research和生产环境中,实现模型训练评测到应用的完整链路。我们除支持了[PEFT](https://github.com/huggingface/peft)提供的轻量训练方案外,也提供了一个完整的**Adapters库**以支持最新的训练技术,如NEFTune、LoRA+、LLaMA-PRO等,这个适配器库可以脱离训练脚本直接使用在自己的自定流程中。

为方便不熟悉深度学习的用户使用,我们提供了一个Gradio的web-ui用于控制训练和推理,并提供了配套的深度学习课程和最佳实践供新手入门。 可以在[Huggingface space](https://huggingface.co/spaces/tastelikefeet/swift) 和 [ModelScope创空间](https://www.modelscope.cn/studios/iic/Scalable-lightWeight-Infrastructure-for-Fine-Tuning/summary) 中体验SWIFT web-ui功能了。

Expand All @@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:


## 🎉 新闻
- 2024.10.09: 支持ovis1.6-gemma2的训练到部署. 使用`swift infer --model_type ovis1_6-gemma2-9b`进行体验.
- 2024.09.26: 支持llama3.2-vision系列模型的训练到部署. 使用`swift infer --model_type llama3_2-11b-vision-instruct`进行体验.
- 2024.09.26: 支持llama3.2系列模型的训练到部署. 使用`swift infer --model_type llama3_2-1b-instruct`进行体验.
- 2024.09.25: 支持got-ocr2的训练到部署. 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/2122).
Expand Down Expand Up @@ -635,6 +636,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
| Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | 英文 | 8B | chat模型 |
| Pixtral | [mistralai](https://huggingface.co/mistralai) | 英文 | 12B | chat模型 |
| Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | 英文 | 8B | chat模型 |
| Ovis | [Ovis](https://github.com/AIDC-AI/Ovis) | English | 9B | chat模型 |


#### 扩散模型
Expand Down
1 change: 1 addition & 0 deletions docs/source/Instruction/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@
|internvl2-llama3-76b-awq|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://modelscope.cn/models/OpenGVLab/InternVL2-Llama3-76B-AWQ/summary)|^(language_model\|mlp1)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|internvl2|&#x2714;|&#x2714;|&#x2714;|&#x2718;|transformers>=4.36, timm|vision, video|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B-AWQ)|
|deepseek-vl-1_3b-chat|[deepseek-ai/deepseek-vl-1.3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-1.3b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|&#x2714;|&#x2718;|&#x2714;|&#x2718;||vision|[deepseek-ai/deepseek-vl-1.3b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat)|
|deepseek-vl-7b-chat|[deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|&#x2714;|&#x2718;|&#x2714;|&#x2718;||vision|[deepseek-ai/deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat)|
|ovis1_6-gemma2-9b|[AIDC-AI/Ovis1.6-Gemma2-9B](https://modelscope.cn/models/AIDC-AI/Ovis1.6-Gemma2-9B/summary)|^(llm)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|ovis1_6|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers>=4.42|vision|[AIDC-AI/Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)|
|paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.41|vision|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)|
|paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.41|vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)|
|paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.41|vision|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)|
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Supported-models-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ The table below introcudes all models supported by SWIFT:
|internvl2-llama3-76b-awq|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://modelscope.cn/models/OpenGVLab/InternVL2-Llama3-76B-AWQ/summary)|^(language_model\|mlp1)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|internvl2|&#x2714;|&#x2714;|&#x2714;|&#x2718;|transformers>=4.36, timm|vision, video|[OpenGVLab/InternVL2-Llama3-76B-AWQ](https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B-AWQ)|
|deepseek-vl-1_3b-chat|[deepseek-ai/deepseek-vl-1.3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-1.3b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|&#x2714;|&#x2718;|&#x2714;|&#x2718;||vision|[deepseek-ai/deepseek-vl-1.3b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat)|
|deepseek-vl-7b-chat|[deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat/summary)|^(language_model\|aligner)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|deepseek-vl|&#x2714;|&#x2718;|&#x2714;|&#x2718;||vision|[deepseek-ai/deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat)|
|ovis1_6-gemma2-9b|[AIDC-AI/Ovis1.6-Gemma2-9B](https://modelscope.cn/models/AIDC-AI/Ovis1.6-Gemma2-9B/summary)|^(llm)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|ovis1_6|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers>=4.42|vision|[AIDC-AI/Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)|
|paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.41|vision|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)|
|paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.41|vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)|
|paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|paligemma|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.41|vision|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)|
Expand Down
47 changes: 43 additions & 4 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ class ModelType:
gemma2_2b_instruct = 'gemma2-2b-instruct'
gemma2_9b_instruct = 'gemma2-9b-instruct'
gemma2_27b_instruct = 'gemma2-27b-instruct'

ovis1_6_gemma2_9b = 'ovis1_6-gemma2-9b'
# paligemma
paligemma_3b_pt_224 = 'paligemma-3b-pt-224'
paligemma_3b_pt_448 = 'paligemma-3b-pt-448'
Expand Down Expand Up @@ -652,6 +654,7 @@ class LoRATM(NamedTuple):
llama3_1_omni = 'llama3_1_omni'
got_ocr2 = 'got_ocr2'
llama3_2_vision = 'llama3_2_vision'
ovis1_6 = 'ovis1_6'
# default lora target modules for nlp llms.
minicpm3 = ['q_a_proj', 'q_b_proj', 'kv_a_proj_with_mqa', 'kv_b_proj']
baichuan = ['W_pack']
Expand Down Expand Up @@ -2745,6 +2748,40 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)


@register_model(
ModelType.ovis1_6_gemma2_9b,
'AIDC-AI/Ovis1.6-Gemma2-9B',
LoRATM.ovis1_6,
TemplateType.ovis1_6,
requires=['transformers>=4.42'],
support_flash_attn=True,
tags=['multi-modal', 'vision'],
hf_model_id='AIDC-AI/Ovis1.6-Gemma2-9B')
def get_model_tokenizer_ovis(*args, **kwargs):
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
if model is not None:
func_list = ['generate', 'forward', 'get_input_embeddings']
_use_submodel_func(model, 'llm', func_list)
embedding = model.get_input_embeddings()
embedding.register_forward_hook(_clone_hook)
try:
# fix device_map
from transformers.cache_utils import HybridCache

def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
**kwargs) -> Tuple[torch.Tensor]:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)

if not hasattr(HybridCache, '_update_origin'):
HybridCache._update_origin = HybridCache.update
HybridCache.update = update
except ImportError:
pass
return model, tokenizer


@register_model(
ModelType.mplug_owl3_7b_chat,
'iic/mPLUG-Owl3-7B-240728',
Expand All @@ -2762,8 +2799,9 @@ def get_model_tokenizer_mplug_owl3(model_dir: str,
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
processor = model.init_processor(tokenizer)
tokenizer.processor = processor
func_list = ['generate', 'forward']
_use_submodel_func(model, 'language_model', func_list)
if model is not None:
func_list = ['generate', 'forward']
_use_submodel_func(model, 'language_model', func_list)
return model, tokenizer


Expand Down Expand Up @@ -2958,8 +2996,9 @@ def get_model_tokenizer_florence(model_dir: str,
model_dir, torch_dtype, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs)

tokenizer.processor = processor
# model.vision_tower.enable_checkpoint = True
_use_submodel_func(model, 'language_model', ['generate', 'forward'])
if model is not None:
model.vision_tower.enable_checkpoint = True
_use_submodel_func(model, 'language_model', ['generate', 'forward'])
return model, tokenizer


Expand Down
60 changes: 60 additions & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class TemplateType:
c4ai = 'c4ai'
chatml = 'chatml'
got_ocr2 = 'got_ocr2'
ovis1_6 = 'ovis1_6'
# compatibility. (Deprecated)
default_generation_bos = 'default-generation-bos'
yi = 'yi'
Expand Down Expand Up @@ -1285,6 +1286,65 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
register_template(TemplateType.got_ocr2, GOT_OCR2Template(), lazy_tokenize=True, use_model=True)


class OVIS1_6Template(Template):

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
assert media_type == 'image'
return [[-200], '\n']

def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, tokenizer_kwargs = super()._encode(example)
if len(inputs) == 0:
return inputs, {}
images = example['images']
input_ids = inputs['input_ids']
labels = inputs['labels']
idx_list = _findall(input_ids, [-200])
added_tokens_len = 0
pixel_values = []
for i, idx in enumerate(idx_list):
max_partition = get_env_args('max_partition', int, 9)
raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image(
images[i], max_partition=max_partition)
input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:]
if labels is not None:
labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:]
pixel_values.append(raw_pixel_values)
added_tokens_len += len(image_placeholders) - 1
if pixel_values:
pixel_values = torch.cat(pixel_values, dim=0).to(self.model.visual_tokenizer.dtype)
else:
pixel_values = None
inputs = {'labels': labels}
if labels is not None:
labels = torch.tensor(labels)[None]
inputs['_data'] = {'input_ids': torch.tensor(input_ids)[None], 'labels': labels, 'pixel_values': [pixel_values]}
return inputs, {}

def _post_encode(self, model, data: Any) -> Dict[str, Any]:
_, inputs_embeds, labels, _ = self.model.merge_multimodal(
text_input_ids=data['input_ids'],
text_attention_masks=torch.ones_like(data['input_ids']), # not use, only compat
text_labels=data['labels'],
pixel_values=data['pixel_values'],
left_padding=True)
return {'inputs_embeds': inputs_embeds[0], 'labels': labels}

@staticmethod
def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
return generate_ids


register_template(
TemplateType.ovis1_6,
OVIS1_6Template(['<bos>'], ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
['<end_of_turn>\n'], ['<end_of_turn>'], None,
['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n']),
lazy_tokenize=True,
use_model=True)


class _QwenVLTemplateMixin:
load_medias = False

Expand Down
7 changes: 1 addition & 6 deletions swift/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ def __getattr__(self, name: str) -> Any:
return value

def _get_module(self, module_name: str):
try:
return importlib.import_module('.' + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its'
f' traceback):\n{e}') from e
return importlib.import_module('.' + module_name, self.__name__)

def __reduce__(self):
return self.__class__, (self._name, self.__file__, self._import_structure)
6 changes: 6 additions & 0 deletions swift/utils/module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ def __post_init__(self):
vision_tower='vision_model',
)

OVIS1_6 = MultiModelKeys(
language_model='llm',
vision_tower='visual_tokenizer',
)

MODEL_KEYS_MAPPING = OrderedDict([
# MLLM here
('qwen_audio', QWEN_AUDIO_KEYS),
Expand All @@ -324,6 +329,7 @@ def __post_init__(self):
('llama3_1_omni', LLAMA3_1_OMNI),
('got_ocr2', GOT_OCR2),
('llama3_2_vision', LLAMA3_2_VISION),
('ovis1_6', OVIS1_6),
# LLM begins here
('llama', LLAMA_KEYS),
('mistral', LLAMA_KEYS),
Expand Down
Loading