-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization #15734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2702132
cb79d54
f346cd2
0ffdfae
71a072e
1f3729d
d6652ac
8f702c8
1f718b2
f9f2aba
107a7a5
ecb3320
4ef102d
71f89c5
94258b0
239f1d1
fc903c3
1780946
aaa0d9e
9e0d8ce
9a2afda
2161514
1366205
3ed0e91
8b66ddc
bd3b2c7
dd63b79
6350ed7
5eff339
0727349
f6b2001
f1a29d3
5f478de
65ff486
95d1795
e8f1ed7
44eb67e
e8e6fef
c370096
3423c44
301811b
eda000c
a3ca4f1
90022be
67adb34
3b5ba1a
b8ba91b
4bfacad
b911e3f
ae7f6c6
345ec5d
5c399ea
787eb33
d39aee9
b22250f
9508130
3a7048c
5e4a79f
ea67811
6364b73
62c2efb
fb94876
f15c554
94ac0b1
2534cef
9742688
07abdab
79eed6c
068b7a0
b08573b
2853082
0073fb4
2ceedcb
e863a92
456a13a
7f4f1a5
57fa3e1
ef180b0
54da3c9
9da6f61
1a6e018
8d360b6
b2e969f
b04ee94
402c564
3516018
1a81fc8
3fe98e4
3e7e199
6019575
642489b
e974ec4
a515af5
37f3287
57bb189
9338de2
9229d9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,6 +140,11 @@ def get_cache_scale(self, name: str) -> Optional[str]: | |
return name.replace(".k_proj.output_scale", ".attn.k_scale") | ||
if name.endswith(".output_scale") and ".v_proj" in name: | ||
return name.replace(".v_proj.output_scale", ".attn.v_scale") | ||
if name.endswith(".output_scale") and ".q_proj" in name: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it necessary to extend this part if already handled in Same question applies for change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was necessary, I had to add this for it to work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, just checked again, the compressed tensors modification is not necesssary, unless maybe there is int8 w8a8 output_scale, but that work isn't happening yet. But some FP8 models will be quantized as "fp8" in quantization config, and some are just "quark", so the renaming needs to happen in both places. |
||
return name.replace(".q_proj.output_scale", ".attn.q_scale") | ||
if name.endswith("self_attn.prob_output_scale"): | ||
return name.replace(".prob_output_scale", ".attn.prob_scale") | ||
# If no matches, return None | ||
return None | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,9 @@ def create_weights(self, layer: torch.nn.Module): | |
requires_grad=False) | ||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), | ||
requires_grad=False) | ||
# Initialize P = softmax(QK^T) scales | ||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), | ||
requires_grad=False) | ||
|
||
def apply(self, layer: torch.nn.Module) -> torch.Tensor: | ||
raise RuntimeError( | ||
|
@@ -97,5 +100,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
"may cause accuracy issues. Please make sure k/v_scale " | ||
"scaling factors are available in the fp8 checkpoint.") | ||
|
||
if layer.q_scale > 0.0: | ||
q_scale = layer.q_scale | ||
if current_platform.is_fp8_fnuz(): | ||
q_scale *= 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why multiply by 2 here, is that because of the extra bit in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, why can't we just keep the scales on device? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to use is_fp8_fnuz() and keeping scales on device (not sure why the cpu thing was there). |
||
layer.calculate_kv_scales = False | ||
else: | ||
q_scale = 1.0 | ||
if layer.prob_scale > 0.0: | ||
prob_scale = layer.prob_scale | ||
if current_platform.is_fp8_fnuz(): | ||
prob_scale *= 2 | ||
else: | ||
prob_scale = 1.0 | ||
|
||
is_singleton_float = lambda x: isinstance(x, float) or isinstance( | ||
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() | ||
if not is_singleton_float(q_scale) or not is_singleton_float( | ||
prob_scale): | ||
raise ValueError("Only support per-tensor scaling factor" | ||
"for fp8-quantized Q/prob") | ||
|
||
# These are used in the final Attention.forward() | ||
layer._q_scale.copy_(q_scale) | ||
layer._prob_scale.copy_(prob_scale) | ||
if q_scale == 1.0 or prob_scale == 1.0: | ||
logger.warning_once( | ||
f"Using Q scale {q_scale} and prob scale {prob_scale} " | ||
"with fp8 attention. This may cause accuracy issues. " | ||
"Please make sure Q/prob scaling factors are " | ||
"available in the fp8 checkpoint.") | ||
|
||
del layer.k_scale | ||
del layer.v_scale | ||
del layer.q_scale | ||
del layer.prob_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: why is this called
_prob_scale
?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the scale for P, which is the tensor resultng from
softmax(Q@K)
calculation.O
issoftmax(Q@K)@V
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It comes from the parameter in the model, something like
self_attn.prob_output_scale
, it gets remapped to.attn.prob_scale
and also @ProExpertProg 's comments.