Skip to content

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Sep 23, 2025

Stack from ghstack (oldest at bottom):

Summary: This commit adds an option for the existing
NVFP4InferenceConfig to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

While two level scaling is supported in NVFP4Tensor, today
there is no config API for users to call this. The existing
NVFP4InferenceConfig only supports single level scaling
because including an explicit per_tensor_scale field would
make serialization tricky.

In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to NVFP4Tensor as a
static scale, similar to the proposal in #2572.

Test Plan:

pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4

Also did a quick benchmark before and after:

import copy
import time
import torch
from torch._inductor.utils import do_bench_using_profiling
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

m_mx1 = torch.nn.Linear(1024, 1024, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")

def run_no_per_tensor_scale():
   for _ in range(1000):
      m_mx1(torch.randn(512, 1024, device="cuda", dtype=torch.bfloat16))

def run_with_per_tensor_scale():
   for _ in range(1000):
      m_mx2(torch.randn(512, 1024, device="cuda", dtype=torch.bfloat16))

no_per_tensor_scale_elapsed_ms = do_bench_using_profiling(run_no_per_tensor_scale)
with_per_tensor_scale_elapsed_ms = do_bench_using_profiling(run_with_per_tensor_scale)

print("No per_tensor_scale = ", no_per_tensor_scale_elapsed_ms, "ms")
print("With per_tensor_scale = ", with_per_tensor_scale_elapsed_ms, "ms")

On a single B200:

No per_tensor_scale =  105.45409099999787 ms
With per_tensor_scale =  137.24245699997056 ms

**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.

In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.

**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")

start = time.time()
for _ in range(1000):
    m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")

start = time.time()
for _ in range(1000):
    m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```

On a single B200:
```
No per_tensor_scale =  1.2855589389801025 seconds
With per_tensor_scale =  1.3009123802185059 seconds
```

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Sep 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3049

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 757f6fd with merge base 8e2ca35 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2025
@andrewor14 andrewor14 requested review from drisspg and vkuzo September 23, 2025 20:34
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Sep 23, 2025
@drisspg
Copy link
Contributor

drisspg commented Sep 23, 2025

**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.

In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.

**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")

start = time.time()
for _ in range(1000):
    m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")

start = time.time()
for _ in range(1000):
    m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```

On a single B200:
```
No per_tensor_scale =  1.2855589389801025 seconds
With per_tensor_scale =  1.3009123802185059 seconds
```

[ghstack-poisoned]
@andrewor14
Copy link
Contributor Author

**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.

In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.

**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")

start = time.time()
for _ in range(1000):
    m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")

start = time.time()
for _ in range(1000):
    m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```

On a single B200:
```
No per_tensor_scale =  1.2855589389801025 seconds
With per_tensor_scale =  1.3009123802185059 seconds
```

[ghstack-poisoned]
@vkuzo
Copy link
Contributor

vkuzo commented Sep 24, 2025

128, 64, 256 is also a really small shape so not sure perf measurements are representative

@andrewor14
Copy link
Contributor Author

128, 64, 256 is also a really small shape so not sure perf measurements are representative

Ok, updated to 512, 1024, 1024, it seems pretty similar still

@andrewor14 andrewor14 changed the base branch from gh/andrewor14/25/base to main September 25, 2025 14:02
@andrewor14 andrewor14 merged commit a79e66c into main Sep 25, 2025
34 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants