Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔀 Add MergeModelCallBack #2282

Merged
merged 48 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
63f7ae0
Create mergekit_utils.py
August-murr Oct 25, 2024
55f3a25
adding mergekit as an optional dependancy
August-murr Oct 25, 2024
9f829fa
adding MergeModel to callbacks
August-murr Oct 25, 2024
f666369
adding mergekit_utils dependencies to callbacks
August-murr Oct 25, 2024
bb52d4b
setting lower bound for mergekit
August-murr Oct 30, 2024
5d9a5a9
setting mergekit lower band to 0.0.5.1
August-murr Oct 30, 2024
9ba148c
Merge branch 'main' into MergeModelCallBack
qgallouedec Nov 5, 2024
0fe2d67
adding support for MergeModelCallBack __init__.py
August-murr Nov 14, 2024
caee4b2
adding support for mergemodelcallback
August-murr Nov 14, 2024
b278e5a
mergemodelcallback tests
August-murr Nov 14, 2024
311a27a
Update callbacks.py
August-murr Nov 14, 2024
ec8b5ec
Update __init__.py
August-murr Nov 14, 2024
147b188
Update __init__.py
August-murr Nov 14, 2024
7f5dbc1
Update test_callbacks.py
August-murr Nov 14, 2024
c17f8e3
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
0d9c8a0
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
ca07b42
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
6229885
using different dataset for tests
August-murr Nov 15, 2024
9ea5a6b
Update trl/mergekit_utils.py
August-murr Nov 15, 2024
1a2b425
Update trl/mergekit_utils.py
August-murr Nov 15, 2024
f0f84eb
Merge branch 'main' into MergeModelCallBack
qgallouedec Nov 15, 2024
c924cb9
Apply suggestions from code review
August-murr Nov 15, 2024
8b5a4a9
replacing get_last_checkpoint
August-murr Nov 15, 2024
34e94ff
Merge branch 'MergeModelCallBack' of https://github.com/August-murr/t…
August-murr Nov 15, 2024
0a5db60
renaming Merge to merge_models
August-murr Nov 15, 2024
906eafa
setting mergers default value to linear
August-murr Nov 15, 2024
8d03608
removing unnecessary docs and comments
August-murr Nov 15, 2024
eb66e99
adding docstring to Mergeconfig
August-murr Nov 15, 2024
eb7b228
adding mergekits link to docstring
August-murr Nov 15, 2024
1057c59
precommit
August-murr Nov 15, 2024
1c85ee5
removing duplicated import
August-murr Nov 16, 2024
18d0388
typos in mergekit_utils docstring
August-murr Nov 16, 2024
ca8f361
fixing tests
August-murr Nov 17, 2024
0a25ee8
making mergemodelcallback tests optional
August-murr Nov 18, 2024
cd76890
Make import optional
qgallouedec Nov 18, 2024
c8afdbc
minor
qgallouedec Nov 18, 2024
26aa418
Merge branch 'main' into MergeModelCallBack
kashif Nov 20, 2024
fb12119
Merge branch 'MergeModelCallBack' of https://github.com/August-murr/t…
qgallouedec Nov 21, 2024
88de59a
use tmp dir in test
qgallouedec Nov 21, 2024
fed6045
sort
qgallouedec Nov 21, 2024
01e38c8
Add import error checks for mergekit extra
qgallouedec Nov 21, 2024
9e7a068
use a common _merge_and_maybe_push method and compat with windows path
qgallouedec Nov 21, 2024
fa5bafe
debug windows
qgallouedec Nov 21, 2024
47ecef5
Update dependencies for mergekit and add test dependencies
qgallouedec Nov 21, 2024
7fe26d5
Add assertion to check if merged folder exists in the last checkpoint
qgallouedec Nov 21, 2024
a57d88a
Fix temporary directory cleanup in test_callbacks.py
qgallouedec Nov 21, 2024
69ea0ed
Add sys import and skip test for Python versions below 3.10 due to cl…
qgallouedec Nov 21, 2024
d89eadc
revert change for debug
qgallouedec Nov 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.2.1; sys_platform != 'win32'"],
"llm_judge": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
Expand Down
347 changes: 347 additions & 0 deletions trl/mergekit_utils.py
August-murr marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
import os
import torch

from mergekit.config import MergeConfiguration
from mergekit.merge import MergeOptions, run_merge


from huggingface_hub import HfApi


def get_last_checkpoint_path(output_dir):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the path to the most recent checkpoint in the output_dir.

Args:
output_dir (str): The directory where the checkpoints are saved.

Returns:
str: The path to the most recent checkpoint, or None if no checkpoint exists.
"""
# Ensure the output directory exists
if not os.path.exists(output_dir):
print(f"Output directory '{output_dir}' does not exist.")
return None

# Get all subdirectories in the output directory
checkpoints = [
os.path.join(output_dir, d) for d in os.listdir(output_dir)
if os.path.isdir(os.path.join(output_dir, d)) and "checkpoint" in d
]

# If no checkpoints are found
if not checkpoints:
print("No checkpoints found.")
return None

# Sort checkpoints by their last modification time
last_checkpoint = max(checkpoints, key=os.path.getmtime)

return last_checkpoint


def upload_model_to_hf(folder_path, repo_id):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
api = HfApi()
# Create the repository if it doesn't exist
repo = api.create_repo(repo_id, repo_type="model")

# Upload the folder to the specified repository
api.upload_folder(
folder_path=folder_path,
repo_id=repo.repo_id,
repo_type=repo.repo_type,
)

class MergeConfig:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a docstring like we do for our model training configs:

class DPOConfig(TrainingArguments):

This way it renders nicely in the docs :)

It would be good to provide some links to the various supported methods

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful ressource: #1944 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got GPT to add the docstring.
is it ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

def __init__(self, method: str):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
self.method = method
self.policy_model_path = None # To be set by the callback, not the user
self.target_model_path = None

# Initialize relevant parameters based on the method
if method == 'linear':
self.policy_model_weight = 0.5
self.target_model_weight = 0.5
self.dtype = 'float16'
elif method == 'ties':
self.policy_model_weight = 1.0
self.policy_model_density = [1.0, 0.7, 0.1]
self.target_model_weight = 1.0
self.target_model_density = [1.0]
self.normalize = 1.0
self.dtype = 'float16'
elif method == 'dare_ties':
self.policy_model_weight = 1.0
self.policy_model_density = [1.0, 0.7, 0.1]
self.target_model_weight = 1.0
self.target_model_density = [1.0]
self.normalize = 1.0
self.dtype = 'float16'
elif method == 'slerp':
self.t_values = 0.5
self.dtype = 'float16'
else:
raise ValueError(f"Unknown merge method: {method}")
August-murr marked this conversation as resolved.
Show resolved Hide resolved

def create_merge_config_linear(self):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates a merge configuration for a linear merge of two models with specified weights.

Args:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
policy_model_path (str): Path to the policy model.
target_model_path (str): Path to the target model.
policy_model_weight (float, optional): Weight for the policy model. Defaults to 0.5.
target_model_weight (float, optional): Weight for the target model. Defaults to 0.5.
dtype (str, optional): Data type for the merge. Defaults to 'float16'.

Returns:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
MergeConfiguration: A MergeConfiguration object with the provided settings.
"""
# Create the merge configuration dictionary
merge_config_dict = {
'dtype': self.dtype,
'merge_method': 'linear',
'models': [
{'model': self.policy_model_path, 'parameters': {'weight': self.policy_model_weight}},
{'model': self.target_model_path, 'parameters': {'weight': self.target_model_weight}}
]
}

# Create the MergeConfiguration from the dictionary
merge_config = MergeConfiguration.model_validate(merge_config_dict)

return merge_config

def create_merge_config_ties(self):
August-murr marked this conversation as resolved.
Show resolved Hide resolved

"""
Creates a merge configuration for a TIES merge of two models, with specified weights and densities.

Args:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
policy_model_path (str): Path to the policy model (the one being merged with target).
target_model_path (str): Path to the target base model.
policy_model_weight (float, optional): Weight for the policy model. Defaults to 1.0.
policy_model_density (list, optional): Density values for TIES merge of policy model. Defaults to [1.0, 0.7, 0.1].
target_model_weight (float, optional): Weight for the target model. Defaults to 1.0.
target_model_density (list, optional): Density values for TIES merge of target model. Defaults to [1.0].
normalize (float, optional): Normalization parameter. Defaults to 1.0.
dtype (str, optional): Data type for the merge. Defaults to 'float16'.

Returns:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
MergeConfiguration: A MergeConfiguration object with the provided settings.
"""
# Create the TIES merge configuration dictionary
merge_config_dict = {
'merge_method': 'ties',
'slices': None, # Optional slices if needed
'models': [
{
'model': {
'model': {
'path': self.target_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'parameters': {
'density': self.target_model_density,
'weight': self.target_model_weight
}
},
{
'model': {
'model': {
'path': self.policy_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'parameters': {
'density': self.policy_model_density,
'weight': self.policy_model_weight
}
}
],
'parameters': {
'normalize': self.normalize
},
'base_model': {
'model': {
'path': self.policy_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'dtype': self.dtype,
'tokenizer_source': None,
'tokenizer': None,
'chat_template': None,
'out_dtype': None
}

# Create the MergeConfiguration from the dictionary
merge_config = MergeConfiguration.model_validate(merge_config_dict)

return merge_config

def create_merge_config_dare_ties(self):
August-murr marked this conversation as resolved.
Show resolved Hide resolved

"""
Creates a merge configuration for a DARE TIES merge of two models, with specified weights and densities.

Args:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
policy_model_path (str): Path to the policy model (the one being merged with target).
target_model_path (str): Path to the target base model.
policy_model_weight (float, optional): Weight for the policy model. Defaults to 1.0.
policy_model_density (list, optional): Density values for DARE TIES merge of policy model. Defaults to [1.0, 0.7, 0.1].
target_model_weight (float, optional): Weight for the target model. Defaults to 1.0.
target_model_density (list, optional): Density values for DARE TIES merge of target model. Defaults to [1.0].
normalize (float, optional): Normalization parameter. Defaults to 1.0.
dtype (str, optional): Data type for the merge. Defaults to 'float16'.

Returns:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
MergeConfiguration: A MergeConfiguration object with the provided settings.
"""
# Create the DARE TIES merge configuration dictionary
merge_config_dict = {
'merge_method': 'dare_ties',
'slices': None, # Optional slices if needed
'models': [
{
'model': {
'model': {
'path': self.target_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'parameters': {
'density': self.target_model_density,
'weight': self.target_model_weight
}
},
{
'model': {
'model': {
'path': self.policy_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'parameters': {
'density': self.policy_model_density,
'weight': self.policy_model_weight
}
}
],
'parameters': {
'normalize': self.normalize
},
'base_model': {
'model': {
'path': self.policy_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'dtype': self.dtype,
'tokenizer_source': None,
'tokenizer': None,
'chat_template': None,
'out_dtype': None
}

# Create the MergeConfiguration from the dictionary
merge_config = MergeConfiguration.model_validate(merge_config_dict)

return merge_config

def create_merge_config_slerp(self):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates a merge configuration for a SLERP merge of a model with a base model.

Args:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
model_path (str): Path to the model to be merged.
base_model_path (str): Path to the base model.
t_values (list, optional): List of ConditionalParameter values for SLERP. Defaults to None.
dtype (str, optional): Data type for the merge. Defaults to 'float16'.

Returns:
August-murr marked this conversation as resolved.
Show resolved Hide resolved
MergeConfiguration: A MergeConfiguration object with the provided settings.
"""

# Create the SLERP merge configuration dictionary
merge_config_dict = {
'merge_method': 'slerp',
'slices': None, # Optional slices if needed
'models': [
{
'model': {
'model': {
'path': self.target_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'parameters': None # No specific parameters for SLERP model
}
],
'parameters': {
't': self.t_values # Set the t values for SLERP
},
'base_model': {
'model': {
'path': self.policy_model_path,
'revision': None
},
'lora': None,
'override_architecture': None
},
'dtype': self.dtype,
'tokenizer_source': None,
'tokenizer': None,
'chat_template': None,
'out_dtype': None
}

# Create the MergeConfiguration from the dictionary
merge_config = MergeConfiguration.model_validate(merge_config_dict)

return merge_config


def create(self):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
if self.method == 'linear':
return self.create_merge_config_linear()
elif self.method == 'ties':
return self.create_merge_config_ties()
elif self.method == 'dare_ties':
return self.create_merge_config_dare_ties()
elif self.method == 'slerp':
return self.create_merge_config_slerp()

def Merge(config,out_path):
August-murr marked this conversation as resolved.
Show resolved Hide resolved
"""
Merge two models using mergekit

Args:
config (MergeConfig): The merge configuration.
out_path (str): The output path for the merged model.
"""
run_merge(
config,
out_path=out_path,
options=MergeOptions(
cuda=torch.cuda.is_available(),
copy_tokenizer=True,
lazy_unpickle=False,
low_cpu_memory=False,
),
)
Loading
Loading