Skip to content

Commit

Permalink
[Feat/Train] Add code for instruction tuning of OF2.0, add more datas…
Browse files Browse the repository at this point in the history
…et (#191)

* update

* updates

* update

* update

* update

* black format

* updates

* update

* update

* update

* black format

* update

* black format

* merge conflicts with main branch and clean code

* modify loss name in wandb log

* update configuration_flamingo.py

---------

Co-authored-by: Bo Li <drluodian@gmail.com>
  • Loading branch information
ZhangYuanhan-AI and Luodian authored Jul 6, 2023
1 parent ca5f4d6 commit f7095f0
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 178 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,5 @@ azure
checkpoints
pipeline/serve/examples/*.png

tools
tools
otter9B-mpt7b-0705/config.json
4 changes: 0 additions & 4 deletions flamingo/configuration_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from flamingo.mpt.configuration_mpt import MPTConfig


from flamingo.falcon.configuration_RW import RWConfig
from flamingo.mpt.configuration_mpt import MPTConfig


logger = logging.get_logger(__name__)


Expand Down
12 changes: 7 additions & 5 deletions flamingo/injecting_mpt_into_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@
elif model_choice == "7B":
config_file = "./flamingo/flamingo-mpt-7B.json"
state_dict_files = [
f"{root_dir}/mpt-7b-instruct/pytorch_model-00001-of-00002.bin",
f"{root_dir}/mpt-7b-instruct/pytorch_model-00002-of-00002.bin",
f"{root_dir}/mpt-7b/pytorch_model-00001-of-00002.bin",
f"{root_dir}/mpt-7b/pytorch_model-00002-of-00002.bin",
]
save_path = f"{save_root_dir}/flamingo-mpt-7B-instruct-init"
save_path = f"{save_root_dir}/flamingo-mpt-7B"
else:
raise ValueError("Invalid model_choice. Choose either '30B' or '7B'.")

config = FlamingoConfig.from_json_file(config_file)

model = FlamingoForConditionalGeneration(config=config)

state_dict = {}
Expand Down Expand Up @@ -94,10 +95,11 @@
)
# print incompatible keys
print(load_msg[1])

if args.flamingo_dir is not None:
state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)

real_vocab_size = config.text_config.vocab_size
# Reshape the token embedding to 50280 for compatible
model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])

Expand All @@ -108,7 +110,7 @@
# print incompatible keys
print(load_msg[1])
# Reshape the token embedding to 50432
model.lang_encoder.resize_token_embeddings(config.text_config.vocab_size)
model.lang_encoder.resize_token_embeddings(real_vocab_size)

print(f"Saving model to {save_path}...")
model.save_pretrained(save_path, max_shard_size="10GB")
197 changes: 197 additions & 0 deletions otter/Otter-MPT7B-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"_commit_hash": null,
"_name_or_path": "/mnt/petrelfs/zhangyuanhan/weights/flamingo-mpt-7B",
"architectures": [
"FlamingoForConditionalGeneration"
],
"cross_attn_every_n_layers": 4,
"model_type": "otter",
"only_attend_previous": true,
"text_config": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": [
"MPTForCausalLM"
],
"attn_config": {
"alibi": true,
"alibi_bias_max": 8,
"attn_impl": "torch",
"attn_pdrop": 0,
"attn_type": "multihead_attention",
"attn_uses_sequence_id": false,
"clip_qkv": null,
"prefix_lm": false,
"qk_ln": false,
"softmax_scale": null
},
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"d_model": 4096,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": false,
"emb_pdrop": 0,
"embedding_fraction": 1.0,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": null,
"expansion_ratio": 4,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_size": 4096,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"init_config": {
"emb_init_std": null,
"emb_init_uniform_lim": null,
"fan_mode": "fan_in",
"init_div_is_residual": true,
"init_gain": 0,
"init_nonlinearity": "relu",
"init_std": 0.02,
"name": "kaiming_normal_",
"verbose": 0
},
"init_device": "cpu",
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"learned_pos_emb": true,
"length_penalty": 1.0,
"logit_scale": null,
"max_length": 20,
"max_seq_len": 2048,
"min_length": 0,
"model_type": "mpt",
"n_heads": 32,
"n_layers": 32,
"no_bias": true,
"no_repeat_ngram_size": 0,
"norm_type": "low_precision_layernorm",
"num_beam_groups": 1,
"num_beams": 1,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"resid_pdrop": 0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"tokenizer_name": "EleutherAI/gpt-neox-20b",
"top_k": 50,
"top_p": 1.0,
"torch_dtype": "bfloat16",
"torchscript": false,
"transformers_version": "4.30.1",
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": false,
"verbose": 0,
"vocab_size": 50432
},
"torch_dtype": "float32",
"transformers_version": null,
"use_media_placement_augmentation": true,
"vision_config": {
"_name_or_path": "openai/clip-vit-large-patch14",
"add_cross_attention": false,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "clip_vision_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 3,
"num_hidden_layers": 24,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 14,
"prefix": null,
"problem_type": null,
"projection_dim": 512,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.30.1",
"typical_p": 1.0,
"use_bfloat16": false
}
}
18 changes: 17 additions & 1 deletion otter/configuration_otter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.clip import CLIPVisionConfig

from flamingo.falcon.configuration_RW import RWConfig
from flamingo.mpt.configuration_mpt import MPTConfig

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -67,7 +70,20 @@ def __init__(
logger.info("text_config is None. Initializing the text config with default values.")

self.vision_config = CLIPVisionConfig(**vision_config)
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
if "architectures" in text_config.keys() and text_config["architectures"] != None:
if text_config["architectures"][0] == "MPTForCausalLM":
self.text_config = MPTConfig(**text_config)
elif text_config["architectures"][0] == "RWForCausalLM":
self.text_config = RWConfig(**text_config)
elif text_config["architectures"][0] == "LlamaForCausalLM":
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
else:
import pdb

pdb.set_trace()
else:
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)

self.cross_attn_every_n_layers = cross_attn_every_n_layers
self.use_media_placement_augmentation = use_media_placement_augmentation
self.only_attend_previous = only_attend_previous
Expand Down
Loading

0 comments on commit f7095f0

Please sign in to comment.