-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🤏 New models for tests #2287
🤏 New models for tests #2287
Changes from all commits
2ea93cb
6735bf5
53f3091
2db5415
4202f8d
7c4069e
88e371e
170c950
73bafb8
dd5c131
029d758
ad43271
79bb504
71d04f0
2c364c5
48bb040
68d1fa1
5219d9b
a45fbcb
e39a75b
e8c0e43
3393333
162fdb2
8c1effe
ea50da1
1f52cec
5855322
c627f71
607e68f
d29a272
17ae6ed
3c6829d
e06a597
2ffb098
a6728f8
06ca8fb
4fc8172
8c0f901
7eef63a
0e2f55a
651b845
ef0b761
bcd1282
87a4261
c5a8649
cea219d
cf5070b
164d3a4
647254e
761b239
5e708bb
0c6c4e5
90f7426
7d19973
f0bd082
bffdc57
290ef65
e984765
b65ca75
6145820
ae6f210
0a9b7d7
48ff8d8
36938c1
a3ff8ee
c03aa35
2e7695a
c851842
8ee173a
48a134d
58c033a
f8b02be
33baa27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# This script generates tiny models used in the TRL library for unit tests. It pushes them to the Hub under the | ||
# `trl-internal-testing` organization. | ||
# This script is meant to be run when adding new tiny model to the TRL library. | ||
|
||
from huggingface_hub import HfApi, ModelCard | ||
from transformers import ( | ||
AutoProcessor, | ||
AutoTokenizer, | ||
BartConfig, | ||
BartModel, | ||
BloomConfig, | ||
BloomForCausalLM, | ||
CLIPVisionConfig, | ||
CohereConfig, | ||
CohereForCausalLM, | ||
DbrxConfig, | ||
DbrxForCausalLM, | ||
FalconMambaConfig, | ||
FalconMambaForCausalLM, | ||
Gemma2Config, | ||
Gemma2ForCausalLM, | ||
GemmaConfig, | ||
GemmaForCausalLM, | ||
GPT2Config, | ||
GPT2LMHeadModel, | ||
GPTNeoXConfig, | ||
GPTNeoXForCausalLM, | ||
Idefics2Config, | ||
Idefics2ForConditionalGeneration, | ||
LlamaConfig, | ||
LlamaForCausalLM, | ||
LlavaConfig, | ||
LlavaForConditionalGeneration, | ||
LlavaNextConfig, | ||
LlavaNextForConditionalGeneration, | ||
MistralConfig, | ||
MistralForCausalLM, | ||
OPTConfig, | ||
OPTForCausalLM, | ||
PaliGemmaConfig, | ||
PaliGemmaForConditionalGeneration, | ||
Phi3Config, | ||
Phi3ForCausalLM, | ||
Qwen2Config, | ||
Qwen2ForCausalLM, | ||
SiglipVisionConfig, | ||
T5Config, | ||
T5ForConditionalGeneration, | ||
) | ||
from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig | ||
|
||
|
||
ORGANIZATION = "trl-internal-testing" | ||
|
||
MODEL_CARD = """ | ||
--- | ||
library_name: transformers | ||
tags: [trl] | ||
--- | ||
|
||
# Tiny {model_class_name} | ||
|
||
This is a minimal model built for unit tests in the [TRL](https://github.com/huggingface/trl) library. | ||
""" | ||
|
||
|
||
api = HfApi() | ||
|
||
|
||
def push_to_hub(model, tokenizer, suffix=None): | ||
model_class_name = model.__class__.__name__ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure it matters much, but this won't make a distinction between base and instruct models as they share the same class. If we don't care about this difference in our tests, no need to change it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I wasn't sure what to do about that. Most of our tests are based on Qwen2.5 in its instruct version. So I don't know how compatible the trainers are to non-instruct versions. Let's keep it like this for the moment. |
||
content = MODEL_CARD.format(model_class_name=model_class_name) | ||
model_card = ModelCard(content) | ||
repo_id = f"{ORGANIZATION}/tiny-{model_class_name}" | ||
if suffix is not None: | ||
repo_id += f"-{suffix}" | ||
|
||
if api.repo_exists(repo_id): | ||
print(f"Model {repo_id} already exists, skipping") | ||
else: | ||
model.push_to_hub(repo_id) | ||
tokenizer.push_to_hub(repo_id) | ||
model_card.push_to_hub(repo_id) | ||
|
||
|
||
# Decoder models | ||
for model_id, config_class, model_class, suffix in [ | ||
("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After this PR is merged, I would be in favour of just relying on a small, curated set of popular architectures for our tests (e.g. Qwen / Mistral / Llama / Gemma) and remove all the rest where appropriate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, is this script supposed to be re-run whenever we add a model to the list? If so, I recommend adding a note either at the top of this script or in our contributor guide There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Would you remove any model from this list?
Yes.
I added a note in the script in c851842 |
||
("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None), | ||
("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None), | ||
("tiiuae/falcon-7b-instruct", FalconMambaConfig, FalconMambaForCausalLM, None), | ||
("google/gemma-2-2b-it", Gemma2Config, Gemma2ForCausalLM, None), | ||
("google/gemma-7b-it", GemmaConfig, GemmaForCausalLM, None), | ||
("openai-community/gpt2", GPT2Config, GPT2LMHeadModel, None), | ||
("EleutherAI/pythia-14m", GPTNeoXConfig, GPTNeoXForCausalLM, None), | ||
("meta-llama/Meta-Llama-3-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3"), | ||
("meta-llama/Llama-3.1-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3.1"), | ||
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForCausalLM, "3.2"), | ||
("mistralai/Mistral-7B-Instruct-v0.1", MistralConfig, MistralForCausalLM, "0.1"), | ||
("mistralai/Mistral-7B-Instruct-v0.2", MistralConfig, MistralForCausalLM, "0.2"), | ||
("facebook/opt-1.3b", OPTConfig, OPTForCausalLM, None), | ||
("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, None), | ||
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, "2.5"), | ||
("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, "2.5-Coder"), | ||
]: | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
config = config_class( | ||
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), | ||
hidden_size=8, | ||
num_attention_heads=4, | ||
num_key_value_heads=2, | ||
num_hidden_layers=2, | ||
intermediate_size=32, | ||
) | ||
model = model_class(config) | ||
push_to_hub(model, tokenizer, suffix) | ||
|
||
|
||
# Encoder-decoder models | ||
for model_id, config_class, model_class, suffix in [ | ||
("google/flan-t5-small", T5Config, T5ForConditionalGeneration, None), | ||
("facebook/bart-base", BartConfig, BartModel, None), | ||
]: | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
config = config_class( | ||
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), | ||
d_model=16, | ||
encoder_layers=2, | ||
decoder_layers=2, | ||
d_kv=2, | ||
d_ff=64, | ||
num_layers=6, | ||
num_heads=8, | ||
decoder_start_token_id=0, | ||
is_encoder_decoder=True, | ||
) | ||
model = model_class(config) | ||
push_to_hub(model, tokenizer, suffix) | ||
|
||
|
||
# Vision Language Models | ||
# fmt: off | ||
for model_id, config_class, text_config_class, vision_config_class, model_class in [ | ||
("HuggingFaceM4/idefics2-8b", Idefics2Config, MistralConfig, Idefics2VisionConfig, Idefics2ForConditionalGeneration), | ||
("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlamaConfig, CLIPVisionConfig, LlavaForConditionalGeneration), | ||
("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, MistralConfig, CLIPVisionConfig, LlavaNextForConditionalGeneration), | ||
("google/paligemma-3b-pt-224", PaliGemmaConfig, GemmaConfig, SiglipVisionConfig, PaliGemmaForConditionalGeneration), | ||
]: | ||
# fmt: on | ||
processor = AutoProcessor.from_pretrained(model_id) | ||
kwargs = {} | ||
if config_class == PaliGemmaConfig: | ||
kwargs["projection_dim"] = 8 | ||
vision_kwargs = {} | ||
if vision_config_class in [CLIPVisionConfig, SiglipVisionConfig]: | ||
vision_kwargs["projection_dim"] = 8 | ||
if vision_config_class == CLIPVisionConfig: | ||
vision_kwargs["image_size"] = 336 | ||
vision_kwargs["patch_size"] = 14 | ||
config = config_class( | ||
text_config=text_config_class( | ||
vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder), | ||
hidden_size=8, | ||
num_attention_heads=4, | ||
num_key_value_heads=2, | ||
num_hidden_layers=2, | ||
intermediate_size=32, | ||
), | ||
vision_config=vision_config_class( | ||
hidden_size=8, | ||
num_attention_heads=4, | ||
num_hidden_layers=2, | ||
intermediate_size=32, | ||
**vision_kwargs, | ||
), | ||
**kwargs, | ||
) | ||
model = model_class(config) | ||
push_to_hub(model, processor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, we have a similar script in
transformers
in case you want to see the generic case: https://github.com/huggingface/transformers/blob/a0f4f3174f4aee87dd88ffda95579f7450934fc8/utils/create_dummy_models.py#L1403