Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔀 Add MergeModelCallBack #2282

Merged
merged 48 commits into from
Nov 21, 2024

Conversation

August-murr
Copy link
Collaborator

@August-murr August-murr commented Oct 25, 2024

What does this PR do?

Fixes #2241

Since the focus was on replicating the checkpoint merging methods from the paper, I have covered only Linear, TIES, SLERP, and DARE-TIES merging methods. These were the ones used in the paper and were primarily tested with the DPOTrainer.

Please provide feedback on any issues and suggest improvements regarding the structure of the files in this PR, as well as the documentation, to enhance clarity.

Here's an example of usage:

pip install trl[mergekit]

since it's an optional dependency

from trl.mergekit_utils import MergeConfig
Trainer = DPOTrainer(...)
config = MergeConfig("ties")
config.target_model_path = "path_to_target_model" #if none is provided, the reference model will be used
config.policy_model_weight = 0.7 #optional
config.target_model_weight = 0.3 #optional 
merge_callback = MergeModelCallBack(config,push_to_hub=True,merge_at_every_checkpoint=True)
Trainer.add_callback(callback)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

@lewtun
@qgallouedec

setup.py Outdated Show resolved Hide resolved
@August-murr
Copy link
Collaborator Author

@lewtun
@qgallouedec
Feedback would be appreciated!

@August-murr August-murr marked this pull request as ready for review October 28, 2024 12:54
@qgallouedec
Copy link
Member

Thanks a lot @August-murr for the work. Can you add documentation, and test?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

trl/trainer/callbacks.py Outdated Show resolved Hide resolved
@qgallouedec qgallouedec changed the title MergeModelCallBack 🔀 Add MergeModelCallBack Nov 5, 2024
@August-murr
Copy link
Collaborator Author

Thanks a lot @August-murr for the work. Can you add documentation, and test?

I've already added most of the docs, as for the tests, unfortunately I won't be able to do it for a few days and if nobody else added them, I'll do it later.

@August-murr
Copy link
Collaborator Author

The tests I added validate the success of the merge and I could expand it if necessary.
I also added docs to the callbacks file but was unable to produce the HTML file similar to the callback docs so I'd appreciate it if you could confirm whether the docs are properly generated or not.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this callback @August-murr - the implementation is very clean! I've left some minor comments on the PR and then I think it's good to merge (pun intended :))

Can you also fix the tests - it seems to be an issue with the path creation on Windows. Also can you add the callback to the docs here: https://github.com/huggingface/trl/blob/main/docs/source/callbacks.mdx

trl/mergekit_utils.py Show resolved Hide resolved
trl/trainer/callbacks.py Outdated Show resolved Hide resolved
trl/trainer/callbacks.py Outdated Show resolved Hide resolved
trl/trainer/callbacks.py Outdated Show resolved Hide resolved
tests/test_callbacks.py Outdated Show resolved Hide resolved
trl/mergekit_utils.py Outdated Show resolved Hide resolved
trl/mergekit_utils.py Outdated Show resolved Hide resolved
trl/mergekit_utils.py Outdated Show resolved Hide resolved
trl/trainer/callbacks.py Outdated Show resolved Hide resolved
trl/trainer/callbacks.py Outdated Show resolved Hide resolved
August-murr and others added 5 commits November 15, 2024 13:49
removing ## from docs

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
removing ## from docs

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
adding types

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@qgallouedec
Copy link
Member

qgallouedec commented Nov 18, 2024

Another question that came up during the review: why have a new configuration class when we can use the mergekit one directly? I'm afraid of confusing the user, tempted to use :

from mergekit import MergeConfiguration
from trl import MergeModelCallback
    
merge_callback = MergeModelCallback(MergeConfiguration())

@August-murr
Copy link
Collaborator Author

Another question that came up during the review: why have a new configuration class when we can use the mergekit one directly? I'm afraid of confusing the user, tempted to use :

from mergekit import MergeConfiguration
from trl import MergeModelCallback
    
merge_callback = MergeModelCallback(MergeConfiguration())

Actually, ease of use for the user was the reason why I had to write the class in mergekit_utils since mergekit uses a yaml file to get it's Merge config, which is easier to implement but more complicated for the user.

and if you wanted to use MergeConfiguration directly from mergekit:

from mergekit.config import MergeConfiguration

merge_config_dict = {
    "dtype": "float16",
    "merge_method": "linear",
    "models": [
        {"model": "path_to_model_1", "parameters": {"weight": 0.4}},
        {"model": "path_to_model_2", "parameters": {"weight": 0.6}},
    ],
}

config = MergeConfiguration.model_validate(merge_config_dict)

As you add more parameters to the configuration, the dictionary becomes increasingly nested.

The current implementation, although harder to maintain, simplifies everything for the user:

from trl.mergekit_utils import MergeConfig
config = MergeConfig("linear")
config.policy_model_weight = 0.4
config.target_model_weight = 0.6

@qgallouedec
Copy link
Member

That makes sense.
Do you think we can get the best of both worlds by making trl.MergeConfig inherits from mergekit.config.MergeConfigurationMergeConfig?

@August-murr
Copy link
Collaborator Author

That makes sense.
Do you think we can get the best of both worlds by making trl.MergeConfig inherits from mergekit.config.MergeConfigurationMergeConfig?

I'll figure it out.

@August-murr
Copy link
Collaborator Author

That makes sense. Do you think we can get the best of both worlds by making trl.MergeConfig inherits from mergekit.config.MergeConfigurationMergeConfig?

The main issue with using Mergekit's MergeConfiguration directly is that it’s not really designed to work on its own. It relies heavily on dictionaries, usually loaded from a YAML file, or using a bunch of classes from mergekit to set things up:

class MergeConfiguration(BaseModel):
    merge_method: str
    slices: Optional[List[OutputSliceDefinition]] = None
    models: Optional[List[InputModelDefinition]] = None
    parameters: Optional[Dict[str, ParameterSetting]] = None
    base_model: Optional[ModelReference] = None
    dtype: Optional[str] = None
    tokenizer_source: Union[
        Literal["union"], Literal["base"], ModelReference, None
    ] = None
    tokenizer: Optional[TokenizerConfig] = None
    chat_template: Optional[str] = None
    out_dtype: Optional[str] = None

If someone wanted to set up the configuration manually, they’d either need to:

  1. Write or add to a YAML file, or
  2. Write a big, nested dictionary themselves (which only gets more complicated as you add more details), or
  3. Use multiple classes from mergekit (e.g., OutputSliceDefinition, InputModelDefinition, etc.), as seen here.

Neither option is user-friendly.

I admit the current implementation looks messy, but the alternative would create more complications for the user. Maybe in future versions, the Mergekit team will make MergeConfiguration simpler and easier to work with.

@August-murr
Copy link
Collaborator Author

@qgallouedec
Anything else you'd want me to do?

@qgallouedec
Copy link
Member

qgallouedec commented Nov 21, 2024

LGTM thanks!
I've just applied some minor refinements:

  • compat with windows file path
  • use tmp dir in tests
  • sort imports and function
  • common method for saving and pushing in the callback
  • add "trl" to model tags

@August-murr
Copy link
Collaborator Author

@qgallouedec
About the failed tests:
The tests do not fail on Ubuntu; they only fail on Windows. I realized that the issue arose from a permission error from the temporary directory when trying to delete the merged files, specifically the model.safetensors.

@qgallouedec
Copy link
Member

@qgallouedec About the failed tests: The tests do not fail on Ubuntu; they only fail on Windows. I realized that the issue arose from a permission error from the temporary directory when trying to delete the merged files, specifically the model.safetensors.

Ah thanks, I was debugging, but I don't have access to windows vm right now (explains fa5bafe). Any idea how to solve it?

@qgallouedec
Copy link
Member

Found a solution with a57d88a

@qgallouedec qgallouedec merged commit 6578fdc into huggingface:main Nov 21, 2024
13 checks passed
@August-murr
Copy link
Collaborator Author

@qgallouedec
Sorry I wasn't able to sort it out myself.

@qgallouedec
Copy link
Member

No worry, thanks a lot for this nice addition!

@August-murr August-murr deleted the MergeModelCallBack branch December 13, 2024 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add model merging callback
6 participants