Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 e4m3_fnuz support for rocm #2588

Merged
merged 5 commits into from
Oct 16, 2024
Merged

Fp8 e4m3_fnuz support for rocm #2588

merged 5 commits into from
Oct 16, 2024

Conversation

mht-sharma
Copy link
Collaborator

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

self.input_scale,
self.activation_scale_ub,
bias,
self.dtype,
)


class Fp8Linear(torch.nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be cleaner to have a separate Fp8LinearRocm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.

@@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)
try:
input_scale = weights.get_tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weights also has _has_tensor maybe we should make it public and use it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for try: [...]get_tensor below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use the has_tensor

@@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()

if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should wire up scale at some point for CUDA as well.

bias=self.bias,
)

if type(output) is tuple and len(output) == 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this change between torch versions or is output for AMD different?

Suggested change
if type(output) is tuple and len(output) == 2:
if isinstance(output, tuple) and len(output) == 2:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.input_scale,
self.activation_scale_ub,
bias,
self.dtype,
)


class Fp8Linear(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.

@@ -62,7 +62,7 @@ def from_unquant(cls, weight, bias, dtype):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)

@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type.

These arguments get a bit messy. It's easy to mix up a tensor or a float (which was already happening here?). Maybe we should switch these to kwargs-only so that the call sites need to be explicit (+ type annotations).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted them to kwargs and added type hints.

@@ -342,22 +342,19 @@ def get_model(
model_type = config_dict.get("model_type", None)

quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danieldk config renamed to quantisation config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check for both keys, at least for the time being. Some customers/users may have checkpoints that still have compression_config. Maybe with a comment that compression_config is for backwards compatibility?

@mht-sharma mht-sharma requested a review from danieldk October 14, 2024 12:00
@@ -125,9 +164,24 @@ def get_weights_col_packed(
)
scale = scale.reshape(-1).expand(w.shape[0])

input_scale = None
if weights.get_tensor(f"{prefix}.input_scale"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if weights.get_tensor(f"{prefix}.input_scale"):
if weights.has_tensor(f"{prefix}.input_scale"):

?

input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this conditional, we probably need an assertion like

assert len(input_scale) == 0 or len(input_scale) == length(prefixes)

@@ -342,22 +342,19 @@ def get_model(
model_type = config_dict.get("model_type", None)

quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check for both keys, at least for the time being. Some customers/users may have checkpoints that still have compression_config. Maybe with a comment that compression_config is for backwards compatibility?

@danieldk danieldk self-requested a review October 15, 2024 13:32
@danieldk danieldk merged commit 704a58c into main Oct 16, 2024
11 of 12 checks passed
@danieldk danieldk deleted the rocm-fp8 branch October 16, 2024 07:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants