Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix multi-modal training #648

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/finetune_multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def main():
model = AutoModel.get_model(model_args, tune_strategy='none',
ds_config=pipeline_args.deepspeed,
custom_model=True,
with_deepspeed=False)
with_deepspeed=False,
pipeline_args=pipeline_args)
# FIXME check if need to move this part to hf_encoder_decoder.py
for param in model.backend_model.parameters():
param.requires_grad = False
Expand Down
9 changes: 5 additions & 4 deletions scripts/multimodal/run_finetune_multi_modal_stage1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ model_name_or_path=Salesforce/blip2-flan-t5-xxl
# https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K
dataset_path=./data/llava_cc3m_pretrain_595k/chat.json
image_folder=./data/llava_cc3m_pretrain_595k/images
output_dir=output_models/finetune
output_dir=output_models/finetune_llava-336px-vicuna-7b-v1.3_stage1

deepspeed_args="--master_port=12000"

while [[ $# -ge 1 ]]; do
Expand Down Expand Up @@ -62,9 +63,8 @@ deepspeed ${deepspeed_args} \
--llm_model_name_or_path lmsys/vicuna-7b-v1.5 \
--image_aspect_ratio None \
--fp16 True \
--learning_rate 2e-5 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--per_device_train_batch_size 8 \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
Expand All @@ -77,5 +77,6 @@ deepspeed ${deepspeed_args} \
--save_steps 5000 \
--dataloader_num_workers 1 \
--num_train_epochs 1 \
--save_language_projection True \
| tee ${log_dir}/train.log \
2> ${log_dir}/train.err
26 changes: 17 additions & 9 deletions scripts/multimodal/run_finetune_multi_modal_stage2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
model_name_or_path=Salesforce/blip2-flan-t5-xxl
# Please download the coco dataset and the conversation file from
# https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_80k.json
dataset_path=./data/llava_instruct_80k.json
dataset_path=./data/llava_instruct_80k_truncated.json
image_folder=./data/coco2017/train2017
output_dir=output_models/finetune
output_dir=output_models/finetune_llava-336px-vicuna-7b-v1.3_stage2
deepspeed_args="--master_port=12000"

while [[ $# -ge 1 ]]; do
Expand Down Expand Up @@ -50,31 +50,39 @@ fi
if [ ! -f data/llava_instruct_80k.json ]; then
cd data && ./download.sh llava_instruction_finetune_80k && cd -
fi

if [ ! -f data/llava_instruct_80k_truncated.json ]; then
python utils/preprocess_multimodal_data.py \
--data_path data/llava_instruct_80k.json \
--save_path data/llava_instruct_80k_truncated.json
fi

# Finetune
exp_id=finetune
project_dir=$(cd "$(dirname $0)"/..; pwd)
log_dir=${project_dir}/log/${exp_id}
mkdir -p ${output_dir} ${log_dir}

# train batch size is set as 4
# default in llava is 16
deepspeed ${deepspeed_args} \
examples/finetune_multi_modal.py \
--deepspeed configs/ds_config_multimodal.json \
--deepspeed configs/ds_config_zero2.json \
--arch_type vision_encoder_decoder \
--llava_loading True \
--model_name_or_path ${model_name_or_path} \
--image_encoder_name_or_path openai/clip-vit-large-patch14 \
--image_encoder_name_or_path openai/clip-vit-large-patch14-336 \
--pretrained_language_projection_path output_models/llava-336px-pretrain-vicuna-7b-v1.3_language_projection.pth \
--dataset_path ${dataset_path} \
--output_dir ${output_dir} --overwrite_output_dir \
--image_folder ${image_folder} \
--custom_vision_model True \
--llm_model_name_or_path lmsys/vicuna-7b-v1.5 \
--llm_model_name_or_path lmsys/vicuna-7b-v1.3 \
--image_aspect_ratio None \
--fp16 True \
--learning_rate 2e-5 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 2 \
--learning_rate 2e-3 \
--gradient_accumulation_steps 8 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
Expand Down
5 changes: 3 additions & 2 deletions scripts/multimodal/run_vis_chatbot_llava.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# only work for gpu mem > 25G; fail to do 4 bit and 8 bit inference.
model_name_or_path=Salesforce/blip2-flan-t5-xxl
llava_pretrain_model_path="output_models/llava-v1-0719-336px-lora-merge-vicuna-13b-v1.3/"
deepspeed_args="--master_port=12000"
Expand All @@ -23,13 +24,13 @@ deepspeed ${deepspeed_args} \
--custom_model True \
--chatbot_type llava \
--prompt_structure '{input_text} ASSISTANT:' \
--low_resource True \
--llava_loading True \
--model_name_or_path ${model_name_or_path} \
--image_encoder_name_or_path openai/clip-vit-large-patch14 \
--image_encoder_name_or_path openai/clip-vit-large-patch14-336 \
--custom_vision_model True \
--llm_model_name_or_path lmsys/vicuna-13b-v1.5 \
--llava_pretrain_model_path ${llava_pretrain_model_path}"*.bin" \
--with_deepspeed False \
--save_pretrain_model_path "output_models/lmflow_llava-v1-0719-336px-lora-merge-vicuna-13b-v1.3" \
${@:1}

8 changes: 7 additions & 1 deletion src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,20 @@ class FinetunerArguments(TrainingArguments):
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
"help": "wheather to remove the unused columns in collate fn"}
"help": "whether to remove the unused columns in collate fn"}
)
finetune_part: Optional[str] = field(
default="language_projection",
metadata={
"help": "the module to finetune."
}
)
save_language_projection: Optional[str] = field(
default=False,
metadata={
"help": "whether to save language projection layer in multi-modal models."
}
)


@dataclass
Expand Down
13 changes: 9 additions & 4 deletions src/lmflow/models/hf_encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

import torch
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.deepspeed import HfDeepSpeedConfig, HfTrainerDeepSpeedConfig

from transformers.testing_utils import CaptureLogger

Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
use_accelerator=False,
custom_model=False,
with_deepspeed=True,
pipeline_args=None,
*args,
**kwargs
):
Expand Down Expand Up @@ -119,7 +120,10 @@ def __init__(
raise NotImplementedError(
f"Currently encoder2decoder model is not supported with accelerator"
)
dschf = HfDeepSpeedConfig(ds_config)
# dschf = HfDeepSpeedConfig(ds_config)
dschf = HfTrainerDeepSpeedConfig(ds_config)
if pipeline_args is not None:
dschf.trainer_config_process(pipeline_args)
peft_model_id = model_args.lora_model_path
# NOTE: Currently offload is not supported by llama
if "llama" in model_args.model_name_or_path and model_args.use_ram_optimized_load:
Expand Down Expand Up @@ -198,10 +202,11 @@ def __init__(
kwargs = dict(
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
device_map="auto",
)
else:
kwargs = {}
# kwargs = dict(torch_dtype=torch.float16)
kwargs = dict(device_map="auto")
if (model_args.image_encoder_name_or_path is None and
model_args.qformer_name_or_path is None and
model_args.llm_model_name_or_path is None):
Expand Down
31 changes: 19 additions & 12 deletions src/lmflow/models/vision2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.deepspeed import is_deepspeed_zero3_enabled

from lmflow.models.base_model import BaseModel
from lmflow.models.vision_encoder import build_vision_tower
Expand Down Expand Up @@ -77,15 +78,18 @@ def __init__(self,
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config)
if low_resource:
kwargs = dict(
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
)
else:
kwargs = {}
kwargs = dict()
if language_model_name_or_path is not None:
if low_resource:
kwargs = dict(
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
low_cpu_mem_usage=True)
else:
if not is_deepspeed_zero3_enabled:
kwargs = dict(device_map="auto",
torch_dtype=torch.float16)
language_model = AutoModelForCausalLM.from_pretrained(
language_model_name_or_path, **kwargs)
config.text_config = language_model.config
Expand All @@ -97,7 +101,7 @@ def __init__(self,
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, **kwargs)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
if getattr(language_model, "_tied_weights_keys", None) is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]

self.language_model = language_model
Expand Down Expand Up @@ -237,7 +241,6 @@ def forward(
batch_size = pixel_values.shape[0]
else:
batch_size = 1

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -298,6 +301,9 @@ def forward(
# if inputs_embeds is not None:
# print("input_embeds", inputs_embeds.shape)
# attention_mask.shape, inputs_embeds.shape)
# TODO remove this code by fixing the ddp training issue
inputs_embeds = inputs_embeds.to(
self.language_model.lm_head.weight.dtype)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -356,7 +362,7 @@ def processor_image_token_in_minigpt4(self,

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
inputs_embeds = inputs_embeds.to(device=language_model_inputs.device)
# concatenate the text embeddings with image embeddings
inputs_embeds_with_images = []
attention_mask_with_images = []
Expand Down Expand Up @@ -426,7 +432,6 @@ def generate(
batch_size = pixel_values.shape[0]
else:
batch_size = 1

if not self.custom_vision_model:
# do the processing as blip2 and mini gpt-4;
image_embeds = self.vision_model(
Expand Down Expand Up @@ -473,6 +478,8 @@ def generate(
self.language_model.model)
# convert the dtype.
# FIXME check when need to do this
inputs_embeds = inputs_embeds.to(
device=self.language_model.lm_head.weight.device)
inputs_embeds = inputs_embeds.to(
self.language_model.lm_head.weight.dtype)
outputs = self.language_model.generate(
Expand Down
2 changes: 1 addition & 1 deletion src/lmflow/models/vision_encoder/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def prepare_inputs_labels_for_multimodal(
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images, language_projection)

new_input_embeds = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
Expand Down Expand Up @@ -164,6 +163,7 @@ def prepare_inputs_labels_for_multimodal(
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
cur_labels = cur_labels[image_token_start+2:]
else:
cur_input_ids = cur_input_ids.to(device=language_model.device)
cur_new_input_embeds.append(language_model.embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
if labels is not None:
Expand Down
10 changes: 8 additions & 2 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def compute_metrics(eval_preds):
callbacks=trainer_callbacks
)
# Training
# import pdb; pdb.set_trace()
if training_args.do_train:
checkpoint = None
last_checkpoint = self.last_checkpoint
Expand All @@ -312,7 +311,14 @@ def compute_metrics(eval_preds):
if model_args.save_aggregated_lora:
model.merge_lora_weights()
model.save(finetuner_args.output_dir,model_args.save_aggregated_lora)

# save language_projection for multi-modal model;
if self.finetuner_args.save_language_projection:
language_projection_state = trainer.model.language_projection.state_dict()
torch.save(
osp.join(
self.finetuner_args.output_dir,
"language_projection.pth"),
language_projection_state)
metrics = train_result.metrics

max_train_samples = (
Expand Down
13 changes: 6 additions & 7 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def inference(
temperature: float=0.0,
prompt_structure: str='{input}',
remove_image_flag: bool=False,
chatbot_format: str="mini_gpt",
chatbot_type: str="mini_gpt",
):
"""
Perform inference for a model
Expand Down Expand Up @@ -165,9 +165,9 @@ def inference(
input['images'] = np.array(input['images'])
if remove_image_flag:
# remove the image flag <ImageHere> in tokenization;
if chatbot_format == "mini_gpt":
if chatbot_type == "mini_gpt":
image_split_flag = "<ImageHere>"
elif chatbot_format:
elif chatbot_type:
image_split_flag = "<image>"
else:
raise NotImplementedError
Expand All @@ -186,7 +186,7 @@ def inference(
).to(device=self.local_rank)
input_ids.append(temp_inputs['input_ids'])
attention_mask.append(temp_inputs['attention_mask'])
if chatbot_format == "llava":
if chatbot_type == "llava":
# add the flag for inserting the image.
# TODO should merge the way of handling image flag in minigpt and llava.
index_tensor = torch.tensor(
Expand All @@ -200,7 +200,7 @@ def inference(
temp_inputs["input_ids"].shape[1])
if len(image_token_indexes) > 1:
image_token_indexes = image_token_indexes[:-1]
if chatbot_format == "llava":
if chatbot_type == "llava":
input_ids = input_ids[:-1]
attention_mask = attention_mask[:-1]
inputs = temp_inputs
Expand All @@ -219,7 +219,6 @@ def inference(
raise NotImplementedError(
f"device \"{self.inferencer_args.device}\" is not supported"
)

if remove_image_flag:
inputs["image_token_indexes"] = image_token_indexes
inputs["one_sample_multiple_images"] = True
Expand Down Expand Up @@ -553,4 +552,4 @@ def speculative_sampling(input_ids: torch.Tensor,


def stream_inference(self):
raise NotImplementedError("Streaming output for SpeculativeInferencer is not supported yet")
raise NotImplementedError("Streaming output for SpeculativeInferencer is not supported yet")
7 changes: 3 additions & 4 deletions src/lmflow/utils/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ def adapt_llava_model_to_lmflow_type(state_dict):
key = key.replace("model.embed_tokens",
"language_model.model.embed_tokens")
key = key.replace("model.mm_projector", "language_projection")
key = key.replace("lm_head", "model.language_model.lm_head")
key = key.replace("model.norm", "language_model.model.layers")
key = key.replace("lm_head", "language_model.lm_head")
key = key.replace("model.norm", "language_model.model.norm")
if "vision_tower" in key:
continue
new_state_dict[key] = item
return new_state_dict

return new_state_dict
Loading