From 6e42a65e1e2fb10889453a73922465d096246a89 Mon Sep 17 00:00:00 2001 From: Roi Tiefenbrunn Date: Mon, 7 Oct 2024 15:42:17 +0300 Subject: [PATCH 01/14] Simplify HQT config files (#1219) --- .../act_maxabs_pow2_weights_pcs_opt_pow2_quant.json | 2 -- .../quantization_config/maxabs_measure.json | 7 ++----- .../maxabs_measure_include_outputs.json | 2 -- .../text-generation/quantization_config/maxabs_quant.json | 2 -- .../quantization_config/maxabs_quant_phi.json | 1 - .../quantization_config/unit_scale_quant.json | 2 -- 6 files changed, 2 insertions(+), 14 deletions(-) diff --git a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json index bfb932f098..c7c2bd9621 100644 --- a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json +++ b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json @@ -3,7 +3,5 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", - "allowlist": {"types": [], "names": []}, - "blocklist": {"types": [], "names": []}, "dump_stats_path": "./hqt_output/measure" } diff --git a/examples/text-generation/quantization_config/maxabs_measure.json b/examples/text-generation/quantization_config/maxabs_measure.json index 3645fe743a..773bcc54b0 100644 --- a/examples/text-generation/quantization_config/maxabs_measure.json +++ b/examples/text-generation/quantization_config/maxabs_measure.json @@ -2,8 +2,5 @@ "method": "HOOKS", "mode": "MEASURE", "observer": "maxabs", - "allowlist": {"types": [], "names": []}, - "blocklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" -} \ No newline at end of file + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json b/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json index 72dff310ee..230884c3a1 100644 --- a/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json +++ b/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json @@ -3,7 +3,5 @@ "mode": "MEASURE", "observer": "maxabs", "measure_exclude": "NONE", - "allowlist": {"types": [], "names": []}, - "blocklist": {"types": [], "names": []}, "dump_stats_path": "./hqt_output/measure" } \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_quant.json b/examples/text-generation/quantization_config/maxabs_quant.json index 34fab4601d..ce8bae27a8 100644 --- a/examples/text-generation/quantization_config/maxabs_quant.json +++ b/examples/text-generation/quantization_config/maxabs_quant.json @@ -3,7 +3,5 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "maxabs_hw", - "allowlist": {"types": [], "names": []}, - "blocklist": {"types": [], "names": []}, "dump_stats_path": "./hqt_output/measure" } \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_quant_phi.json b/examples/text-generation/quantization_config/maxabs_quant_phi.json index a77200c99f..e7c6b6ddd2 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_phi.json +++ b/examples/text-generation/quantization_config/maxabs_quant_phi.json @@ -3,7 +3,6 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "maxabs_hw", - "allowlist": {"types": [], "names": []}, "blocklist": {"types": [], "names": [ "matmul_qk", "matmul_av", diff --git a/examples/text-generation/quantization_config/unit_scale_quant.json b/examples/text-generation/quantization_config/unit_scale_quant.json index 6bbbde8672..216cf27e68 100644 --- a/examples/text-generation/quantization_config/unit_scale_quant.json +++ b/examples/text-generation/quantization_config/unit_scale_quant.json @@ -3,7 +3,5 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "unit_scale", - "allowlist": {"types": [], "names": []}, - "blocklist": {"types": [], "names": []}, "dump_stats_path": "./hqt_output/measure" } From 8e9744a7a02fe6cb4f7f4628ff5e3f4812bb875e Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:44:56 +0300 Subject: [PATCH 02/14] unify_measurements.py script support to unify PCQ 70B 8x (#1322) --- .../quantization_tools/unify_measurements.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/examples/text-generation/quantization_tools/unify_measurements.py b/examples/text-generation/quantization_tools/unify_measurements.py index 0efc06c8db..4282e4ac49 100644 --- a/examples/text-generation/quantization_tools/unify_measurements.py +++ b/examples/text-generation/quantization_tools/unify_measurements.py @@ -79,9 +79,24 @@ def unify_measurements( for i in range(0, len(max_inputs)): max_inputs[i] = max(measurement_json[node_name]["inputs"][i], max_inputs[i]) if max_outputs is not None: - max_outputs = max(measurement_json[node_name]["outputs"], max_outputs) + if isinstance(max_outputs[0], list): + for i in range(0, len(max_outputs)): + for j in range(0, len(max_outputs[i])): + max_outputs[i][j] = max( + measurement_json[node_name]["outputs"][i][j], max_outputs[i][j] + ) + else: + for i in range(0, len(max_outputs)): + max_outputs[i] = max(measurement_json[node_name]["outputs"][i], max_outputs[i]) if max_weight is not None: - max_weight = max(measurement_json[node_name]["params"]["weight"], max_weight) + if isinstance(max_weight, dict): + for key, values in max_weight.items(): + for i in range(0, len(values)): + max_weight[key][i] = max( + measurement_json[node_name]["params"]["weight"][key][i], max_weight[key][i] + ) + else: + max_weight = max(measurement_json[node_name]["params"]["weight"], max_weight) else: for measurement_json in measurements_jsons: for i in range(0, len(max_inputs)): @@ -99,9 +114,20 @@ def unify_measurements( for i in range(0, len(max_inputs)): unified_json["Nodes"][node_name]["inputs"][i] = max_inputs[i] if max_outputs is not None: - unified_json["Nodes"][node_name]["outputs"] = max_outputs + if isinstance(max_outputs[0], list): + for i in range(0, len(max_outputs)): + for j in range(0, len(max_outputs[i])): + unified_json["Nodes"][node_name]["outputs"][i][j] = max_outputs[i][j] + else: + for i in range(0, len(max_outputs)): + unified_json["Nodes"][node_name]["outputs"][i] = max_outputs[i] if max_weight is not None: - unified_json["Nodes"][node_name]["params"]["weight"] = max_weight + if isinstance(max_weight, dict): + for key, values in max_weight.items(): + for i in range(0, len(values)): + unified_json["Nodes"][node_name]["params"]["weight"][key][i] = max_weight[key][i] + else: + unified_json["Nodes"][node_name]["params"]["weight"] = max_weight else: for i in range(0, len(max_inputs)): for j in range(0, len(max_inputs[i])): From 08ae4fed9805f78d9ae4d43d0046346056a9de59 Mon Sep 17 00:00:00 2001 From: Konrad Drozd Date: Mon, 7 Oct 2024 14:45:39 +0200 Subject: [PATCH 03/14] Add misc. training args (#1346) --- examples/summarization/run_summarization.py | 6 ++++++ optimum/habana/accelerate/accelerator.py | 4 +++- optimum/habana/transformers/trainer.py | 1 + optimum/habana/transformers/training_args.py | 14 ++++++++++++++ 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index 8715c4e75f..28498fc0a2 100755 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -375,6 +375,12 @@ def main(): token=model_args.token, ) + if training_args.do_train and training_args.use_compiled_autograd: + from habana_frameworks.torch.dynamo.compile_backend.experimental import enable_compiled_autograd + + enable_compiled_autograd() + torch._C._set_autograd_fallback_mode("nothing") + # Log on each process the small summary: mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast logger.warning( diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index 2a307bdcd9..67aa9b8984 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -118,6 +118,7 @@ def __init__( step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: GaudiDynamoBackend | str | None = None, + dynamic: bool | None = None, distribution_strategy: str = None, force_autocast: bool = False, ): @@ -310,6 +311,7 @@ def __init__( FutureWarning, ) self.step_scheduler_with_optimizer = step_scheduler_with_optimizer + self.dynamic = dynamic # Mixed precision attributes self.scaler = None @@ -776,7 +778,7 @@ def _prepare_deepspeed(self, *args): if self.state.dynamo_plugin.backend == GaudiDynamoBackend.HPU_BACKEND and not is_compiled_module( kwargs["model"] ): - engine.compile() + engine.compile(compile_kwargs={"dynamic": self.dynamic}) if optimizer is not None: optimizer = DeepSpeedOptimizerWrapper(optimizer) if scheduler is not None: diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index a9c9a1c923..843f646b14 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -2431,6 +2431,7 @@ def create_accelerator_and_postprocess(self): "deepspeed_plugin": self.args.deepspeed_plugin, "gradient_accumulation_plugin": gradient_accumulation_plugin, "distribution_strategy": self.args.distribution_strategy, + "dynamic": self.args.compile_dynamic, } if is_accelerate_available("0.28.0"): args["dataloader_config"] = dataloader_config diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 3a71d46506..44af85fd54 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -97,6 +97,10 @@ class GaudiTrainingArguments(TrainingArguments): Whether to use HPU graphs for performing inference. It will speed up latency but may not be compatible with some operations. use_hpu_graphs_for_training (`bool`, *optional*, defaults to `False`): Whether to use HPU graphs for performing inference. It will speed up training but may not be compatible with some operations. + use_compiled_autograd (`bool`, *optional*, defaults to `False`): + Whether to use compiled autograd for training. Currently only for summarization models. + compile_dynamic (`bool|None`, *optional*, defaults to `None`): + Set value of 'dynamic' parameter for torch.compile. disable_tensor_cache_hpu_graphs (`bool`, *optional*, defaults to `False`): Whether to disable tensor cache when using hpu graphs. If True, tensors won't be cached in hpu graph and memory can be saved. max_hpu_graphs (`int`, *optional*): @@ -156,6 +160,16 @@ class GaudiTrainingArguments(TrainingArguments): }, ) + use_compiled_autograd: Optional[bool] = field( + default=False, + metadata={"help": ("Whether to use compiled autograd for training. Currently only for summarization models.")}, + ) + + compile_dynamic: Optional[bool | None] = field( + default=None, + metadata={"help": ("Set value of 'dynamic' parameter for torch.compile.")}, + ) + disable_tensor_cache_hpu_graphs: Optional[bool] = field( default=False, metadata={"help": "Whether to use a tensor cache for hpu graphs."}, From ad86795dbe4c77324fe2dd83a1bcd0db43aeef59 Mon Sep 17 00:00:00 2001 From: Uri Livne Date: Mon, 7 Oct 2024 15:46:10 +0300 Subject: [PATCH 04/14] Add quantization config for low bs case (#1377) --- .../quantization_config/maxabs_quant_scalar_scales.json | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 examples/text-generation/quantization_config/maxabs_quant_scalar_scales.json diff --git a/examples/text-generation/quantization_config/maxabs_quant_scalar_scales.json b/examples/text-generation/quantization_config/maxabs_quant_scalar_scales.json new file mode 100644 index 0000000000..3c0fe17f86 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_quant_scalar_scales.json @@ -0,0 +1,8 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "dump_stats_path": "./hqt_output/measure", + "scale_format": "scalar" +} \ No newline at end of file From e35e970b6c6e82a9adf289e796e3546953d12389 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Tue, 8 Oct 2024 01:47:53 +0300 Subject: [PATCH 05/14] Remove HQT from OHF (#1257) Co-authored-by: Adam Stachowicz Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: Yeonsil Yoon --- examples/image-to-text/README.md | 2 +- examples/image-to-text/run_pipeline.py | 41 ++++++---------------- examples/text-generation/utils.py | 47 ++++++++++---------------- 3 files changed, 29 insertions(+), 61 deletions(-) diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 2ac99dc829..b5e261f32a 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -93,7 +93,7 @@ python3 run_pipeline.py \ ``` ### Inference with FP8 -Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`. +Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. More information on enabling FP8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index d80939b43f..9f523fc3c7 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -37,43 +37,21 @@ def setup_quantization(model, args): - if os.getenv("USE_INC", "1") != "0": - try: - from neural_compressor.torch.quantization import FP8Config, convert, prepare - except ImportError: - raise ImportError( - "Module neural_compressor is missing. Please use a newer Synapse version to use quantization, or set the environment variable to USE_INC=0" - ) - - config = FP8Config.from_json_file(args.quant_config) - if config.measure: - model = prepare(model, config) - elif config.quantize: - model = convert(model, config) - else: - import habana_frameworks.torch.core as htcore - import habana_quantization_toolkit + from neural_compressor.torch.quantization import FP8Config, convert, prepare - habana_quantization_toolkit.prep_model(model) - htcore.hpu_initialize(model) + config = FP8Config.from_json_file(args.quant_config) + if config.measure: + model = prepare(model, config) + elif config.quantize: + model = convert(model, config) return model def finalize_quantization(model): - if os.getenv("USE_INC", "1") != "0": - try: - from neural_compressor.torch.quantization import finalize_calibration - except ImportError: - raise ImportError( - "Module neural_compressor is missing. Please use a newer Synapse version to use quantization, or set the environment variable to USE_INC=0" - ) - - finalize_calibration(model) - else: - import habana_quantization_toolkit + from neural_compressor.torch.quantization import finalize_calibration - habana_quantization_toolkit.finish_measurements(model) + finalize_calibration(model) def main(): @@ -151,7 +129,7 @@ def main(): # set args.quant_config with env variable if it is set args.quant_config = os.getenv("QUANT_CONFIG", "") - + os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") adapt_transformers_to_gaudi() model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type @@ -227,6 +205,7 @@ def main(): if args.quant_config: generator.model = setup_quantization(generator.model, args) + htcore.hpu_initialize(generator.model) # warm up for i in range(args.warmup): diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 66690c9b05..4cbebefa71 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -189,41 +189,30 @@ def get_torch_compiled_model(model): def setup_quantization(model, args): - if os.getenv("USE_INC", "1") != "0": - try: - from neural_compressor.torch.quantization import FP8Config, convert, prepare - except ImportError: - raise ImportError( - "Module neural_compressor is missing. Please use a newer Synapse version to use quantization, or set the environment variable to USE_INC=0" - ) - - config = FP8Config.from_json_file(args.quant_config) - if config.measure: - model = prepare(model, config) - elif config.quantize: - model = convert(model, config) - else: - import habana_quantization_toolkit + try: + from neural_compressor.torch.quantization import FP8Config, convert, prepare + except ImportError: + raise ImportError( + "Module neural_compressor is missing. Please use a newer Synapse version to use quantization." + ) - habana_quantization_toolkit.prep_model(model) + config = FP8Config.from_json_file(args.quant_config) + if config.measure: + model = prepare(model, config) + if config.quantize: + model = convert(model, config) return model def finalize_quantization(model): - if os.getenv("USE_INC", "1") != "0": - try: - from neural_compressor.torch.quantization import finalize_calibration - except ImportError: - raise ImportError( - "Module neural_compressor is missing. Please use a newer Synapse version to use quantization, or set the environment variable to USE_INC=0" - ) - - finalize_calibration(model) - else: - import habana_quantization_toolkit - - habana_quantization_toolkit.finish_measurements(model) + try: + from neural_compressor.torch.quantization import finalize_calibration + except ImportError: + raise ImportError( + "Module neural_compressor is missing. Please use a newer Synapse version to use quantization." + ) + finalize_calibration(model) def setup_model(args, model_dtype, model_kwargs, logger): From 6b2b243e0d9b22e8d9786f3fe3f2452a2b633cbe Mon Sep 17 00:00:00 2001 From: Danny Semiat Date: Tue, 8 Oct 2024 01:50:09 +0300 Subject: [PATCH 06/14] Load INC GPTQ checkpoint & rename params (#1364) Co-authored-by: Yaser Afshar Co-authored-by: Harish Subramony <81822986+hsubramony@users.noreply.github.com> Co-authored-by: Yeonsil Yoon --- Makefile | 1 + examples/text-generation/README.md | 63 +++++++++++++++++++++- examples/text-generation/run_generation.py | 42 +++++++++++---- examples/text-generation/utils.py | 27 ++++++++-- tests/test_text_generation_example.py | 37 +++++++++++++ 5 files changed, 155 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index 2ac2c85fe4..636ce76a04 100644 --- a/Makefile +++ b/Makefile @@ -105,6 +105,7 @@ slow_tests_diffusers: test_installs # Run text-generation non-regression tests slow_tests_text_generation_example: test_installs + BUILD_CUDA_EXT=0 python -m pip install -vvv --no-build-isolation git+https://github.com/HabanaAI/AutoGPTQ.git python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.18.0 python -m pytest tests/test_text_generation_example.py tests/test_encoder_decoder.py -v -s --token $(TOKEN) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index dc6081c5e6..2a3b68f3cd 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -502,7 +502,7 @@ python run_generation.py \ ### Loading 4 Bit Checkpoints from Hugging Face -You can load pre-quantized 4bit models with the argument `--load_quantized_model`. +You can load pre-quantized 4bit models with the argument `--load_quantized_model_with_inc`. Currently, uint4 checkpoints and single device are supported. More information on enabling 4 bit inference in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_UINT4.html. @@ -524,7 +524,35 @@ python run_lm_eval.py \ --attn_softmax_bf16 \ --bucket_size=128 \ --bucket_internal \ ---load_quantized_model +--load_quantized_model_with_inc +``` + +### Loading 4 Bit Checkpoints from Neural Compressor (INC) + +You can load a pre-quantized 4-bit checkpoint with the argument `--local_quantized_inc_model_path`, supplied with the original model with the argument `--model_name_or_path`. +Currently, only uint4 checkpoints and single-device configurations are supported. +**Note:** In this process, you can load a checkpoint that has been quantized using INC. +More information on enabling 4-bit inference in SynapseAI is available here: +https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_UINT4.html?highlight=inference%20using%20int4#enabling-and-running-uint4-in-pytorch-models. + +Below is an example of loading a llama7b model with a 4bit checkpoint quantized in INC. +Please note that the model checkpoint name is denoted as ``. +Additionally, the following environment variables are used for performance optimizations and are planned to be removed in future versions: +`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1` +```bash +SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \ +python run_lm_eval.py \ +-o acc_load_uint4_model.txt \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--trim_logits \ +--batch_size 1 \ +--bf16 \ +--attn_softmax_bf16 \ +--bucket_size=128 \ +--bucket_internal \ +--local_quantized_inc_model_path \ ``` ### Using Habana Flash Attention @@ -555,6 +583,37 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). +### Running with UINT4 weight quantization using AutoGPTQ + + +Llama2-7b in UINT4 weight only quantization is enabled using [AutoGPTQ Fork](https://github.com/HabanaAI/AutoGPTQ), which provides quantization capabilities in PyTorch. +Currently, the support is for UINT4 inference of pre-quantized models only. + +You can run a *UINT4 weight quantized* model using AutoGPTQ by setting the following environment variables: +`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=true` before running the command, +and by adding the argument `--load_quantized_model_with_autogptq`. + +***Note:*** +Setting the above environment variables improves performance. These variables will be removed in future releases. + + +Here is an example to run a quantized model : +```bash +SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false \ +ENABLE_EXPERIMENTAL_FLAGS=true python run_generation.py \ +--attn_softmax_bf16 \ +--model_name_or_path \ +--use_hpu_graphs \ +--limit_hpu_graphs \ +--use_kv_cache \ +--bucket_size 128 \ +--bucket_internal \ +--trim_logits \ +--max_new_tokens 128 \ +--batch_size 1 \ +--bf16 \ +--load_quantized_model_with_autogptq +``` ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index a8d56ff1cb..8ae8b547fd 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -293,21 +293,11 @@ def setup_parser(parser): type=str, help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.", ) - parser.add_argument( - "--disk_offload", - action="store_true", - help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", - ) parser.add_argument( "--trust_remote_code", action="store_true", help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.", ) - parser.add_argument( - "--load_quantized_model", - action="store_true", - help="Whether to load model from hugging face checkpoint.", - ) parser.add_argument( "--parallel_strategy", type=str, @@ -326,6 +316,35 @@ def setup_parser(parser): help="Run the inference with dataset for specified --n_iterations(default:5)", ) + parser.add_argument( + "--run_partial_dataset", + action="store_true", + help="Run the inference with dataset for specified --n_iterations(default:5)", + ) + + quant_parser_group = parser.add_mutually_exclusive_group() + quant_parser_group.add_argument( + "--load_quantized_model_with_autogptq", + action="store_true", + help="Load an AutoGPTQ quantized checkpoint using AutoGPTQ.", + ) + quant_parser_group.add_argument( + "--disk_offload", + action="store_true", + help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", + ) + quant_parser_group.add_argument( + "--load_quantized_model_with_inc", + action="store_true", + help="Load a Huggingface quantized checkpoint using INC.", + ) + quant_parser_group.add_argument( + "--local_quantized_inc_model_path", + type=str, + default=None, + help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.", + ) + args = parser.parse_args() if args.torch_compile: @@ -338,6 +357,9 @@ def setup_parser(parser): args.flash_attention_fast_softmax = True args.quant_config = os.getenv("QUANT_CONFIG", "") + if args.quant_config and args.load_quantized_model_with_autogptq: + raise RuntimeError("Setting both quant_config and load_quantized_model_with_autogptq is unsupported. ") + if args.quant_config == "" and args.disk_offload: logger.warning( "`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag." diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 4cbebefa71..61a8aa3338 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -237,10 +237,32 @@ def setup_model(args, model_dtype, model_kwargs, logger): torch_dtype=model_dtype, **model_kwargs, ) - elif args.load_quantized_model: + elif args.load_quantized_model_with_autogptq: + from transformers import GPTQConfig + + quantization_config = GPTQConfig(bits=4, use_exllama=False) + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs + ) + elif args.load_quantized_model_with_inc: from neural_compressor.torch.quantization import load model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs) + elif args.local_quantized_inc_model_path: + org_model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + **model_kwargs, + ) + + from neural_compressor.torch.quantization import load + + model = load( + model_name_or_path=args.local_quantized_inc_model_path, + format="default", + device="hpu", + original_model=org_model, + **model_kwargs, + ) else: if args.assistant_model is not None: assistant_model = AutoModelForCausalLM.from_pretrained( @@ -613,8 +635,7 @@ def initialize_model(args, logger): "token": args.token, "trust_remote_code": args.trust_remote_code, } - - if args.load_quantized_model: + if args.load_quantized_model_with_inc or args.local_quantized_inc_model_path: model_kwargs["torch_dtype"] = torch.bfloat16 if args.trust_remote_code: diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index a17333cf68..96d6043f36 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -67,6 +67,9 @@ ("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, 1147.50), ("microsoft/phi-2", 1, 1, True, 128, 128, 254.08932787178165), ], + "load_quantized_model_with_autogptq": [ + ("TheBloke/Llama-2-7b-Chat-GPTQ", 1, 10, False, 128, 2048, 456.7), + ], "deepspeed": [ ("bigscience/bloomz", 8, 1, 36.77314954096159), ("meta-llama/Llama-2-70b-hf", 8, 1, 64.10514998902435), @@ -110,6 +113,7 @@ ("state-spaces/mamba-130m-hf", 224, False, 794.542), ], "fp8": [], + "load_quantized_model_with_autogptq": [], "deepspeed": [ ("bigscience/bloomz-7b1", 8, 1, 31.994268212011505), ], @@ -132,6 +136,7 @@ def _test_text_generation( world_size: int = 8, torch_compile: bool = False, fp8: bool = False, + load_quantized_model_with_autogptq: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, parallel_strategy: str = None, @@ -243,6 +248,8 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] + if load_quantized_model_with_autogptq: + command += ["--load_quantized_model_with_autogptq"] if parallel_strategy is not None: command += [ f"--parallel_strategy={parallel_strategy}", @@ -336,6 +343,36 @@ def test_text_generation_fp8( ) +@pytest.mark.parametrize( + "model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline", + MODELS_TO_TEST["load_quantized_model_with_autogptq"], +) +def test_text_generation_gptq( + model_name: str, + baseline: float, + world_size: int, + batch_size: int, + reuse_cache: bool, + input_len: int, + output_len: int, + token: str, +): + deepspeed = True if world_size > 1 else False + _test_text_generation( + model_name, + baseline, + token, + deepspeed=deepspeed, + world_size=world_size, + fp8=False, + load_quantized_model_with_autogptq=True, + batch_size=batch_size, + reuse_cache=reuse_cache, + max_input_tokens=input_len, + max_output_tokens=output_len, + ) + + @pytest.mark.parametrize("model_name, world_size, batch_size, baseline", MODELS_TO_TEST["deepspeed"]) def test_text_generation_deepspeed(model_name: str, baseline: float, world_size: int, batch_size: int, token: str): _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size) From 211085b3c3be88804fefc42aa2d0c414e7582a69 Mon Sep 17 00:00:00 2001 From: Piotr Bielak Date: Tue, 8 Oct 2024 00:52:53 +0200 Subject: [PATCH 07/14] Enable FusedSDPA fp8 in Llama FT (#1388) Co-authored-by: Yaser Afshar Co-authored-by: Harish Subramony <81822986+hsubramony@users.noreply.github.com> --- examples/language-modeling/run_lora_clm.py | 10 ++++++++++ .../accelerate/utils/transformer_engine.py | 20 +++++++++++++++++++ .../models/llama/configuration_llama.py | 2 ++ .../models/llama/modeling_llama.py | 18 +++++++++++++++-- 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index ebbcc2e4d0..0b16be0725 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -172,6 +172,10 @@ class ModelArguments: ) }, ) + flash_attention_fp8: bool = field( + default=False, + metadata={"help": ("Whether to enable flash attention in FP8.")}, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -509,6 +513,7 @@ def main(): "trust_remote_code": True if model_args.trust_remote_code else None, "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, "token": model_args.token, + "flash_attention_fp8": model_args.flash_attention_fp8, } if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) @@ -705,6 +710,11 @@ def main(): model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + + if model_args.flash_attention_fp8: + import habana_frameworks.torch.hpu as hthpu + + assert hthpu.get_device_name() == "GAUDI3", "Flash attention in FP8 is supported only on Gaudi3" if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 823da61d5c..b40b3b2110 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -42,6 +42,8 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): """ Recursively converts the linear layer of a model to their `transformers_engine` counterpart. """ + from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA + if not is_fp8_available(): raise ImportError("Using `convert_model` requires transformer_engine to be installed.") for name, module in model.named_children(): @@ -75,6 +77,24 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True): new_module.bias.copy_(module.bias) setattr(model, name, new_module) + elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine: + from habana_frameworks.torch.hpex.experimental.transformer_engine import ( + FusedAttention as TE_FusedAttention, + ) + + class TE_ModuleFusedSDPA(torch.nn.Module): + def __init__(self): + super().__init__() + self._hpu_kernel_fsdpa = TE_FusedAttention( + scale=module.scale, + attention_dropout=module.attention_dropout, + enable_recompute=module.enable_recompute, + ) + + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): + return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode) + + setattr(model, name, TE_ModuleFusedSDPA()) else: _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear) diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index fb159cfc48..0d5b2ef6e3 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -28,6 +28,7 @@ def __init__( head_dim=None, fused_qkv=False, parallel_strategy=None, + flash_attention_fp8=False, **kwargs, ): super().__init__( @@ -57,3 +58,4 @@ def __init__( self.fused_qkv = fused_qkv self.parallel_strategy = parallel_strategy + self.flash_attention_fp8 = flash_attention_fp8 diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 6764a46b1c..21bafda4b2 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -351,9 +351,13 @@ def gaudi_llama_repeat_kv( # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): - def __init__(self, fusedSDPA): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode) @@ -416,7 +420,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() - self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -432,6 +435,17 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.fused_scaled_dot_product_attention = ( + ModuleFusedSDPA( + FusedSDPA, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" From 51a2c8e2b9c11497d5c41047c53968f9f77fe1f6 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 7 Oct 2024 16:05:26 -0700 Subject: [PATCH 08/14] Valid sequence length for sdpa (#1183) Co-authored-by: Harish Co-authored-by: Libin Tang Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/text-generation/run_generation.py | 7 ++ examples/text-generation/utils.py | 1 + .../generation/configuration_utils.py | 1 + .../habana/transformers/generation/utils.py | 2 + .../models/llama/modeling_llama.py | 94 +++++++++++++++---- 5 files changed, 88 insertions(+), 17 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 8ae8b547fd..29b97574a6 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -465,6 +465,13 @@ def generate(size=None, reduce_recompile=False): max_length=args.max_input_tokens, truncation=True, ) + + def compute_valid_sequence_lengths_tensor(input_tokens): + attn_mask = input_tokens["attention_mask"] + return torch.sum(attn_mask, dim=1) + + valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device) + generation_config.valid_sequence_lengths = valid_sequence_lengths else: input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True) encode_duration = time.perf_counter() - encode_t0 diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 61a8aa3338..df37b10a30 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -590,6 +590,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + generation_config.valid_sequence_lengths = None return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index ce38a07ed9..ec04f139c9 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -55,3 +55,4 @@ def __init__(self, **kwargs): self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) + self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 8a787a8b68..42eff8db58 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1218,6 +1218,8 @@ def generate( True if generation_config.flash_attention_fast_softmax else False ) model_kwargs["num_virtual_tokens"] = num_virtual_tokens + if generation_config.valid_sequence_lengths is not None: + model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[1] + num_virtual_tokens diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 21bafda4b2..55da544464 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -359,8 +359,33 @@ def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_ self.enable_recompute = enable_recompute self.flash_attention_fp8 = flash_attention_fp8 - def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode) + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + return self._hpu_kernel_fsdpa.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) class Matmul(torch.nn.Module): @@ -506,6 +531,7 @@ def pre_attn_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -636,30 +662,54 @@ def pre_attn_forward( past_key_value = None if use_flash_attention and FusedSDPA is not None: - import habana_frameworks.torch.hpu as ht - softmax_mode = "fast" if flash_attention_fast_softmax else "None" if q_len == 1: # next token use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + use_recompute, + None, + "None", + ) else: # first token if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left", + ) else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + flash_attention_recompute, + None, + "None", + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -855,6 +905,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -888,6 +939,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, **kwargs, @@ -923,6 +975,7 @@ def pre_attn( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -943,6 +996,7 @@ def pre_attn( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1036,6 +1090,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1175,6 +1230,7 @@ def forward( flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, + valid_sequence_lengths, None, ) else: @@ -1194,6 +1250,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1272,6 +1329,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1304,6 +1362,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, @@ -1427,6 +1486,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"), From 1b60d04c740095e6fec71368e9d1cee5e004a792 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 10 Oct 2024 13:24:04 -0400 Subject: [PATCH 09/14] Multiple fixes (dynamo graph break, qwen-moe, multicard) (#1410) --- .../models/llama/modeling_llama.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 55da544464..f0e9935c81 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,6 +1,5 @@ import copy import math -import os from typing import List, Optional, Tuple, Union import torch @@ -662,11 +661,8 @@ def pre_attn_forward( past_key_value = None if use_flash_attention and FusedSDPA is not None: - softmax_mode = "fast" if flash_attention_fast_softmax else "None" - if q_len == 1: # next token - use_recompute = True if os.getenv("QUANT_CONFIG", "") else False attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, @@ -675,13 +671,14 @@ def pre_attn_forward( 0.0, False, None, - softmax_mode, - use_recompute, + "None", + False, None, "None", ) else: # first token + softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: attn_output = self.fused_scaled_dot_product_attention( query_states, @@ -845,22 +842,22 @@ def pre_attn_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward( self, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - position_embeddings, - token_idx, - attn_softmax_bf16, - reuse_cache, - use_flash_attention, - flash_attention_recompute, - flash_attention_causal_mask, - flash_attention_fast_softmax, - cache_idx, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + cache_idx=cache_idx, **kwargs, ) @@ -924,17 +921,17 @@ def forward( residual = hidden_states hidden_states, self_attn_weights, present_key_value = self.pre_attn( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - position_embeddings, - token_idx, - attn_softmax_bf16, - reuse_cache, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, From a2ba41a2a191700774c788482ee33acbd3968075 Mon Sep 17 00:00:00 2001 From: Harish Subramony <81822986+hsubramony@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:53:53 -0700 Subject: [PATCH 10/14] datasets downgrade version to 2.21.0 (#1413) --- examples/text-generation/requirements_lm_eval.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text-generation/requirements_lm_eval.txt b/examples/text-generation/requirements_lm_eval.txt index 494612f122..e632dc1236 100644 --- a/examples/text-generation/requirements_lm_eval.txt +++ b/examples/text-generation/requirements_lm_eval.txt @@ -1 +1,2 @@ https://github.com/EleutherAI/lm-evaluation-harness/archive/0bf683b4e6a9df359b3156ba9ba8d62bdd47e0c0.zip +datasets==2.21.0 From 11f020d1fdfbfb543e2a9682819b64c729df7f1d Mon Sep 17 00:00:00 2001 From: ZhengHongming888 Date: Thu, 17 Oct 2024 06:58:14 -0700 Subject: [PATCH 11/14] Update ci sentence_transformer.sh (#1424) --- tests/ci/sentence_transformers.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ci/sentence_transformers.sh b/tests/ci/sentence_transformers.sh index 03b3d768a9..e731f9b291 100644 --- a/tests/ci/sentence_transformers.sh +++ b/tests/ci/sentence_transformers.sh @@ -7,6 +7,6 @@ python -m pip install --upgrade pip python -m pip install $OPTIMUM_HABANA_PATH[tests] cd $SENTENCE_TRANSFORMER_PATH/tests python -m pip install .. -pytest test_cmnrl.py test_evaluator.py test_multi_process.py test_train_stsb.py test_compute_embeddings.py test_model_card_data.py test_trainer.py test_util.py test_pretrained_stsb.py +pytest test_cmnrl.py test_multi_process.py test_compute_embeddings.py test_model_card_data.py test_util.py cd $OPTIMUM_HABANA_PATH/tests python -m pytest test_sentence_transformers.py From cf1cbd1d3e3a8651126dfd32151b95fd30c59d5b Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 17 Oct 2024 07:01:17 -0700 Subject: [PATCH 12/14] Fix load INC load weights compile error due to Transformer 4.45 upgrade. (#1421) --- examples/text-generation/utils.py | 49 ++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index df37b10a30..78cc79a238 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -245,8 +245,12 @@ def setup_model(args, model_dtype, model_kwargs, logger): args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs ) elif args.load_quantized_model_with_inc: - from neural_compressor.torch.quantization import load + #TODO: This will be removed in v1.19 Synapse release + #Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release. + import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl + nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight + from neural_compressor.torch.quantization import load model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs) elif args.local_quantized_inc_model_path: org_model = AutoModelForCausalLM.from_pretrained( @@ -662,3 +666,46 @@ def initialize_model(args, logger): logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") logger.info(f"Model initialization took {(init_end - init_start):.3f}s") return model, assistant_model, tokenizer, generation_config + +#TODO:This will be removed from Synapse v1.19 release. +#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release. +def local_load_remaining_pretrained_weight(self,model): + from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict + + resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) + torch_dtype = self.kwargs.pop("torch_dtype", torch.float32) + dtype_orig = self.kwargs.pop("dtype_orig", None) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + + params_dict={ + "model": model, + "state_dict": state_dict, + "start_prefix": "", + "expected_keys": list(state_dict.keys()), + "device_map": {"": self.device}, + "offload_folder": offload_folder, + "state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None, + "state_dict_index": {} if offload_state_dict else None, + "dtype": torch_dtype, + "keep_in_fp32_modules": [], + } + + _load_state_dict_into_meta_model(**params_dict) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + return model From 8d759369dbb36fc2a511bf2687d8534647f52ee6 Mon Sep 17 00:00:00 2001 From: Harish Subramony <81822986+hsubramony@users.noreply.github.com> Date: Thu, 17 Oct 2024 07:12:06 -0700 Subject: [PATCH 13/14] Update language-modeling README.md, add trust_remote_code for flan-t5-xl (#1422) --- examples/language-modeling/README.md | 3 ++- examples/text-generation/README.md | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index b0c027df65..8ea0cdd554 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -868,7 +868,8 @@ python3 ../gaudi_spawn.py --world_size 8 --use_mpi peft_poly_seq2seq_with_genera --per_device_eval_batch_size 4 \ --bf16 \ --use_hpu_graphs_for_inference \ - --use_hpu_graphs_for_training + --use_hpu_graphs_for_training \ + --trust_remote_code ``` diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 2a3b68f3cd..38209bd521 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -633,6 +633,8 @@ First, you should install the requirements: pip install -r requirements_lm_eval.txt ``` +> [!NOTE] +> If custom models on hub is being used, please set env variable HF_DATASETS_TRUST_REMOTE_CODE=true instead of arg --trust_remote_code with the installed lm_eval version and dependency datasets==2.21.0 ### Examples From f4cb594ff60b74c41bbe88d58ee6f8dad63222a1 Mon Sep 17 00:00:00 2001 From: "Seunghyuk Park (shepark)" Date: Thu, 17 Oct 2024 07:12:48 -0700 Subject: [PATCH 14/14] Update unify_measurements.py support info (#1425) --- examples/text-generation/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 38209bd521..87a58f8da3 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -470,6 +470,10 @@ cards 0-3 and cards 4-7 will be unified in two different measurement files. All More information on usage of the unifier script can be found in fp8 Habana docs: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html +> [!NOTE] +> unify_measurements.py does not support PCQ mode. (default: PTQ) + + ### CPU memory reduction on single card