-
Notifications
You must be signed in to change notification settings - Fork 349
Improve QAT nvfp4 numerics #3050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
**Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in #2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. Details TBD. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3050
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fa4d9ee with merge base 5cbbd73 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. Details TBD. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 47019f4 Pull Request resolved: #3050
test/quantization/test_qat.py
Outdated
self._test_quantize_api_against_ptq( | ||
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), | ||
target_prepare_sqnr=12, | ||
target_prepare_sqnr=target_prepare_sqnr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain this I would expect that per-tensor scaling would have less error and thus a higher sqnr? but also I dont know what prepare sqnr means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: they're both inf now. The per-tensor case wasn't inf before because we weren't simulating _addmm_nvfp4_dispatch
in QAT, but instead just called F.linear
on NVFP4Tensor.to_nvfp4().to_dtype()
, which doesn't give the same numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d8f7eff Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d0120f0 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d0120f0 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: ecbff90 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: ecbff90 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: a707a59 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: a707a59 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
activation_dtype=e4m3_dtype, | ||
) | ||
elif isinstance(base_config, NVFP4InferenceConfig): | ||
if NVFP4MMConfig.DYNAMIC: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if this is a boolean, might be good to say NVFP4MMConfig.is_dynamic I think, although probably not relevant to this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, can fix separately since this is a PTQ config
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 77f47b7 Pull Request resolved: #3050
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg for B200, I think there should be a TODO somewhere to in the future add a path for non-B200s if they want to at least emulate nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 77f47b7 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 77f47b7 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: bb1356c Pull Request resolved: #3050
Stack from ghstack (oldest at bottom):
Summary: Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
Unit tests:
End-to-end tests:
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
Wikitext:
perplexity than the quantized baseline
same as the quantized baseline