Skip to content

[AMD][Kernel][BugFix] fix test_rocm_compressed_tensors_w8a8 for rocm #19509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import importlib
from typing import TYPE_CHECKING, Optional, Union

import torch
Expand Down Expand Up @@ -706,10 +705,8 @@ def cutlass_scaled_mm(a: torch.Tensor,

cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
triton_scaled_mm)
Comment on lines +708 to +709
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing the dynamic import using importlib.import_module with a direct from ... import ... statement is a good change. This directly addresses the stated torch.compile incompatibility with importlib and generally improves code clarity by making the dependency explicit.

The use of # noqa is appropriate here to suppress linter warnings for an import that is not at the top level of the module, given its conditional nature within the if block.

return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

out = torch.empty((m, n), dtype=out_dtype, device=a.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def triton_scaled_mm(input: torch.Tensor,
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b

assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
[N, 1])
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1
or scale_a.shape[0] == M)
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1
or scale_b.shape[0] == N)
Comment on lines +147 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying the assertions to avoid direct comparison with torch.Size objects (e.g., scale_a.shape == torch.Size([1, 1])) and instead using direct shape attribute access (e.g., scale_a.shape[1] == 1 and scale_a.shape[0] == 1) is a good solution to the torch.compile issue mentioned in the PR description.

The new assertions are logically equivalent to the old ones and ensure compatibility. The way they are split across lines maintains readability.

assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)
Expand Down