Skip to content

Commit ce27def

Browse files
authored
Revert "Nvfp4 static gs (#61)"
This reverts commit c4ef813.
1 parent 310c908 commit ce27def

File tree

4 files changed

+16
-71
lines changed

4 files changed

+16
-71
lines changed

examples/offline_inference/basic/basic_hpu.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@
4141
# model_path = "/software/users/yiliu4/deepseek-ai/DeepSeek-R1-NVFP4-OFFLINE"
4242
model_path = "/software/users/yiliu4/HF_HOME/weiweiz1/DeepSeek-R1-NVFP4-RTN"
4343

44-
# model_path = "/software/users/yiliu4/HF_HOME/Yi30/DeepSeek-V2-Lite-NVFP4-W4A4-RTN-GLOBAL-SCALE-WW"
45-
46-
4744
import os
4845

4946
os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
@@ -123,7 +120,6 @@ def main(args):
123120
# Create a sampling params object.
124121
max_model_len = 2048
125122
model_path = args.model_path
126-
print(f"model_path: {model_path}")
127123
llm = LLM(
128124
model=model_path,
129125
# quantization="inc",

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,13 @@ def get_moe_method(
8888

8989

9090
def nvfp4_unpacked_weight_gemm(
91-
x,
92-
weight,
93-
weight_scale,
94-
weight_global_scale,
95-
input_global_scale=None,
91+
x, weight_unpacked, weight_scale, weight_global_scale
9692
):
97-
weight_unpacked = weight
9893
# return self.run_nvfp4_emulations(x, layer)
9994
from vllm.model_executor.layers.quantization.utils.nvfp4_qdq import (
10095
unpacked_nvfp4_to_fp8,
10196
dequant_nvfp4,
10297
qdq_nvfp4,
103-
qdq_nvfp4_with_gs,
10498
)
10599

106100
# bs, seq_len, hidden_size = x.shape
@@ -113,7 +107,8 @@ def nvfp4_unpacked_weight_gemm(
113107
packed=False,
114108
)
115109

116-
x = qdq_nvfp4_with_gs(x, input_global_scale)
110+
# breakpoint()
111+
x = qdq_nvfp4(x)
117112
out = x @ hp_weight.t()
118113
# out = out.reshape(bs, seq_len, -1)
119114
return out
@@ -290,7 +285,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
290285
layer.w13_weight_global_scale[:, 1]):
291286
logger.warning_once(
292287
"w1_weight_global_scale must match w3_weight_global_scale. "
293-
f"Accuracy may be affected. {getattr(layer, 'layer_name', '')}")
288+
"Accuracy may be affected.")
294289

295290
# Take inverse of global scale saved to disk
296291
layer.w13_weight_scale_2 = torch.nn.Parameter(
@@ -480,19 +475,16 @@ def apply(
480475
]
481476
local_w3_global_scale = local_w13_global_scale[1]
482477
local_w3_input_global_scale = local_w13_input_global_scale[1]
483-
484-
# breakpoint()
478+
485479
local_w1_out = nvfp4_unpacked_weight_gemm(
486480
x=current_state_static,
487-
input_global_scale=local_w1_input_global_scale,
488-
weight=local_w1_unpacked,
481+
weight_unpacked=local_w1_unpacked,
489482
weight_scale=local_w1_scale,
490483
weight_global_scale=local_w1_global_scale,
491484
)
492485
local_w3_out = nvfp4_unpacked_weight_gemm(
493486
x=current_state_static,
494-
input_global_scale=local_w3_input_global_scale,
495-
weight=local_w3_unpacked,
487+
weight_unpacked=local_w3_unpacked,
496488
weight_scale=local_w3_scale,
497489
weight_global_scale=local_w3_global_scale,
498490
)
@@ -501,12 +493,10 @@ def apply(
501493

502494
local_w2_out = nvfp4_unpacked_weight_gemm(
503495
x=w13_out,
504-
input_global_scale=local_w2_input_global_scale,
505-
weight=local_w2_unpacked,
496+
weight_unpacked=local_w2_unpacked,
506497
weight_scale=local_w2_scale,
507498
weight_global_scale=local_w2_global_scale,
508499
)
509-
510500
padded_weight = experts_mask[expert_index + ep_shift].unsqueeze(
511501
1
512502
)

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def apply_weights(self,
212212
from vllm.model_executor.layers.quantization.utils.nvfp4_qdq import (
213213
dequant_nvfp4,
214214
qdq_nvfp4,
215-
qdq_nvfp4_with_gs
216215
)
217216

218217
need_reshape = False
@@ -230,7 +229,8 @@ def apply_weights(self,
230229
packed=False,
231230
)
232231

233-
x = qdq_nvfp4_with_gs(x, layer.input_global_scale)
232+
# breakpoint()
233+
x = qdq_nvfp4(x)
234234
out = x @ hp_weight.t()
235235
if need_reshape:
236236
out = out.reshape(bs, seq_len, -1)

vllm/model_executor/layers/quantization/utils/nvfp4_qdq.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -265,30 +265,18 @@ def nvfp4_quantize(
265265
return out_scales, data_lp
266266

267267

268-
def to_nvfp4(x, x_global_scale=None, do_pack=True):
269-
if x_global_scale is None:
270-
tensor_amax = torch.max(torch.abs(x))
271-
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
272-
x_global_scale = per_tensor_scale
268+
def to_nvfp4(x, do_pack=True):
269+
tensor_amax = torch.max(torch.abs(x))
270+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
273271
out_scales, data_lp = nvfp4_quantize(
274272
data_hp=x,
275273
block_size=16,
276-
per_tensor_scale=x_global_scale,
274+
per_tensor_scale=per_tensor_scale,
277275
do_pack=do_pack,
278276
)
279277
return data_lp, out_scales, per_tensor_scale
280278

281279

282-
def to_nvfp4_with_gs(x, x_global_scale, do_pack=True):
283-
out_scales, data_lp = nvfp4_quantize(
284-
data_hp=x,
285-
block_size=16,
286-
per_tensor_scale=x_global_scale,
287-
do_pack=do_pack,
288-
)
289-
return data_lp, out_scales
290-
291-
292280
def dequant_nvfp4(
293281
data_lp,
294282
out_scales,
@@ -329,26 +317,11 @@ def check_nan(x):
329317
return torch.isnan(x).any() or torch.isinf(x).any()
330318

331319

332-
def qdq_nvfp4(x, x_global_scale=None):
333-
if envs.VLLM_DISABLE_INPUT_QDQ:
334-
return x
335-
336-
data_lp, x_scale = to_nvfp4(x, x_global_scale, do_pack=False)
337-
x_dq = dequant_nvfp4(
338-
data_lp,
339-
x_scale,
340-
x_global_scale,
341-
original_dtype=x.dtype,
342-
packed=False,
343-
)
344-
return x_dq
345-
346-
347-
def qdq_nvfp4(x, x_global_scale=None):
320+
def qdq_nvfp4(x):
348321
if envs.VLLM_DISABLE_INPUT_QDQ:
349322
return x
350323

351-
data_lp, x_scale = to_nvfp4_with_gs(x, x_global_scale, do_pack=False)
324+
data_lp, x_scale, x_global_scale = to_nvfp4(x, do_pack=False)
352325
x_dq = dequant_nvfp4(
353326
data_lp,
354327
x_scale,
@@ -359,20 +332,6 @@ def qdq_nvfp4(x, x_global_scale=None):
359332
return x_dq
360333

361334

362-
def qdq_nvfp4_with_gs(x, x_global_scale):
363-
if envs.VLLM_DISABLE_INPUT_QDQ:
364-
return x
365-
366-
data_lp, x_scale = to_nvfp4_with_gs(x, x_global_scale, do_pack=False)
367-
x_dq = dequant_nvfp4(
368-
data_lp,
369-
x_scale,
370-
x_global_scale,
371-
original_dtype=x.dtype,
372-
packed=False,
373-
)
374-
return x_dq
375-
376335
class NVFP4Linear(torch.nn.Module):
377336
def __init__(self, in_features, out_features):
378337
super().__init__()

0 commit comments

Comments
 (0)