Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔀 Add MergeModelCallBack #2282

Merged
merged 48 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
63f7ae0
Create mergekit_utils.py
August-murr Oct 25, 2024
55f3a25
adding mergekit as an optional dependancy
August-murr Oct 25, 2024
9f829fa
adding MergeModel to callbacks
August-murr Oct 25, 2024
f666369
adding mergekit_utils dependencies to callbacks
August-murr Oct 25, 2024
bb52d4b
setting lower bound for mergekit
August-murr Oct 30, 2024
5d9a5a9
setting mergekit lower band to 0.0.5.1
August-murr Oct 30, 2024
9ba148c
Merge branch 'main' into MergeModelCallBack
qgallouedec Nov 5, 2024
0fe2d67
adding support for MergeModelCallBack __init__.py
August-murr Nov 14, 2024
caee4b2
adding support for mergemodelcallback
August-murr Nov 14, 2024
b278e5a
mergemodelcallback tests
August-murr Nov 14, 2024
311a27a
Update callbacks.py
August-murr Nov 14, 2024
ec8b5ec
Update __init__.py
August-murr Nov 14, 2024
147b188
Update __init__.py
August-murr Nov 14, 2024
7f5dbc1
Update test_callbacks.py
August-murr Nov 14, 2024
c17f8e3
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
0d9c8a0
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
ca07b42
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
6229885
using different dataset for tests
August-murr Nov 15, 2024
9ea5a6b
Update trl/mergekit_utils.py
August-murr Nov 15, 2024
1a2b425
Update trl/mergekit_utils.py
August-murr Nov 15, 2024
f0f84eb
Merge branch 'main' into MergeModelCallBack
qgallouedec Nov 15, 2024
c924cb9
Apply suggestions from code review
August-murr Nov 15, 2024
8b5a4a9
replacing get_last_checkpoint
August-murr Nov 15, 2024
34e94ff
Merge branch 'MergeModelCallBack' of https://github.com/August-murr/t…
August-murr Nov 15, 2024
0a5db60
renaming Merge to merge_models
August-murr Nov 15, 2024
906eafa
setting mergers default value to linear
August-murr Nov 15, 2024
8d03608
removing unnecessary docs and comments
August-murr Nov 15, 2024
eb66e99
adding docstring to Mergeconfig
August-murr Nov 15, 2024
eb7b228
adding mergekits link to docstring
August-murr Nov 15, 2024
1057c59
precommit
August-murr Nov 15, 2024
1c85ee5
removing duplicated import
August-murr Nov 16, 2024
18d0388
typos in mergekit_utils docstring
August-murr Nov 16, 2024
ca8f361
fixing tests
August-murr Nov 17, 2024
0a25ee8
making mergemodelcallback tests optional
August-murr Nov 18, 2024
cd76890
Make import optional
qgallouedec Nov 18, 2024
c8afdbc
minor
qgallouedec Nov 18, 2024
26aa418
Merge branch 'main' into MergeModelCallBack
kashif Nov 20, 2024
fb12119
Merge branch 'MergeModelCallBack' of https://github.com/August-murr/t…
qgallouedec Nov 21, 2024
88de59a
use tmp dir in test
qgallouedec Nov 21, 2024
fed6045
sort
qgallouedec Nov 21, 2024
01e38c8
Add import error checks for mergekit extra
qgallouedec Nov 21, 2024
9e7a068
use a common _merge_and_maybe_push method and compat with windows path
qgallouedec Nov 21, 2024
fa5bafe
debug windows
qgallouedec Nov 21, 2024
47ecef5
Update dependencies for mergekit and add test dependencies
qgallouedec Nov 21, 2024
7fe26d5
Add assertion to check if merged folder exists in the last checkpoint
qgallouedec Nov 21, 2024
a57d88a
Fix temporary directory cleanup in test_callbacks.py
qgallouedec Nov 21, 2024
69ea0ed
Add sys import and skip test for Python versions below 3.10 due to cl…
qgallouedec Nov 21, 2024
d89eadc
revert change for debug
qgallouedec Nov 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into MergeModelCallBack
  • Loading branch information
kashif authored Nov 20, 2024
commit 26aa418ed04c367170d00c1c16b8c5d3ae3ffe85
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"diffusers": ["diffusers>=0.18.0"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"llm_judge": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
Expand Down
32 changes: 31 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@

from transformers import is_bitsandbytes_available, is_sklearn_available, is_wandb_available

from trl import is_diffusers_available, is_llm_blender_available
from trl import BaseBinaryJudge, BasePairwiseJudge, is_diffusers_available, is_llm_blender_available
from trl.import_utils import is_mergekit_available


# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use
# in our test suite. We therefore need to implement our own version of this function.
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
"""
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)


def require_diffusers(test_case):
"""
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
Expand Down Expand Up @@ -53,3 +62,24 @@ def require_mergekit(test_case):
Decorator marking a test that requires Mergekit. Skips the test if Mergekit is not available.
"""
return unittest.skipUnless(is_mergekit_available(), "test requires Mergekit")(test_case)


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""

def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]


class RandomPairwiseJudge(BasePairwiseJudge):
"""
Random pairwise judge, for testing purposes.
"""

def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
if not return_scores:
return [random.randint(0, len(completion) - 1) for completion in completions]
else:
return [random.random() for _ in range(len(prompts))]
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.