From 517396664f044952edd8534177d9631dc7135dd5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 24 Sep 2024 17:40:56 +0800 Subject: [PATCH] Enable BNB multi-backend support (#31098) * enable cpu bnb path * fix style * fix code style * fix 4 bit path * Update src/transformers/utils/import_utils.py Co-authored-by: Aarni Koskela * add multi backend refactor tests * fix style * tweak 4bit quantizer + fix corresponding tests * tweak 8bit quantizer + *try* fixing corresponding tests * fix dequant bnb 8bit * account for Intel CPU in variability of expected outputs * enable cpu and xpu device map * further tweaks to account for Intel CPU * fix autocast to work with both cpu + cuda * fix comments * fix comments * switch to testing_utils.torch_device * allow for xpu in multi-gpu tests * fix tests 4bit for CPU NF4 * fix bug with is_torch_xpu_available needing to be called as func * avoid issue where test reports attr err due to other failure * fix formatting * fix typo from resolving of merge conflict * polish based on last PR review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix CI * Update src/transformers/integrations/integration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/integrations/integration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix error log * fix error msg * add \n in error log * make quality * rm bnb cuda restriction in doc * cpu model don't need dispatch * fix doc * fix style * check cuda avaliable in testing * fix tests * Update docs/source/en/model_doc/chameleon.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: Aarni Koskela * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Aarni Koskela * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Aarni Koskela * fix doc * fix check multibackends * fix import sort * remove check torch in bnb * docs: update bitsandbytes references with multi-backend info * docs: fix small mistakes in bnb paragraph * run formatting * reveret bnb check * move bnb multi-backend check to import_utils * Update src/transformers/utils/import_utils.py Co-authored-by: Aarni Koskela * fix bnb check * minor fix for bnb * check lib first * fix code style * Revert "run formatting" This reverts commit ac108c6d6b34f45a5745a736ba57282405cfaa61. * fix format * give warning when bnb version is low and no cuda found] * fix device assignment check to be multi-device capable * address akx feedback on get_avlbl_dev fn * revert partially, as we don't want the function that public, as docs would be too much (enforced) --------- Co-authored-by: Aarni Koskela Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/llm_tutorial_optimization.md | 2 +- docs/source/en/model_doc/chameleon.md | 12 ++- docs/source/en/model_doc/llava_next.md | 12 ++- docs/source/en/model_doc/llava_next_video.md | 12 ++- docs/source/en/model_doc/llava_onevision.md | 14 ++- docs/source/en/model_doc/mixtral.md | 2 +- docs/source/en/model_doc/video_llava.md | 12 ++- docs/source/en/model_memory_anatomy.md | 2 +- docs/source/en/perf_train_gpu_one.md | 2 +- docs/source/en/quantization/bitsandbytes.md | 8 ++ docs/source/en/quantization/overview.md | 16 ++- src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/bitsandbytes.py | 98 ++++++++++++++++- .../quantizers/quantizer_bnb_4bit.py | 22 +++- .../quantizers/quantizer_bnb_8bit.py | 23 ++-- src/transformers/testing_utils.py | 48 ++++++++- src/transformers/utils/__init__.py | 32 ++++++ src/transformers/utils/import_utils.py | 22 +++- tests/quantization/bnb/test_4bit.py | 100 ++++++++++++------ tests/quantization/bnb/test_mixed_int8.py | 95 +++++++++++------ 20 files changed, 436 insertions(+), 100 deletions(-) diff --git a/docs/source/en/llm_tutorial_optimization.md b/docs/source/en/llm_tutorial_optimization.md index a675a6de39a2fc..9d3d8ad6ba8b86 100644 --- a/docs/source/en/llm_tutorial_optimization.md +++ b/docs/source/en/llm_tutorial_optimization.md @@ -181,7 +181,7 @@ for every matrix multiplication. Dequantization and re-quantization is performed Therefore, inference time is often **not** reduced when using quantized weights, but rather increases. Enough theory, let's give it a try! To quantize the weights with Transformers, you need to make sure that -the [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) library is installed. +the [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes) library is installed. ```bash !pip install bitsandbytes diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index 28ec01ad615871..2fa9c1db866c7e 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -128,7 +128,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza ### Quantization using Bitsandbytes -The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Simply change the snippet above with: ```python from transformers import ChameleonForConditionalGeneration, BitsAndBytesConfig diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index d0558be76467a2..f04827cc7d5f74 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -233,7 +233,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza ### Quantization using Bitsandbytes -The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes`, and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Simply change the snippet above with: ```python from transformers import LlavaNextForConditionalGeneration, BitsAndBytesConfig diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md index 48e50f950621e8..fe905dfb7932ab 100644 --- a/docs/source/en/model_doc/llava_next_video.md +++ b/docs/source/en/model_doc/llava_next_video.md @@ -205,7 +205,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. This allows for efficient deployment on resource-constrained cases. -First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: +First, make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Then simply load the quantized model by adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: ```python diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index 64a127abca4c28..717784da738d8c 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -264,9 +264,19 @@ processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spac ## Model optimization -### Quantization using Bitsandbytes +### Quantization using bitsandbytes -The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Simply change the snippet above with: ```python from transformers import LlavaOnevisionForConditionalGeneration, BitsAndBytesConfig diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md index 26eff8ec21ad7a..71c7d7921ef005 100644 --- a/docs/source/en/model_doc/mixtral.md +++ b/docs/source/en/model_doc/mixtral.md @@ -141,7 +141,7 @@ The Flash Attention-2 model uses also a more memory efficient cache slicing mech As the Mixtral model has 45 billion parameters, that would require about 90GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), a single A100 with 40GB of RAM is enough to fit the entire model, as in that case only about 27 GB of RAM is required. -Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the BitsAndyBytes quantization (but refer to [this page](../quantization.md) for other quantization methods): +Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods): ```python >>> import torch diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md index f098e82a177670..1c4b5b4b874dd7 100644 --- a/docs/source/en/model_doc/video_llava.md +++ b/docs/source/en/model_doc/video_llava.md @@ -139,7 +139,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. his allows for efficient deployment on resource-constrained cases. -First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: +First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: ```python diff --git a/docs/source/en/model_memory_anatomy.md b/docs/source/en/model_memory_anatomy.md index c1d9d4c54bc728..44c197aae5cfe4 100644 --- a/docs/source/en/model_memory_anatomy.md +++ b/docs/source/en/model_memory_anatomy.md @@ -233,7 +233,7 @@ Let's look at the details. **Optimizer States:** - 8 bytes * number of parameters for normal AdamW (maintains 2 states) -- 2 bytes * number of parameters for 8-bit AdamW optimizers like [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +- 2 bytes * number of parameters for 8-bit AdamW optimizers like [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) - 4 bytes * number of parameters for optimizers like SGD with momentum (maintains only 1 state) **Gradients** diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index c90f2ca58483bf..364fc46544c6fd 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -284,7 +284,7 @@ training_args = TrainingArguments(per_device_train_batch_size=4, optim="adamw_bn However, we can also use a third-party implementation of the 8-bit optimizer for demonstration purposes to see how that can be integrated. -First, follow the installation guide in the GitHub [repo](https://github.com/TimDettmers/bitsandbytes) to install the `bitsandbytes` library +First, follow the installation guide in the GitHub [repo](https://github.com/bitsandbytes-foundation/bitsandbytes) to install the `bitsandbytes` library that implements the 8-bit Adam optimizer. Next you need to initialize the optimizer. This involves two steps: diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md index 334b6145e537fe..e9447555e82449 100644 --- a/docs/source/en/quantization/bitsandbytes.md +++ b/docs/source/en/quantization/bitsandbytes.md @@ -38,6 +38,14 @@ pip install --upgrade accelerate transformers + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers. diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 9eb74793a12797..97bb0cf5326308 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -49,7 +49,7 @@ Use the table below to help you decide which quantization method to use. |-------------------------------------|-------------------------|-----|----------|----------------|-----------------------|-------------------------|----------------|-------------------------------------|--------------|------------------------|---------------------------------------------| | [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 / 2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM | | [AWQ](./awq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ | -| [bitsandbytes](./bitsandbytes) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 4 / 8 | 🟢 | 🟢 | 🟢 | https://github.com/TimDettmers/bitsandbytes | +| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 * | 🟢 | 🟡 * | 🔴 ** | 🔴 (soon!) | 4 / 8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes | | [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ | | GGUF / GGML (llama.cpp) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 1 - 8 | 🔴 | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp | | [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ | @@ -57,3 +57,17 @@ Use the table below to help you decide which quantization method to use. | [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | | [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | + + + +\* bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + + + +\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships. + + diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 0a28ff022a536b..00bbcf2d060fe9 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -31,6 +31,7 @@ "replace_with_bnb_linear", "set_module_8bit_tensor_to_device", "set_module_quantized_tensor_to_device", + "validate_bnb_backend_availability", ], "deepspeed": [ "HfDeepSpeedConfig", @@ -124,6 +125,7 @@ replace_with_bnb_linear, set_module_8bit_tensor_to_device, set_module_quantized_tensor_to_device, + validate_bnb_backend_availability, ) from .deepspeed import ( HfDeepSpeedConfig, diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index f37ca9a2650bf3..2501261b55e091 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -6,7 +6,15 @@ from packaging import version -from ..utils import is_accelerate_available, is_bitsandbytes_available, logging +from ..utils import ( + get_available_devices, + is_accelerate_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_ipex_available, + is_torch_available, + logging, +) if is_bitsandbytes_available(): @@ -332,7 +340,7 @@ def get_keys_to_not_convert(model): # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 -def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): +def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None): """ Helper function to dequantize 4bit or 8bit bnb weights. @@ -350,7 +358,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): logger.warning_once( f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" ) - return output_tensor + return output_tensor.to(dtype) if state.SCB is None: state.SCB = weight.SCB @@ -361,7 +369,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): if state.CxB is None: state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) - return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype) def _create_accelerate_new_hook(old_hook): @@ -383,6 +391,7 @@ def _create_accelerate_new_hook(old_hook): def _dequantize_and_replace( model, + dtype, modules_to_not_convert=None, current_key_name=None, quantization_config=None, @@ -422,7 +431,7 @@ def _dequantize_and_replace( else: state = None - new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state)) if bias is not None: new_module.bias = bias @@ -441,6 +450,7 @@ def _dequantize_and_replace( if len(list(module.children())) > 0: _, has_been_replaced = _dequantize_and_replace( module, + dtype, modules_to_not_convert, current_key_name, quantization_config, @@ -458,6 +468,7 @@ def dequantize_and_replace( ): model, has_been_replaced = _dequantize_and_replace( model, + model.dtype, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config, ) @@ -468,3 +479,80 @@ def dequantize_and_replace( ) return model + + +def _validate_bnb_multi_backend_availability(raise_exception): + import bitsandbytes as bnb + + bnb_supported_devices = getattr(bnb, "supported_torch_devices", set()) + available_devices = get_available_devices() + + if available_devices == {"cpu"} and not is_ipex_available(): + from importlib.util import find_spec + + if find_spec("intel_extension_for_pytorch"): + logger.warning( + "You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible." + ) + + available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment + + if not available_devices.intersection(bnb_supported_devices): + if raise_exception: + bnb_supported_devices_with_info = set( # noqa: C401 + '"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)' + if device == "cpu" + else device + for device in bnb_supported_devices + ) + err_msg = ( + f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. " + "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + + logger.error(err_msg) + raise RuntimeError(err_msg) + + logger.warning("No supported devices found for bitsandbytes multi-backend.") + return False + + logger.debug("Multi-backend validation successful.") + return True + + +def _validate_bnb_cuda_backend_availability(raise_exception): + if not is_torch_available(): + return False + + import torch + + if not torch.cuda.is_available(): + log_msg = ( + "CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. " + "Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + if raise_exception: + logger.error(log_msg) + raise RuntimeError(log_msg) + + logger.warning(log_msg) + return False + + logger.debug("CUDA backend validation successful.") + return True + + +def validate_bnb_backend_availability(raise_exception=False): + """ + Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not. + """ + if not is_bitsandbytes_available(): + if importlib.util.find_spec("bitsandbytes") and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.43.1"): + return _validate_bnb_cuda_backend_availability(raise_exception) + return False + + if is_bitsandbytes_multi_backend_available(): + return _validate_bnb_multi_backend_availability(raise_exception) + return _validate_bnb_cuda_backend_availability(raise_exception) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 827ca310f35a1a..73e7664aeb884d 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -29,6 +29,7 @@ is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_xpu_available, logging, ) @@ -65,8 +66,6 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available(): raise ImportError( f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" @@ -76,6 +75,12 @@ def validate_environment(self, *args, **kwargs): "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) + from ..integrations import validate_bnb_backend_availability + from ..utils import is_bitsandbytes_multi_backend_available + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + validate_bnb_backend_availability(raise_exception=True) + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): raise ValueError( "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" @@ -91,7 +96,9 @@ def validate_environment(self, *args, **kwargs): device_map_without_lm_head = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -255,10 +262,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.cuda.current_device()} + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + elif is_torch_xpu_available(): + device_map = {"": f"xpu:{torch.xpu.current_device()}"} + else: + device_map = {"": "cpu"} logger.info( "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " + f"Setting device_map to {device_map}. " "If you want to use the model for inference, please set device_map ='auto' " ) return device_map diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index dbfceac2de8667..65d97716d02cf8 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -27,6 +27,7 @@ is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_xpu_available, logging, ) from .quantizers_utils import get_module_from_name @@ -64,9 +65,6 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") - if not is_accelerate_available(): raise ImportError( f"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" @@ -76,6 +74,12 @@ def validate_environment(self, *args, **kwargs): "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) + from ..integrations import validate_bnb_backend_availability + from ..utils import is_bitsandbytes_multi_backend_available + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + validate_bnb_backend_availability(raise_exception=True) + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): raise ValueError( "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" @@ -91,7 +95,9 @@ def validate_environment(self, *args, **kwargs): device_map_without_lm_head = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -127,10 +133,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.cuda.current_device()} + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + elif is_torch_xpu_available(): + device_map = {"": f"xpu:{torch.xpu.current_device()}"} + else: + device_map = {"": "cpu"} logger.info( "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " + f"Setting device_map to {device_map}. " "If you want to use the model for inference, please set device_map ='auto' " ) return device_map diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index e0608acfeb8a54..2cc0fa5710895a 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -61,6 +61,7 @@ is_auto_gptq_available, is_av_available, is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, is_bs4_available, is_cv2_available, is_cython_available, @@ -224,6 +225,17 @@ def parse_int_from_env(key, default=None): _run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) +def get_device_count(): + import torch + + if is_torch_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() + + return num_devices + + def is_pt_tf_cross_test(test_case): """ Decorator marking a test as a test that control interactions between PyTorch and TensorFlow. @@ -331,6 +343,29 @@ def tooslow(test_case): return unittest.skip(reason="test is too slow")(test_case) +def skip_if_not_implemented(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except NotImplementedError as e: + raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}") + + return wrapper + + +def apply_skip_if_not_implemented(cls): + """ + Class decorator to apply @skip_if_not_implemented to all test methods. + """ + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, skip_if_not_implemented(attr)) + return cls + + def custom_tokenizers(test_case): """ Decorator marking a test for a custom tokenizer. @@ -738,9 +773,9 @@ def require_torch_multi_gpu(test_case): if not is_torch_available(): return unittest.skip(reason="test requires PyTorch")(test_case) - import torch + device_count = get_device_count() - return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) + return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case) def require_torch_multi_accelerator(test_case): @@ -947,6 +982,15 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): + """ + Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. + """ + if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available(): + return test_case + return require_torch_gpu(test_case) + + def require_torch_accelerator(test_case): """Decorator marking a test that requires an accessible accelerator and PyTorch.""" return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")( diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index eee350349f5565..93976c2375565b 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import lru_cache +from typing import FrozenSet + from huggingface_hub import get_full_repo_name # for backward compatibility from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility from packaging import version @@ -118,6 +121,7 @@ is_auto_gptq_available, is_av_available, is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, is_bs4_available, is_coloredlogs_available, is_cv2_available, @@ -277,3 +281,31 @@ def check_min_version(min_version): + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other " "versions of HuggingFace Transformers." ) + + +@lru_cache() +def get_available_devices() -> FrozenSet[str]: + """ + Returns a frozenset of devices available for the current PyTorch installation. + """ + devices = {"cpu"} # `cpu` is always supported as a device in PyTorch + + if is_torch_cuda_available(): + devices.add("cuda") + + if is_torch_mps_available(): + devices.add("mps") + + if is_torch_xpu_available(): + devices.add("xpu") + + if is_torch_npu_available(): + devices.add("npu") + + if is_torch_mlu_available(): + devices.add("mlu") + + if is_torch_musa_available(): + devices.add("musa") + + return frozenset(devices) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ad8b649aaa4e84..289dd02fdd52c5 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -849,15 +849,29 @@ def is_torch_xpu_available(check_device=False): return hasattr(torch, "xpu") and torch.xpu.is_available() +@lru_cache() def is_bitsandbytes_available(): - if not is_torch_available(): + if not is_torch_available() or not _bitsandbytes_available: return False - # bitsandbytes throws an error if cuda is not available - # let's avoid that by adding a simple check import torch - return _bitsandbytes_available and torch.cuda.is_available() + # `bitsandbytes` versions older than 0.43.1 eagerly require CUDA at import time, + # so those versions of the library are practically only available when CUDA is too. + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.1"): + return torch.cuda.is_available() + + # Newer versions of `bitsandbytes` can be imported on systems without CUDA. + return True + + +def is_bitsandbytes_multi_backend_available() -> bool: + if not is_bitsandbytes_available(): + return False + + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) def is_flash_attn_2_available(): diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 785402b3f798ee..0ac9b3d82fc7b0 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -30,12 +30,13 @@ pipeline, ) from transformers.testing_utils import ( + apply_skip_if_not_implemented, is_bitsandbytes_available, is_torch_available, require_accelerate, require_bitsandbytes, require_torch, - require_torch_gpu, + require_torch_gpu_if_bnb_not_multi_backend_enabled, require_torch_multi_gpu, slow, torch_device, @@ -85,7 +86,7 @@ def forward(self, input, *args, **kwargs): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class Base4bitTest(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -111,6 +112,7 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) +@apply_skip_if_not_implemented class Bnb4BitTest(Base4bitTest): def setUp(self): super().setUp() @@ -206,7 +208,7 @@ def test_rwkv_4bit(self): tok = AutoTokenizer.from_pretrained(model_id) text = "Hello my name is" - input_ids = tok.encode(text, return_tensors="pt").to(0) + input_ids = tok.encode(text, return_tensors="pt").to(torch_device) _ = model.generate(input_ids, max_new_tokens=30) @@ -217,7 +219,9 @@ def test_generate_quality(self): the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = self.model_4bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -234,7 +238,7 @@ def test_generate_quality_config(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") output_sequences = model_4bit_from_config.generate( - input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10 + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -252,7 +256,9 @@ def test_generate_quality_dequantize(self): model_4bit.dequantize() encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_4bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -267,15 +273,18 @@ def test_device_assignment(self): self.assertEqual(self.model_4bit.device.type, "cpu") self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) - # Move back to CUDA device - self.model_4bit.to(0) - self.assertEqual(self.model_4bit.device, torch.device(0)) - self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + if torch.cuda.is_available(): + # Move back to CUDA device + self.model_4bit.to("cuda") + self.assertEqual(self.model_4bit.device.type, "cuda") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. - Checks also if other models are casted correctly. + Test whether attempting to change the device or cast the dtype of a model + after converting it to 4-bit precision will raise an appropriate error. + The test ensures that such operations are prohibited on 4-bit models + to prevent invalid conversions. """ # Moving with `to` or `cuda` is not supported with versions < 0.43.2. @@ -297,25 +306,24 @@ def test_device_and_dtype_assignment(self): self.model_4bit.to(torch.float16) with self.assertRaises(ValueError): - # Tries with a `dtype` and `device` - self.model_4bit.to(device="cuda:0", dtype=torch.float16) - - with self.assertRaises(ValueError): - # Tries with a cast + # Tries to cast the 4-bit model to float32 using `float()` self.model_4bit.float() with self.assertRaises(ValueError): - # Tries with a cast + # Tries to cast the 4-bit model to float16 using `half()` self.model_4bit.half() # Test if we did not break anything + self.model_4bit.to(torch.device(torch_device)) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") self.model_fp16 = self.model_fp16.to(torch.float32) - _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) - # Check that this does not throw an error - _ = self.model_fp16.cuda() + if torch.cuda.is_available(): + # Check that this does not throw an error + _ = self.model_fp16.cuda() # Check this does not throw an error _ = self.model_fp16.to("cpu") @@ -344,8 +352,9 @@ def test_bnb_4bit_wrong_config(self): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow +@apply_skip_if_not_implemented class Bnb4BitT5Test(unittest.TestCase): @classmethod def setUpClass(cls): @@ -375,14 +384,14 @@ def test_inference_without_keep_in_fp32(self): # test with `google-t5/t5-small` model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto") - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_4bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) T5ForConditionalGeneration._keep_in_fp32_modules = modules @@ -400,17 +409,18 @@ def test_inference_with_keep_in_fp32(self): # there was a bug with decoders - this test checks that it is fixed self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit)) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_4bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) +@apply_skip_if_not_implemented class Classes4BitModelTest(Base4bitTest): def setUp(self): super().setUp() @@ -460,6 +470,7 @@ def test_correct_head_class(self): self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter) +@apply_skip_if_not_implemented class Pipeline4BitTest(Base4bitTest): def setUp(self): super().setUp() @@ -469,7 +480,8 @@ def tearDown(self): TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 """ - del self.pipe + if hasattr(self, "pipe"): + del self.pipe gc.collect() torch.cuda.empty_cache() @@ -484,7 +496,12 @@ def test_pipeline(self): self.pipe = pipeline( "text-generation", model=self.model_name, - model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16}, + model_kwargs={ + "device_map": "auto", + "load_in_4bit": True, + # float16 isn't supported on CPU, use bfloat16 instead + "torch_dtype": torch.bfloat16 if torch_device == "cpu" else torch.float16, + }, max_new_tokens=self.MAX_NEW_TOKENS, ) @@ -494,6 +511,7 @@ def test_pipeline(self): @require_torch_multi_gpu +@apply_skip_if_not_implemented class Bnb4bitTestMultiGpu(Base4bitTest): def setUp(self): super().setUp() @@ -515,10 +533,13 @@ def test_multi_gpu_loading(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Second real batch - output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_parallel = model_parallel.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) +@apply_skip_if_not_implemented class Bnb4BitTestTraining(Base4bitTest): def setUp(self): self.model_name = "facebook/opt-350m" @@ -531,7 +552,10 @@ def test_training(self): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True) - self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + if torch.cuda.is_available(): + self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + else: + self.assertTrue(all(param.device.type == "cpu" for param in model.parameters())) for param in model.parameters(): param.requires_grad = False # freeze the model - train adapters later @@ -547,10 +571,10 @@ def test_training(self): module.v_proj = LoRALayer(module.v_proj, rank=16) # Step 3: dummy batch - batch = self.tokenizer("Test batch ", return_tensors="pt").to(0) + batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device) # Step 4: Check if the gradient is not None - with torch.cuda.amp.autocast(): + with torch.autocast(torch_device): out = model.forward(**batch) out.logits.norm().backward() @@ -562,6 +586,7 @@ def test_training(self): self.assertTrue(module.weight.grad is None) +@apply_skip_if_not_implemented class Bnb4BitGPT2Test(Bnb4BitTest): model_name = "openai-community/gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187 @@ -570,8 +595,9 @@ class Bnb4BitGPT2Test(Bnb4BitTest): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow +@apply_skip_if_not_implemented class BaseSerializationTest(unittest.TestCase): model_name = "facebook/opt-125m" input_text = "Mars colonists' favorite meals are" @@ -635,7 +661,9 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa d1[k].quant_state.as_dict().values(), ): if isinstance(v0, torch.Tensor): - self.assertTrue(torch.equal(v0, v1.to(v0.device))) + # The absmax will not be saved in the quant_state when using NF4 in CPU + if v0.numel() != 0: + self.assertTrue(torch.equal(v0, v1.to(v0.device))) else: self.assertTrue(v0 == v1) @@ -659,6 +687,7 @@ def _decode(token): ) +@apply_skip_if_not_implemented class ExtendedSerializationTest(BaseSerializationTest): """ tests more combinations of parameters @@ -706,8 +735,9 @@ class GPTSerializationTest(BaseSerializationTest): @require_bitsandbytes @require_accelerate -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow +@apply_skip_if_not_implemented class Bnb4BitTestBasicConfigTest(unittest.TestCase): def test_load_in_4_and_8_bit_fails(self): with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index ca3f043c749a31..5a99ab32e42b8c 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -30,14 +30,17 @@ pipeline, ) from transformers.testing_utils import ( + apply_skip_if_not_implemented, is_accelerate_available, + is_bitsandbytes_available, is_torch_available, require_accelerate, require_bitsandbytes, require_torch, - require_torch_gpu, + require_torch_gpu_if_bnb_not_multi_backend_enabled, require_torch_multi_gpu, slow, + torch_device, ) @@ -77,10 +80,14 @@ def forward(self, input, *args, **kwargs): return self.module(input, *args, **kwargs) + self.adapter(input) +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class BaseMixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -108,6 +115,7 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) +@apply_skip_if_not_implemented class MixedInt8Test(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -240,7 +248,6 @@ def test_llm_skip(self): r""" A simple test to check if `llm_int8_skip_modules` works as expected """ - import bitsandbytes as bnb quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"]) seq_classification_model = AutoModelForSequenceClassification.from_pretrained( @@ -263,7 +270,9 @@ def test_generate_quality(self): the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = self.model_8bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -280,7 +289,7 @@ def test_generate_quality_config(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") output_sequences = model_8bit_from_config.generate( - input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10 + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -298,7 +307,9 @@ def test_generate_quality_dequantize(self): model_8bit.dequantize() encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_8bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -319,8 +330,10 @@ def test_raise_if_config_and_load_in_8bit(self): def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. - Checks also if other models are casted correctly. + Test whether attempting to change the device or cast the dtype of a model + after converting it to 8-bit precision will raise an appropriate error. + The test ensures that such operations are prohibited on 8-bit models + to prevent invalid conversions. """ with self.assertRaises(ValueError): # Tries with `str` @@ -332,21 +345,21 @@ def test_device_and_dtype_assignment(self): with self.assertRaises(ValueError): # Tries with a `device` - self.model_8bit.to(torch.device("cuda:0")) + self.model_8bit.to(torch.device(torch_device)) with self.assertRaises(ValueError): - # Tries with a `device` + # Tries to cast the 8-bit model to float32 using `float()` self.model_8bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries to cast the 4-bit model to float16 using `half()` self.model_8bit.half() # Test if we did not break anything encoded_input = self.tokenizer(self.input_text, return_tensors="pt") self.model_fp16 = self.model_fp16.to(torch.float32) - _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) # Check this does not throw an error _ = self.model_fp16.to("cpu") @@ -385,7 +398,9 @@ def test_int8_serialization(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_from_saved.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -410,7 +425,9 @@ def test_int8_serialization_regression(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_from_saved.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -435,7 +452,9 @@ def test_int8_serialization_sharded(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_from_saved.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -455,7 +474,7 @@ def test_int8_from_pretrained(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -463,7 +482,7 @@ def test_int8_from_pretrained(self): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class MixedInt8T5Test(unittest.TestCase): @classmethod @@ -494,14 +513,14 @@ def test_inference_without_keep_in_fp32(self): # test with `google-t5/t5-small` model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_8bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) T5ForConditionalGeneration._keep_in_fp32_modules = modules @@ -511,7 +530,6 @@ def test_inference_with_keep_in_fp32(self): `flan-t5-small` uses `T5DenseGatedActDense` whereas `google-t5/t5-small` uses `T5DenseReluDense`. We need to test both cases. """ - import bitsandbytes as bnb from transformers import T5ForConditionalGeneration @@ -521,14 +539,14 @@ def test_inference_with_keep_in_fp32(self): # there was a bug with decoders - this test checks that it is fixed self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt)) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_8bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) def test_inference_with_keep_in_fp32_serialized(self): @@ -538,7 +556,6 @@ def test_inference_with_keep_in_fp32_serialized(self): `flan-t5-small` uses `T5DenseGatedActDense` whereas `google-t5/t5-small` uses `T5DenseReluDense`. We need to test both cases. """ - import bitsandbytes as bnb from transformers import T5ForConditionalGeneration @@ -553,14 +570,14 @@ def test_inference_with_keep_in_fp32_serialized(self): # there was a bug with decoders - this test checks that it is fixed self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt)) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_8bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) @@ -614,6 +631,7 @@ def test_correct_head_class(self): self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter) +@apply_skip_if_not_implemented class MixedInt8TestPipeline(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -623,7 +641,8 @@ def tearDown(self): TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 """ - del self.pipe + if hasattr(self, "pipe"): + del self.pipe gc.collect() torch.cuda.empty_cache() @@ -648,6 +667,7 @@ def test_pipeline(self): @require_torch_multi_gpu +@apply_skip_if_not_implemented class MixedInt8TestMultiGpu(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -669,11 +689,14 @@ def test_multi_gpu_loading(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Second real batch - output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_parallel = model_parallel.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @require_torch_multi_gpu +@apply_skip_if_not_implemented class MixedInt8TestCpuGpu(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -683,7 +706,7 @@ def check_inference_correctness(self, model): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Check the exactness of the results - output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) # Get the generation output_text = self.tokenizer.decode(output_parallel[0], skip_special_tokens=True) @@ -819,6 +842,7 @@ def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self): self.check_inference_correctness(model_8bit) +@apply_skip_if_not_implemented class MixedInt8TestTraining(BaseMixedInt8Test): def setUp(self): self.model_name = "facebook/opt-350m" @@ -831,7 +855,10 @@ def test_training(self): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True) - self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + if torch.cuda.is_available(): + self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + else: + self.assertTrue(all(param.device.type == "cpu" for param in model.parameters())) for param in model.parameters(): param.requires_grad = False # freeze the model - train adapters later @@ -847,10 +874,10 @@ def test_training(self): module.v_proj = LoRALayer(module.v_proj, rank=16) # Step 3: dummy batch - batch = self.tokenizer("Test batch ", return_tensors="pt").to(0) + batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device) # Step 4: Check if the gradient is not None - with torch.cuda.amp.autocast(): + with torch.autocast(torch_device): out = model.forward(**batch) out.logits.norm().backward() @@ -862,6 +889,7 @@ def test_training(self): self.assertTrue(module.weight.grad is None) +@apply_skip_if_not_implemented class MixedInt8GPT2Test(MixedInt8Test): model_name = "openai-community/gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357 @@ -870,6 +898,9 @@ class MixedInt8GPT2Test(MixedInt8Test): EXPECTED_OUTPUTS.add("Hello my name is John Doe, and I'm a fan of the") # Expected values on a A10 EXPECTED_OUTPUTS.add("Hello my name is John Doe, and I am a member of the") + # Expected values on Intel CPU + EXPECTED_OUTPUTS.add("Hello my name is John Doe. I am a man. I am") + EXPECTED_OUTPUTS.add("Hello my name is John, and I'm a writer. I'm") def test_int8_from_pretrained(self): r""" @@ -887,6 +918,6 @@ def test_int8_from_pretrained(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)