Skip to content

Commit

Permalink
Merge pull request #648 from OptimalScale/lianqing/multi_modal_training
Browse files Browse the repository at this point in the history
[FIX] Fix multi-modal training
  • Loading branch information
research4pan authored Sep 21, 2023
2 parents aa055d4 + 19f4ac0 commit c6b2f14
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 47 deletions.
3 changes: 2 additions & 1 deletion examples/finetune_multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def main():
model = AutoModel.get_model(model_args, tune_strategy='none',
ds_config=pipeline_args.deepspeed,
custom_model=True,
with_deepspeed=False)
with_deepspeed=False,
pipeline_args=pipeline_args)
# FIXME check if need to move this part to hf_encoder_decoder.py
for param in model.backend_model.parameters():
param.requires_grad = False
Expand Down
9 changes: 5 additions & 4 deletions scripts/multimodal/run_finetune_multi_modal_stage1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ model_name_or_path=Salesforce/blip2-flan-t5-xxl
# https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K
dataset_path=./data/llava_cc3m_pretrain_595k/chat.json
image_folder=./data/llava_cc3m_pretrain_595k/images
output_dir=output_models/finetune
output_dir=output_models/finetune_llava-336px-vicuna-7b-v1.3_stage1

deepspeed_args="--master_port=12000"

while [[ $# -ge 1 ]]; do
Expand Down Expand Up @@ -62,9 +63,8 @@ deepspeed ${deepspeed_args} \
--llm_model_name_or_path lmsys/vicuna-7b-v1.5 \
--image_aspect_ratio None \
--fp16 True \
--learning_rate 2e-5 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--per_device_train_batch_size 8 \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
Expand All @@ -77,5 +77,6 @@ deepspeed ${deepspeed_args} \
--save_steps 5000 \
--dataloader_num_workers 1 \
--num_train_epochs 1 \
--save_language_projection True \
| tee ${log_dir}/train.log \
2> ${log_dir}/train.err
26 changes: 17 additions & 9 deletions scripts/multimodal/run_finetune_multi_modal_stage2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
model_name_or_path=Salesforce/blip2-flan-t5-xxl
# Please download the coco dataset and the conversation file from
# https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_80k.json
dataset_path=./data/llava_instruct_80k.json
dataset_path=./data/llava_instruct_80k_truncated.json
image_folder=./data/coco2017/train2017
output_dir=output_models/finetune
output_dir=output_models/finetune_llava-336px-vicuna-7b-v1.3_stage2
deepspeed_args="--master_port=12000"

while [[ $# -ge 1 ]]; do
Expand Down Expand Up @@ -50,31 +50,39 @@ fi
if [ ! -f data/llava_instruct_80k.json ]; then
cd data && ./download.sh llava_instruction_finetune_80k && cd -
fi

if [ ! -f data/llava_instruct_80k_truncated.json ]; then
python utils/preprocess_multimodal_data.py \
--data_path data/llava_instruct_80k.json \
--save_path data/llava_instruct_80k_truncated.json
fi

# Finetune
exp_id=finetune
project_dir=$(cd "$(dirname $0)"/..; pwd)
log_dir=${project_dir}/log/${exp_id}
mkdir -p ${output_dir} ${log_dir}

# train batch size is set as 4
# default in llava is 16
deepspeed ${deepspeed_args} \
examples/finetune_multi_modal.py \
--deepspeed configs/ds_config_multimodal.json \
--deepspeed configs/ds_config_zero2.json \
--arch_type vision_encoder_decoder \
--llava_loading True \
--model_name_or_path ${model_name_or_path} \
--image_encoder_name_or_path openai/clip-vit-large-patch14 \
--image_encoder_name_or_path openai/clip-vit-large-patch14-336 \
--pretrained_language_projection_path output_models/llava-336px-pretrain-vicuna-7b-v1.3_language_projection.pth \
--dataset_path ${dataset_path} \
--output_dir ${output_dir} --overwrite_output_dir \
--image_folder ${image_folder} \
--custom_vision_model True \
--llm_model_name_or_path lmsys/vicuna-7b-v1.5 \
--llm_model_name_or_path lmsys/vicuna-7b-v1.3 \
--image_aspect_ratio None \
--fp16 True \
--learning_rate 2e-5 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 2 \
--learning_rate 2e-3 \
--gradient_accumulation_steps 8 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
Expand Down
5 changes: 3 additions & 2 deletions scripts/multimodal/run_vis_chatbot_llava.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# only work for gpu mem > 25G; fail to do 4 bit and 8 bit inference.
model_name_or_path=Salesforce/blip2-flan-t5-xxl
llava_pretrain_model_path="output_models/llava-v1-0719-336px-lora-merge-vicuna-13b-v1.3/"
deepspeed_args="--master_port=12000"
Expand All @@ -23,13 +24,13 @@ deepspeed ${deepspeed_args} \
--custom_model True \
--chatbot_type llava \
--prompt_structure '{input_text} ASSISTANT:' \
--low_resource True \
--llava_loading True \
--model_name_or_path ${model_name_or_path} \
--image_encoder_name_or_path openai/clip-vit-large-patch14 \
--image_encoder_name_or_path openai/clip-vit-large-patch14-336 \
--custom_vision_model True \
--llm_model_name_or_path lmsys/vicuna-13b-v1.5 \
--llava_pretrain_model_path ${llava_pretrain_model_path}"*.bin" \
--with_deepspeed False \
--save_pretrain_model_path "output_models/lmflow_llava-v1-0719-336px-lora-merge-vicuna-13b-v1.3" \
${@:1}

8 changes: 7 additions & 1 deletion src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,20 @@ class FinetunerArguments(TrainingArguments):
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
"help": "wheather to remove the unused columns in collate fn"}
"help": "whether to remove the unused columns in collate fn"}
)
finetune_part: Optional[str] = field(
default="language_projection",
metadata={
"help": "the module to finetune."
}
)
save_language_projection: Optional[str] = field(
default=False,
metadata={
"help": "whether to save language projection layer in multi-modal models."
}
)


@dataclass
Expand Down
13 changes: 9 additions & 4 deletions src/lmflow/models/hf_encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

import torch
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.deepspeed import HfDeepSpeedConfig, HfTrainerDeepSpeedConfig

from transformers.testing_utils import CaptureLogger

Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
use_accelerator=False,
custom_model=False,
with_deepspeed=True,
pipeline_args=None,
*args,
**kwargs
):
Expand Down Expand Up @@ -119,7 +120,10 @@ def __init__(
raise NotImplementedError(
f"Currently encoder2decoder model is not supported with accelerator"
)
dschf = HfDeepSpeedConfig(ds_config)
# dschf = HfDeepSpeedConfig(ds_config)
dschf = HfTrainerDeepSpeedConfig(ds_config)
if pipeline_args is not None:
dschf.trainer_config_process(pipeline_args)
peft_model_id = model_args.lora_model_path
# NOTE: Currently offload is not supported by llama
if "llama" in model_args.model_name_or_path and model_args.use_ram_optimized_load:
Expand Down Expand Up @@ -198,10 +202,11 @@ def __init__(
kwargs = dict(
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
device_map="auto",
)
else:
kwargs = {}
# kwargs = dict(torch_dtype=torch.float16)
kwargs = dict(device_map="auto")
if (model_args.image_encoder_name_or_path is None and
model_args.qformer_name_or_path is None and
model_args.llm_model_name_or_path is None):
Expand Down
31 changes: 19 additions & 12 deletions src/lmflow/models/vision2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.deepspeed import is_deepspeed_zero3_enabled

from lmflow.models.base_model import BaseModel
from lmflow.models.vision_encoder import build_vision_tower
Expand Down Expand Up @@ -77,15 +78,18 @@ def __init__(self,
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config)
if low_resource:
kwargs = dict(
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
)
else:
kwargs = {}
kwargs = dict()
if language_model_name_or_path is not None:
if low_resource:
kwargs = dict(
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
low_cpu_mem_usage=True)
else:
if not is_deepspeed_zero3_enabled:
kwargs = dict(device_map="auto",
torch_dtype=torch.float16)
language_model = AutoModelForCausalLM.from_pretrained(
language_model_name_or_path, **kwargs)
config.text_config = language_model.config
Expand All @@ -97,7 +101,7 @@ def __init__(self,
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, **kwargs)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
if getattr(language_model, "_tied_weights_keys", None) is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]

self.language_model = language_model
Expand Down Expand Up @@ -237,7 +241,6 @@ def forward(
batch_size = pixel_values.shape[0]
else:
batch_size = 1

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -298,6 +301,9 @@ def forward(
# if inputs_embeds is not None:
# print("input_embeds", inputs_embeds.shape)
# attention_mask.shape, inputs_embeds.shape)
# TODO remove this code by fixing the ddp training issue
inputs_embeds = inputs_embeds.to(
self.language_model.lm_head.weight.dtype)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -356,7 +362,7 @@ def processor_image_token_in_minigpt4(self,

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
inputs_embeds = inputs_embeds.to(device=language_model_inputs.device)
# concatenate the text embeddings with image embeddings
inputs_embeds_with_images = []
attention_mask_with_images = []
Expand Down Expand Up @@ -426,7 +432,6 @@ def generate(
batch_size = pixel_values.shape[0]
else:
batch_size = 1

if not self.custom_vision_model:
# do the processing as blip2 and mini gpt-4;
image_embeds = self.vision_model(
Expand Down Expand Up @@ -473,6 +478,8 @@ def generate(
self.language_model.model)
# convert the dtype.
# FIXME check when need to do this
inputs_embeds = inputs_embeds.to(
device=self.language_model.lm_head.weight.device)
inputs_embeds = inputs_embeds.to(
self.language_model.lm_head.weight.dtype)
outputs = self.language_model.generate(
Expand Down
2 changes: 1 addition & 1 deletion src/lmflow/models/vision_encoder/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def prepare_inputs_labels_for_multimodal(
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images, language_projection)

new_input_embeds = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
Expand Down Expand Up @@ -164,6 +163,7 @@ def prepare_inputs_labels_for_multimodal(
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
cur_labels = cur_labels[image_token_start+2:]
else:
cur_input_ids = cur_input_ids.to(device=language_model.device)
cur_new_input_embeds.append(language_model.embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
if labels is not None:
Expand Down
10 changes: 8 additions & 2 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def compute_metrics(eval_preds):
callbacks=trainer_callbacks
)
# Training
# import pdb; pdb.set_trace()
if training_args.do_train:
checkpoint = None
last_checkpoint = self.last_checkpoint
Expand All @@ -312,7 +311,14 @@ def compute_metrics(eval_preds):
if model_args.save_aggregated_lora:
model.merge_lora_weights()
model.save(finetuner_args.output_dir,model_args.save_aggregated_lora)

# save language_projection for multi-modal model;
if self.finetuner_args.save_language_projection:
language_projection_state = trainer.model.language_projection.state_dict()
torch.save(
osp.join(
self.finetuner_args.output_dir,
"language_projection.pth"),
language_projection_state)
metrics = train_result.metrics

max_train_samples = (
Expand Down
13 changes: 6 additions & 7 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def inference(
temperature: float=0.0,
prompt_structure: str='{input}',
remove_image_flag: bool=False,
chatbot_format: str="mini_gpt",
chatbot_type: str="mini_gpt",
):
"""
Perform inference for a model
Expand Down Expand Up @@ -165,9 +165,9 @@ def inference(
input['images'] = np.array(input['images'])
if remove_image_flag:
# remove the image flag <ImageHere> in tokenization;
if chatbot_format == "mini_gpt":
if chatbot_type == "mini_gpt":
image_split_flag = "<ImageHere>"
elif chatbot_format:
elif chatbot_type:
image_split_flag = "<image>"
else:
raise NotImplementedError
Expand All @@ -186,7 +186,7 @@ def inference(
).to(device=self.local_rank)
input_ids.append(temp_inputs['input_ids'])
attention_mask.append(temp_inputs['attention_mask'])
if chatbot_format == "llava":
if chatbot_type == "llava":
# add the flag for inserting the image.
# TODO should merge the way of handling image flag in minigpt and llava.
index_tensor = torch.tensor(
Expand All @@ -200,7 +200,7 @@ def inference(
temp_inputs["input_ids"].shape[1])
if len(image_token_indexes) > 1:
image_token_indexes = image_token_indexes[:-1]
if chatbot_format == "llava":
if chatbot_type == "llava":
input_ids = input_ids[:-1]
attention_mask = attention_mask[:-1]
inputs = temp_inputs
Expand All @@ -219,7 +219,6 @@ def inference(
raise NotImplementedError(
f"device \"{self.inferencer_args.device}\" is not supported"
)

if remove_image_flag:
inputs["image_token_indexes"] = image_token_indexes
inputs["one_sample_multiple_images"] = True
Expand Down Expand Up @@ -553,4 +552,4 @@ def speculative_sampling(input_ids: torch.Tensor,


def stream_inference(self):
raise NotImplementedError("Streaming output for SpeculativeInferencer is not supported yet")
raise NotImplementedError("Streaming output for SpeculativeInferencer is not supported yet")
7 changes: 3 additions & 4 deletions src/lmflow/utils/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ def adapt_llava_model_to_lmflow_type(state_dict):
key = key.replace("model.embed_tokens",
"language_model.model.embed_tokens")
key = key.replace("model.mm_projector", "language_projection")
key = key.replace("lm_head", "model.language_model.lm_head")
key = key.replace("model.norm", "language_model.model.layers")
key = key.replace("lm_head", "language_model.lm_head")
key = key.replace("model.norm", "language_model.model.norm")
if "vision_tower" in key:
continue
new_state_dict[key] = item
return new_state_dict

return new_state_dict
Loading

0 comments on commit c6b2f14

Please sign in to comment.