Skip to content

Commit

Permalink
Memory Efficient GRPO (unslothai#1773)
Browse files Browse the repository at this point in the history
* Update __init__.py

* Update loader.py

* Update rl.py

* Update rl.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Better TRL handling

* Update rl.py

* Update tokenizer_utils.py

* Auto patching

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update rl.py

* Update tokenizer_utils.py

* Update rl.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update tokenizer_utils.py

* Update rl.py

* Update rl.py

* Update rl.py

* max seq length

* Update rl.py

* Update rl.py

* Patching

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* NEFTune

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Extra replacements

* Update rl_replacements.py

* Update rl.py

* extra RL replacements

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update _utils.py

* Update loader_utils.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* autocast

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update pyproject.toml

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update _utils.py

* Update llama.py

* Update _utils.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* GRPO optimized

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Selective Log softmax

* Fix GRPO bsz

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Fix TRL

* Metrics GRPO

* Update rl_replacements.py

* Update rl_replacements.py

* No compile

* Update rl.py

* Remove docs

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving breaks) (unslothai#1649)

* edit save.py to fix gguf saving breaks.

* add check for .exe or not exe file extension for linux and windows

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* unsloth_num_chunks

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py (unslothai#1754)

Fix typo in comment: know -> now.

This was printed when running the Llama3.1_(8B)-GRPO.ipynb example notebook, so I'd expect others to run into it as well.

* Optional logits

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* fix an import error (unslothai#1767)

* fix an import error

* Delete .gitignore

* Update loader.py

* Update save.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* SamplingParams

* Convert mask to float (unslothai#1762)

* [Windows Support] Add latest `xformers` wheels to pyproject.toml (unslothai#1753)

* Add latest xformers

* Add a couple of lines to docs

* vLLMSamplingParams

* Update __init__.py

* default num_chunks == -1

* Versioning

---------

Co-authored-by: Gennadii Manzhos <105049664+everythingisc00l@users.noreply.github.com>
Co-authored-by: Seth Weidman <seth@sethweidman.com>
Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com>
Co-authored-by: Edd <68678137+Erland366@users.noreply.github.com>
Co-authored-by: Ben <6579034+versipellis@users.noreply.github.com>
  • Loading branch information
6 people authored Feb 20, 2025
1 parent 67d3440 commit f29da34
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 101 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://git
### Windows Installation

To run Unsloth directly on Windows:
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows (be aware that the Windows fork requires PyTorch >= 2.4 and CUDA 12)
- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
```python
trainer = SFTTrainer(
Expand All @@ -202,12 +202,15 @@ trainer = SFTTrainer(
)
```

### Advanced/Troubleshooting

For **advanced installation instructions** or if you see weird errors during installations:

1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
4. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful.
5. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`

## 📜 [Documentation](https://docs.unsloth.ai)
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ triton = [
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
huggingface = [
"unsloth_zoo>=2025.2.5",
"unsloth_zoo>=2025.2.6",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
Expand Down Expand Up @@ -196,6 +196,10 @@ cu126onlytorch260 = [
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118 = [
"unsloth[huggingface]",
Expand Down Expand Up @@ -344,7 +348,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
"unsloth_zoo>=2025.2.5",
"unsloth_zoo>=2025.2.6",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
Expand Down
2 changes: 1 addition & 1 deletion unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
# Check for unsloth_zoo
try:
unsloth_zoo_version = importlib_version("unsloth_zoo")
if Version(unsloth_zoo_version) < Version("2025.2.4"):
if Version(unsloth_zoo_version) < Version("2025.2.6"):
try:
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
except:
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported
from .rl import PatchFastRL
from .rl import PatchFastRL, vLLMSamplingParams
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.2.12"
__version__ = "2025.2.13"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand Down
97 changes: 59 additions & 38 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def LlamaModel_fast_forward(
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
attention_mask = attention_mask[:,:self.max_seq_length] # Must resize!
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
Expand Down Expand Up @@ -774,9 +775,12 @@ def LlamaModel_fast_forward(
self.SWA_mask = True
self.GA_mask = False
elif attention_mask is not None:

# Fixes https://github.com/unslothai/unsloth/issues/853
# Unsloth needs a 2D mask, not a [2, 1, n, n] mask!

# https://github.com/pytorch/pytorch/issues/103749
# Need to convert to float and not using bool
attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min
dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
Expand Down Expand Up @@ -1030,6 +1034,7 @@ def _CausalLM_fast_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = 0,
logits_to_keep: Optional[int] = 0,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:

Expand All @@ -1053,16 +1058,16 @@ def _CausalLM_fast_forward(
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
input_ids = input_ids,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
)
pass
hidden_states = outputs[0]
Expand All @@ -1072,6 +1077,20 @@ def _CausalLM_fast_forward(
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
logit_scaling = getattr(self.config, "logit_scale", 0)
dtype = lm_head.dtype
num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)

# Output last hidden states without logits if asked
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return CausalLMOutputWithPast(
loss = None,
logits = hidden_states,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions= outputs.attentions,
)
pass

if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))
Expand Down Expand Up @@ -1166,11 +1185,11 @@ def _CausalLM_fast_forward(
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss = loss,
logits = logits,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions= outputs.attentions,
)
pass
return _CausalLM_fast_forward
Expand All @@ -1180,28 +1199,30 @@ def _CausalLM_fast_forward(
@torch._disable_dynamo
def PeftModelForCausalLM_fast_forward(
self,
input_ids=None,
causal_mask=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
num_logits_to_keep=0,
input_ids = None,
causal_mask = None,
attention_mask = None,
inputs_embeds = None,
labels = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
task_ids = None,
num_logits_to_keep = 0,
logits_to_keep = 0,
**kwargs,
):
return self.base_model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
input_ids = input_ids,
causal_mask = causal_mask,
attention_mask = attention_mask,
inputs_embeds = inputs_embeds,
labels = labels,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
num_logits_to_keep = num_logits_to_keep,
logits_to_keep = logits_to_keep,
**kwargs,
)
pass
Expand Down Expand Up @@ -1694,9 +1715,9 @@ def from_pretrained(
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
dtype = torch.bfloat16
# elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
# logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
# dtype = torch.bfloat16

assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)

Expand Down
10 changes: 7 additions & 3 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
from .loader_utils import get_model_name
import os, contextlib, sys
try:
from huggingface_hub.utils import get_token
from huggingface_hub import get_token
except:
# Old HF Hub versions <= 0.0.25
from huggingface_hub.utils._token import get_token
try:
from huggingface_hub.utils import get_token
except:
# For older versions of huggingface_hub
from huggingface_hub.utils._token import get_token
pass
pass
from huggingface_hub import HfFileSystem
import importlib.util
Expand Down
5 changes: 0 additions & 5 deletions unsloth/models/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,6 @@
"Qwen/Qwen2.5-VL-72B-Instruct",
"unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit",
),
"unsloth/DeepHermes-3-Llama-3-8B-Preview-unsloth-bnb-4bit" : (
"unsloth/DeepHermes-3-Llama-3-8B-Preview",
"NousResearch/DeepHermes-3-Llama-3-8B-Preview",
"unsloth/DeepHermes-3-Llama-3-8B-Preview-bnb-4bit",
),
"unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit" : (
"unsloth/DeepHermes-3-Llama-3-8B-Preview",
"agentica-org/DeepScaleR-1.5B-Preview",
Expand Down
Loading

0 comments on commit f29da34

Please sign in to comment.