Skip to content

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Sep 23, 2025

Stack from ghstack (oldest at bottom):

Summary: Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

Unit tests:

python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4

End-to-end tests:
Fine-tuning Llama3.2-3B with and without this PR in axolotl:

  • fine-tune for 1 epoch on yahma/alpaca-cleaned
  • batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

  • With this PR, QAT nvfp4 quantized model achieved 15% lower
    perplexity than the quantized baseline
  • Without this PR, QAT nvfp4 quantized model was about the
    same as the quantized baseline
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|

**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.

In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.

**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")

start = time.time()
for _ in range(1000):
    m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")

start = time.time()
for _ in range(1000):
    m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```

On a single B200:
```
No per_tensor_scale =  1.2855589389801025 seconds
With per_tensor_scale =  1.3009123802185059 seconds
```

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
Details TBD.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Sep 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3050

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fa4d9ee with merge base 5cbbd73 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
Details TBD.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 04f6bce
Pull Request resolved: #3050
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2025
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 04f6bce
Pull Request resolved: #3050
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 04f6bce
Pull Request resolved: #3050
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 47019f4
Pull Request resolved: #3050
self._test_quantize_api_against_ptq(
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
target_prepare_sqnr=12,
target_prepare_sqnr=target_prepare_sqnr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this I would expect that per-tensor scaling would have less error and thus a higher sqnr? but also I dont know what prepare sqnr means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: they're both inf now. The per-tensor case wasn't inf before because we weren't simulating _addmm_nvfp4_dispatch in QAT, but instead just called F.linear on NVFP4Tensor.to_nvfp4().to_dtype(), which doesn't give the same numerics

**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: d8f7eff
Pull Request resolved: #3050
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: d0120f0
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: d0120f0
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: ecbff90
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: ecbff90
Pull Request resolved: #3050
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: a707a59
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: a707a59
Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 633bc65
Pull Request resolved: #3050
@andrewor14 andrewor14 requested review from drisspg and vkuzo September 26, 2025 16:04
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 633bc65
Pull Request resolved: #3050
activation_dtype=e4m3_dtype,
)
elif isinstance(base_config, NVFP4InferenceConfig):
if NVFP4MMConfig.DYNAMIC:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if this is a boolean, might be good to say NVFP4MMConfig.is_dynamic I think, although probably not relevant to this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, can fix separately since this is a PTQ config

andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 633bc65
Pull Request resolved: #3050
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 77f47b7
Pull Request resolved: #3050
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg for B200, I think there should be a TODO somewhere to in the future add a path for non-B200s if they want to at least emulate nvfp4 numerics

andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 77f47b7
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 77f47b7
Pull Request resolved: #3050
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: bb1356c
Pull Request resolved: #3050
@andrewor14 andrewor14 changed the base branch from gh/andrewor14/26/base to main September 30, 2025 15:29
@andrewor14 andrewor14 merged commit d407246 into main Sep 30, 2025
34 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants