Skip to content

Commit

Permalink
[Feat/Model] add specific configs when using saving final_weights met…
Browse files Browse the repository at this point in the history
…hod. (#151)

* update

* update

* update:mean_loss

* black format

* Delete accelerate_config_fsdp.yaml

this file should be at pipeline/accelerate_configs/xxxx.yaml

* update llava training and saving config

* delete accelerate_config

* update

* update

* Delete config.json

* black format

* Update .gitignore

* update .gitignore

---------

Co-authored-by: Li Bo <drluodian@gmail.com>
  • Loading branch information
ZhangYuanhan-AI and Luodian authored Jun 16, 2023
1 parent fa97b6c commit 84b1b2f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 19 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,6 @@ azure

*.mp4
checkpoints
pipeline/serve/examples/*.png
pipeline/serve/examples/*.png

tools
121 changes: 106 additions & 15 deletions pipeline/mimicit_utils/mimicit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

from .transforms import *

# sys.path.append("/mnt/lustre/yhzhang/Otter/pipeline/multi_instruct_data_utils")
# from transforms import *


from torch.utils.data import Dataset


Expand Down Expand Up @@ -195,22 +199,37 @@ def process_llava(self, instruction_id, instruction, answer, image_ids, in_conte
all_texts = ""
all_instruction_ids = in_context_example_ids + [instruction_id]
# random.shuffle(all_instruction_ids)
for cur_instruction_id in all_instruction_ids[:]:
cur_instruction_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
cur_instruction = self.dataset[cur_instruction_id]["instruction"]
cur_answer = self.dataset[cur_instruction_id]["answer"]
cur_image = self.images[cur_instruction_image_id]
if "CONV" in instruction_id:
for cur_instruction_id in all_instruction_ids[:]:
cur_instruction_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
cur_instruction = self.dataset[cur_instruction_id]["instruction"]
cur_answer = self.dataset[cur_instruction_id]["answer"]
cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
all_texts += cur_text

all_texts = f"<image>{all_texts}"
cur_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
cur_image = self.images[cur_image_id]
cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
if len(patch_images) == 0:
patch_images = cur_patch_image
else:
patch_images = torch.cat((patch_images, cur_patch_image))

cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
cur_text = f"<image>User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
all_texts += cur_text
patch_images = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
else:
for cur_instruction_id in all_instruction_ids[:]:
cur_instruction_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
cur_instruction = self.dataset[cur_instruction_id]["instruction"]
cur_answer = self.dataset[cur_instruction_id]["answer"]
cur_image = self.images[cur_instruction_image_id]
cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
if len(patch_images) == 0:
patch_images = cur_patch_image
else:
patch_images = torch.cat((patch_images, cur_patch_image))
cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
cur_text = f"<image>User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
all_texts += cur_text
# <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|><image>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
# incontext_text = "<image>User: What does this image descibe? GPT:<answer>The children in the image, along with the rest of the family. They are Skiing. <|endofchunk|>"
# query_text = f"<image>User: What does this image descibe? GPT:<answer>"
Expand Down Expand Up @@ -516,3 +535,75 @@ def copy_tensor(src, dst):
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res


if __name__ == "__main__":
from PIL import Image, ImageFile
from io import BytesIO
import base64
from tqdm import tqdm
import json
import argparse
import sys

sys.path.append("/mnt/petrelfs/zhangyuanhan/Otter/")
from flamingo.modeling_flamingo import FlamingoForConditionalGeneration

parser = argparse.ArgumentParser()
parser.add_argument(
"--multi_instruct_path",
type=str,
help="path to multi_instruct dataset, this should be a glob pattern such as vision_language_examples.tsv",
)
parser.add_argument("--offline", action="store_true")

args = parser.parse_args()

args.multi_instruct_path = "/mnt/petrelfs/zhangyuanhan/data/mimicit/LA/LACR_I2I_instructions.json" # ,/mnt/petrelfs/zhangyuanhan/data/LLaVA-Instruct-150K/LA/LACR_I2I_instructions.json,/mnt/petrelfs/zhangyuanhan/data/LLaVA-Instruct-150K/LA/LACR_T2T_instructions.json,/mnt/petrelfs/zhangyuanhan/data/LLaVA-Instruct-150K/LA/LADD_instructions.json"
args.images_path = "/mnt/petrelfs/zhangyuanhan/data/mimicit/LA/LA_00.json"
args.train_config_path = "/mnt/petrelfs/zhangyuanhan/data/mimicit/LA/LACR_I2I_train.json" # ,/mnt/petrelfs/zhangyuanhan/data/LLaVA-Instruct-150K/LA/LACR_I2I_train.json,/mnt/petrelfs/zhangyuanhan/data/LLaVA-Instruct-150K/LA/LACR_T2T_train.json,/mnt/petrelfs/zhangyuanhan/data/LLaVA-Instruct-150K/LA/LADD_train.json"
args.max_src_length = 256
args.max_tgt_length = 256
args.task = "pretrain"
args.pretrain_seed = 0
args.patch_image_size = 224

from transformers import LlamaTokenizer

with open("/mnt/petrelfs/zhangyuanhan/weights/flamingo_9b_hf/config.json") as f:
config = json.load(f)

tokenizer = LlamaTokenizer.from_pretrained("luodian/llama-7b-hf")

# add <answer> token to tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]})

tokenizer.add_special_tokens({"pad_token": "<PAD>"})

args.tokenizer = tokenizer

cur_multi_instruct_path, cur_images_path, cur_train_config_path = args.multi_instruct_path, args.images_path, args.train_config_path

test_dataset = MimicitDataset(args, cur_multi_instruct_path, cur_images_path, cur_train_config_path)

uniq_id_dict = {}
samples = []
counter = 0
for _ in tqdm(test_dataset):
if counter > 0:
break
counter += 1
samples.append(_)
cur_data = test_dataset.collate(samples)
import pdb

pdb.set_trace()
# import pdb;pdb.set_trace()
# uniq_id, image, caption, question, refs, gt_objects, dataset_name, type = _
# # index = random.choice(positive_caption_dict[uniq_id])
# # prompt_uniq_id, prompt_image, prompt_caption, prompt_question, prompt_refs, prompt_gt_objects, prompt_dataset_name, prompt_type = test_dataset.get_prompt_item(int(index))
# uniq_id, image, caption, question, refs, gt_objects, dataset_name, type = _
# if uniq_id not in uniq_id_dict:
# uniq_id_dict[uniq_id] = 0

# print(uniq_id, image, caption, question, refs, gt_objects, dataset_name, type)
8 changes: 5 additions & 3 deletions pipeline/train/instruction_following.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def train_one_epoch(
labels[labels == answer_token_id] = -100
labels[labels == media_token_id] = -100

# import pdb;pdb.set_trace()
# with accelerator.accumulate(model):
# with autocast():
with accelerator.autocast():
Expand Down Expand Up @@ -483,8 +482,6 @@ def apply_decay(x):
model, optimizer, lr_scheduler, multi_instruct_loaders = accelerator.prepare(model, optimizer, lr_scheduler, multi_instruct_loaders)
model.train()

# device_id = accelerator.device

for epoch in range(resume_from_epoch, args.num_epochs):
for cur_data_loader in multi_instruct_loaders:
cur_data_loader.dataset.set_epoch(epoch)
Expand Down Expand Up @@ -517,6 +514,11 @@ def apply_decay(x):
get_checkpoint(model=unwrapped_model),
f"{args.external_save_dir}/final_weights.pt",
)
# save the config
unwrapped_model.config.save_pretrained(args.external_save_dir)
if model.can_generate():
model_to_save.generation_config.save_pretrained(args.external_save_dir)

if args.report_to_wandb and args.save_checkpoints_to_wandb:
wandb.save(f"{args.external_save_dir}/final_weights.pt")
if args.save_hf_model:
Expand Down

0 comments on commit 84b1b2f

Please sign in to comment.