Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤏 New models for tests #2287

Merged
merged 73 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
2ea93cb
first commit
qgallouedec Oct 27, 2024
6735bf5
uncomment
qgallouedec Oct 27, 2024
53f3091
other tests adaptations
qgallouedec Oct 27, 2024
2db5415
Remove unused variable in test_setup_chat_format
qgallouedec Oct 27, 2024
4202f8d
Merge branch 'main' into tiny-models-for-testing
qgallouedec Oct 28, 2024
7c4069e
Remove unused import statement
qgallouedec Oct 28, 2024
88e371e
Merge branch 'main' into tiny-models-for-testing
qgallouedec Oct 31, 2024
170c950
Merge branch 'main' into tiny-models-for-testing
qgallouedec Oct 31, 2024
73bafb8
Merge branch 'main' into tiny-models-for-testing
qgallouedec Oct 31, 2024
dd5c131
Merge branch 'main' into tiny-models-for-testing
qgallouedec Oct 31, 2024
029d758
style
qgallouedec Oct 31, 2024
ad43271
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 10, 2024
79bb504
Add Bart model
qgallouedec Nov 10, 2024
71d04f0
Update BCOTrainerTester class in test_bco_trainer.py
qgallouedec Nov 10, 2024
2c364c5
Update model IDs and tokenizers in test files
qgallouedec Nov 10, 2024
48bb040
Add new models and processors
qgallouedec Nov 10, 2024
68d1fa1
Update model IDs in test files
qgallouedec Nov 10, 2024
5219d9b
Fix formatting issue in test_dataset_formatting.py
qgallouedec Nov 10, 2024
a45fbcb
Refactor dataset formatting in test_dataset_formatting.py
qgallouedec Nov 10, 2024
e39a75b
Fix dataset sequence length in SFTTrainerTester
qgallouedec Nov 10, 2024
e8c0e43
Remove tokenizer
qgallouedec Nov 11, 2024
3393333
Remove print statement
qgallouedec Nov 11, 2024
162fdb2
Add reward_model_path and sft_model_path to PPO trainer
qgallouedec Nov 11, 2024
8c1effe
Fix tokenizer padding issue
qgallouedec Nov 11, 2024
ea50da1
Add chat template for testing purposes in PaliGemma model
qgallouedec Nov 11, 2024
1f52cec
Update PaliGemma model and chat template
qgallouedec Nov 11, 2024
5855322
Increase learning rate to speed up test
qgallouedec Nov 11, 2024
c627f71
Update model names in run_dpo.sh and run_sft.sh scripts
qgallouedec Nov 11, 2024
607e68f
Update model and dataset names
qgallouedec Nov 11, 2024
d29a272
Fix formatting issue in test_dataset_formatting.py
qgallouedec Nov 11, 2024
17ae6ed
Fix formatting issue in test_dataset_formatting.py
qgallouedec Nov 11, 2024
3c6829d
Remove unused chat template
qgallouedec Nov 11, 2024
e06a597
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 11, 2024
2ffb098
Update model generation script
qgallouedec Nov 14, 2024
a6728f8
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 18, 2024
06ca8fb
additional models
qgallouedec Nov 19, 2024
4fc8172
Update model references in test files
qgallouedec Nov 19, 2024
8c0f901
Merge branch 'tiny-models-for-testing' of https://github.com/huggingf…
qgallouedec Nov 19, 2024
7eef63a
Remove unused imports in test_online_dpo_trainer.py
qgallouedec Nov 19, 2024
0e2f55a
Add is_llm_blender_available import and update reward_tokenizer
qgallouedec Nov 19, 2024
651b845
Refactor test_online_dpo_trainer.py: Move skipped test case decorator
qgallouedec Nov 19, 2024
ef0b761
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 19, 2024
bcd1282
remove models without chat templates
qgallouedec Nov 19, 2024
87a4261
Merge branch 'tiny-models-for-testing' of https://github.com/huggingf…
qgallouedec Nov 19, 2024
c5a8649
Update model names in scripts and tests
qgallouedec Nov 19, 2024
cea219d
Update model_id in test_modeling_value_head.py
qgallouedec Nov 19, 2024
cf5070b
Update model versions in test files
qgallouedec Nov 19, 2024
164d3a4
Fix formatting issue in test_dataset_formatting.py
qgallouedec Nov 20, 2024
647254e
Update embedding model ID in BCOTrainerTester
qgallouedec Nov 20, 2024
761b239
Update test_online_dpo_trainer.py with reward model changes
qgallouedec Nov 20, 2024
5e708bb
Update expected formatted text in test_dataset_formatting.py
qgallouedec Nov 20, 2024
0c6c4e5
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 20, 2024
90f7426
Add reward_tokenizer to TestOnlineDPOTrainer
qgallouedec Nov 20, 2024
7d19973
Merge branch 'tiny-models-for-testing' of https://github.com/huggingf…
qgallouedec Nov 20, 2024
f0bd082
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 20, 2024
bffdc57
fix tests
qgallouedec Nov 20, 2024
290ef65
Merge branch 'tiny-models-for-testing' of https://github.com/huggingf…
qgallouedec Nov 20, 2024
e984765
Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer
qgallouedec Nov 20, 2024
b65ca75
Fix dummy_text format in test_rloo_trainer.py
qgallouedec Nov 20, 2024
6145820
Skip outdated test for chatML data collator
qgallouedec Nov 20, 2024
ae6f210
Add new vision language models
qgallouedec Nov 20, 2024
0a9b7d7
Commented out unused model IDs in test_vdpo_trainer
qgallouedec Nov 20, 2024
48ff8d8
Update model and vision configurations in generate_tiny_models.py and…
qgallouedec Nov 20, 2024
36938c1
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 21, 2024
a3ff8ee
Update model and tokenizer references
qgallouedec Nov 21, 2024
c03aa35
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 24, 2024
2e7695a
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 25, 2024
c851842
Don't push if it already exists
qgallouedec Nov 25, 2024
8ee173a
Add comment explaining test skip
qgallouedec Nov 25, 2024
48a134d
Fix model_exists function call and add new models
qgallouedec Nov 25, 2024
58c033a
Merge branch 'main' into tiny-models-for-testing
qgallouedec Nov 25, 2024
f8b02be
Update LlavaForConditionalGeneration model and processor
qgallouedec Nov 25, 2024
33baa27
`qgallouedec` -> `trl-internal-testing`
qgallouedec Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion commands/run_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
Expand Down
2 changes: 1 addition & 1 deletion commands/run_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_sft/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
DATASET_NAME="stanfordnlp/imdb"
MAX_STEPS=5
BATCH_SIZE=2
Expand Down
2 changes: 1 addition & 1 deletion docs/source/clis.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ We also recommend you passing a YAML config file to configure your training prot

```yaml
model_name_or_path:
trl-internal-testing/tiny-random-LlamaForCausalLM
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
Expand Down
2 changes: 1 addition & 1 deletion examples/cli_configs/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# CUDA_VISIBLE_DEVICES: 0

model_name_or_path:
trl-internal-testing/tiny-random-LlamaForCausalLM
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
Expand Down
193 changes: 193 additions & 0 deletions scripts/generate_tiny_models.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, we have a similar script in transformers in case you want to see the generic case: https://github.com/huggingface/transformers/blob/a0f4f3174f4aee87dd88ffda95579f7450934fc8/utils/create_dummy_models.py#L1403

Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This script generates tiny models used in the TRL library for unit tests. It pushes them to the Hub under the
# `trl-internal-testing` organization.
# This script is meant to be run when adding new tiny model to the TRL library.

from huggingface_hub import HfApi, ModelCard
from transformers import (
AutoProcessor,
AutoTokenizer,
BartConfig,
BartModel,
BloomConfig,
BloomForCausalLM,
CLIPVisionConfig,
CohereConfig,
CohereForCausalLM,
DbrxConfig,
DbrxForCausalLM,
FalconMambaConfig,
FalconMambaForCausalLM,
Gemma2Config,
Gemma2ForCausalLM,
GemmaConfig,
GemmaForCausalLM,
GPT2Config,
GPT2LMHeadModel,
GPTNeoXConfig,
GPTNeoXForCausalLM,
Idefics2Config,
Idefics2ForConditionalGeneration,
LlamaConfig,
LlamaForCausalLM,
LlavaConfig,
LlavaForConditionalGeneration,
LlavaNextConfig,
LlavaNextForConditionalGeneration,
MistralConfig,
MistralForCausalLM,
OPTConfig,
OPTForCausalLM,
PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
Phi3Config,
Phi3ForCausalLM,
Qwen2Config,
Qwen2ForCausalLM,
SiglipVisionConfig,
T5Config,
T5ForConditionalGeneration,
)
from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig


ORGANIZATION = "trl-internal-testing"

MODEL_CARD = """
---
library_name: transformers
tags: [trl]
---

# Tiny {model_class_name}

This is a minimal model built for unit tests in the [TRL](https://github.com/huggingface/trl) library.
"""


api = HfApi()


def push_to_hub(model, tokenizer, suffix=None):
model_class_name = model.__class__.__name__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it matters much, but this won't make a distinction between base and instruct models as they share the same class. If we don't care about this difference in our tests, no need to change it

Copy link
Member Author

@qgallouedec qgallouedec Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I wasn't sure what to do about that. Most of our tests are based on Qwen2.5 in its instruct version. So I don't know how compatible the trainers are to non-instruct versions. Let's keep it like this for the moment.

content = MODEL_CARD.format(model_class_name=model_class_name)
model_card = ModelCard(content)
repo_id = f"{ORGANIZATION}/tiny-{model_class_name}"
if suffix is not None:
repo_id += f"-{suffix}"

if api.repo_exists(repo_id):
print(f"Model {repo_id} already exists, skipping")
else:
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
model_card.push_to_hub(repo_id)


# Decoder models
for model_id, config_class, model_class, suffix in [
("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR is merged, I would be in favour of just relying on a small, curated set of popular architectures for our tests (e.g. Qwen / Mistral / Llama / Gemma) and remove all the rest where appropriate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is this script supposed to be re-run whenever we add a model to the list? If so, I recommend adding a note either at the top of this script or in our contributor guide

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of just relying on a small, curated set of popular architectures

Would you remove any model from this list?

supposed to be re-run whenever we add a model

Yes.

adding a note

I added a note in the script in c851842

("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None),
("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None),
("tiiuae/falcon-7b-instruct", FalconMambaConfig, FalconMambaForCausalLM, None),
("google/gemma-2-2b-it", Gemma2Config, Gemma2ForCausalLM, None),
("google/gemma-7b-it", GemmaConfig, GemmaForCausalLM, None),
("openai-community/gpt2", GPT2Config, GPT2LMHeadModel, None),
("EleutherAI/pythia-14m", GPTNeoXConfig, GPTNeoXForCausalLM, None),
("meta-llama/Meta-Llama-3-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3"),
("meta-llama/Llama-3.1-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3.1"),
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForCausalLM, "3.2"),
("mistralai/Mistral-7B-Instruct-v0.1", MistralConfig, MistralForCausalLM, "0.1"),
("mistralai/Mistral-7B-Instruct-v0.2", MistralConfig, MistralForCausalLM, "0.2"),
("facebook/opt-1.3b", OPTConfig, OPTForCausalLM, None),
("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, None),
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, "2.5"),
("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, "2.5-Coder"),
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
hidden_size=8,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
intermediate_size=32,
)
model = model_class(config)
push_to_hub(model, tokenizer, suffix)


# Encoder-decoder models
for model_id, config_class, model_class, suffix in [
("google/flan-t5-small", T5Config, T5ForConditionalGeneration, None),
("facebook/bart-base", BartConfig, BartModel, None),
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
d_model=16,
encoder_layers=2,
decoder_layers=2,
d_kv=2,
d_ff=64,
num_layers=6,
num_heads=8,
decoder_start_token_id=0,
is_encoder_decoder=True,
)
model = model_class(config)
push_to_hub(model, tokenizer, suffix)


# Vision Language Models
# fmt: off
for model_id, config_class, text_config_class, vision_config_class, model_class in [
("HuggingFaceM4/idefics2-8b", Idefics2Config, MistralConfig, Idefics2VisionConfig, Idefics2ForConditionalGeneration),
("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlamaConfig, CLIPVisionConfig, LlavaForConditionalGeneration),
("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, MistralConfig, CLIPVisionConfig, LlavaNextForConditionalGeneration),
("google/paligemma-3b-pt-224", PaliGemmaConfig, GemmaConfig, SiglipVisionConfig, PaliGemmaForConditionalGeneration),
]:
# fmt: on
processor = AutoProcessor.from_pretrained(model_id)
kwargs = {}
if config_class == PaliGemmaConfig:
kwargs["projection_dim"] = 8
vision_kwargs = {}
if vision_config_class in [CLIPVisionConfig, SiglipVisionConfig]:
vision_kwargs["projection_dim"] = 8
if vision_config_class == CLIPVisionConfig:
vision_kwargs["image_size"] = 336
vision_kwargs["patch_size"] = 14
config = config_class(
text_config=text_config_class(
vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder),
hidden_size=8,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
intermediate_size=32,
),
vision_config=vision_config_class(
hidden_size=8,
num_attention_heads=4,
num_hidden_layers=2,
intermediate_size=32,
**vision_kwargs,
),
**kwargs,
)
model = model_class(config)
push_to_hub(model, processor)
4 changes: 2 additions & 2 deletions tests/slow/testing_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

# TODO: push them under trl-org
MODELS_TO_TEST = [
"trl-internal-testing/tiny-random-LlamaForCausalLM",
"HuggingFaceM4/tiny-random-MistralForCausalLM",
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
]

# We could have also not declared these variables but let's be verbose
Expand Down
36 changes: 17 additions & 19 deletions tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@

class BCOTrainerTester(unittest.TestCase):
def setUp(self):
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

# get embedding model
model_id = "facebook/bart-base"
model_id = "trl-internal-testing/tiny-BartModel"
self.embedding_model = AutoModel.from_pretrained(model_id)
self.embedding_tokenizer = AutoTokenizer.from_pretrained(model_id)

@parameterized.expand(
[
["gpt2", True, True, "standard_unpaired_preference"],
["gpt2", True, False, "standard_unpaired_preference"],
["gpt2", False, True, "standard_unpaired_preference"],
["gpt2", False, False, "standard_unpaired_preference"],
["gpt2", True, True, "conversational_unpaired_preference"],
("qwen", True, True, "standard_unpaired_preference"),
("qwen", True, False, "standard_unpaired_preference"),
("qwen", False, True, "standard_unpaired_preference"),
("qwen", False, False, "standard_unpaired_preference"),
("qwen", True, True, "conversational_unpaired_preference"),
]
)
@require_sklearn
Expand All @@ -73,7 +73,7 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset, config_name):

dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

if name == "gpt2":
if name == "qwen":
model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer
Expand Down Expand Up @@ -160,9 +160,9 @@ def test_tokenize_and_process_tokens(self):
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [5377, 11141])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [318, 1365, 621, 8253, 13])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [31137])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [374, 2664, 1091, 16965, 13])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1])

fn_kwargs = {
Expand All @@ -178,15 +178,13 @@ def test_tokenize_and_process_tokens(self):
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [50256, 5377, 11141])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [31137])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1])
self.assertListEqual(
processed_dataset["completion_input_ids"][0], [50256, 5377, 11141, 318, 1365, 621, 8253, 13, 50256]
)
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1, 1, 1])
self.assertListEqual(
processed_dataset["completion_labels"][0], [-100, -100, -100, 318, 1365, 621, 8253, 13, 50256]
processed_dataset["completion_input_ids"][0], [31137, 374, 2664, 1091, 16965, 13, 151645]
)
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
self.assertListEqual(processed_dataset["completion_labels"][0], [-100, 374, 2664, 1091, 16965, 13, 151645])

@require_sklearn
def test_bco_trainer_without_providing_ref_model(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_best_of_n_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BestOfNSamplerTester(unittest.TestCase):
Tests the BestOfNSampler class
"""

ref_model_name = "trl-internal-testing/dummy-GPT2-correct-vocab"
ref_model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
output_length_sampler = LengthSampler(2, 6)
model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi

class WinRateCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer.pad_token = self.tokenizer.eos_token
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
dataset["train"] = dataset["train"].select(range(8))
Expand Down Expand Up @@ -219,8 +219,8 @@ def test_lora(self):
@require_wandb
class LogCompletionsCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer.pad_token = self.tokenizer.eos_token
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
dataset["train"] = dataset["train"].select(range(8))
Expand Down Expand Up @@ -283,8 +283,8 @@ def test_basic(self):
)
class MergeModelCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")

def test_callback(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CLITester(unittest.TestCase):
def test_sft_cli(self):
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
Expand All @@ -32,7 +32,7 @@ def test_sft_cli(self):
def test_dpo_cli(self):
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
Expand Down
Loading
Loading