-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 e4m3_fnuz support for rocm #2588
Conversation
self.input_scale, | ||
self.activation_scale_ub, | ||
bias, | ||
self.dtype, | ||
) | ||
|
||
|
||
class Fp8Linear(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be cleaner to have a separate Fp8LinearRocm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.
@@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str): | |||
.reshape(-1) | |||
.expand(w.shape[0]) | |||
) | |||
try: | |||
input_scale = weights.get_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weights
also has _has_tensor
maybe we should make it public and use it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for try: [...]get_tensor
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use the has_tensor
@@ -72,6 +99,10 @@ def fp8_quantize( | |||
# as both required as inputs to torch._scaled_mm | |||
qweight = qweight.to(qdtype) | |||
scale = scale.float().reciprocal() | |||
|
|||
if SYSTEM == "rocm": | |||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should wire up scale
at some point for CUDA as well.
bias=self.bias, | ||
) | ||
|
||
if type(output) is tuple and len(output) == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this change between torch versions or is output for AMD different?
if type(output) is tuple and len(output) == 2: | |
if isinstance(output, tuple) and len(output) == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a common change for torch 2.5. https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Blas.cpp#L1175
self.input_scale, | ||
self.activation_scale_ub, | ||
bias, | ||
self.dtype, | ||
) | ||
|
||
|
||
class Fp8Linear(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.
@@ -62,7 +62,7 @@ def from_unquant(cls, weight, bias, dtype): | |||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) | |||
|
|||
@classmethod | |||
def from_fp8(cls, weight, scale, _input_scale, bias, dtype): | |||
def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type.
These arguments get a bit messy. It's easy to mix up a tensor or a float (which was already happening here?). Maybe we should switch these to kwargs-only so that the call sites need to be explicit (+ type annotations).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converted them to kwargs and added type hints.
@@ -342,22 +342,19 @@ def get_model( | |||
model_type = config_dict.get("model_type", None) | |||
|
|||
quantization_config = config_dict.get("quantization_config", None) | |||
compression_config = config_dict.get("compression_config", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danieldk config renamed to quantisation config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should check for both keys, at least for the time being. Some customers/users may have checkpoints that still have compression_config
. Maybe with a comment that compression_config
is for backwards compatibility?
@@ -125,9 +164,24 @@ def get_weights_col_packed( | |||
) | |||
scale = scale.reshape(-1).expand(w.shape[0]) | |||
|
|||
input_scale = None | |||
if weights.get_tensor(f"{prefix}.input_scale"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if weights.get_tensor(f"{prefix}.input_scale"): | |
if weights.has_tensor(f"{prefix}.input_scale"): |
?
input_scale = [ | ||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) | ||
for p, shape in zip(prefixes, shapes) | ||
if weights.has_tensor(f"{p}.input_scale") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given this conditional, we probably need an assertion like
assert len(input_scale) == 0 or len(input_scale) == length(prefixes)
@@ -342,22 +342,19 @@ def get_model( | |||
model_type = config_dict.get("model_type", None) | |||
|
|||
quantization_config = config_dict.get("quantization_config", None) | |||
compression_config = config_dict.get("compression_config", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should check for both keys, at least for the time being. Some customers/users may have checkpoints that still have compression_config
. Maybe with a comment that compression_config
is for backwards compatibility?
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.