Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Continuation] Merge EmbeddedLLM/vllm-rocm into vLLM main #1836

Merged
merged 63 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
43af310
port dtype_float16.cuh and cache_kernels.cu
pcmoritz Oct 10, 2023
cc81866
port dtype_bfloat16.cuh
pcmoritz Oct 10, 2023
475b5e2
port attention_utils.cuh
pcmoritz Oct 10, 2023
ddc496c
port more kernels
pcmoritz Oct 10, 2023
5eaa7a1
fix typo
pcmoritz Oct 10, 2023
f7273c6
add cuda_compat.h
pcmoritz Oct 10, 2023
99c3be7
Merge branch 'main' into port-to-rocm
pcmoritz Oct 16, 2023
f8093dc
sync branches
pcmoritz Oct 16, 2023
41df689
update
pcmoritz Oct 16, 2023
93be9c5
update
pcmoritz Oct 16, 2023
d96fa3c
fixes
pcmoritz Oct 16, 2023
421365b
cleanup
pcmoritz Oct 16, 2023
06b800e
update
pcmoritz Oct 16, 2023
2312beb
update
pcmoritz Oct 16, 2023
2958b39
update
pcmoritz Oct 16, 2023
3f89734
fmt
pcmoritz Oct 16, 2023
5397a57
cleanup
pcmoritz Oct 16, 2023
90e02d2
refactor
pcmoritz Oct 16, 2023
a420202
update
pcmoritz Oct 16, 2023
b072182
Merge branch 'main' into port-to-rocm
pcmoritz Oct 17, 2023
2d1e435
detecting rocm and adding flag for compiling
iAmir97 Oct 17, 2023
e231b79
using asm volatile instead of hip api
iAmir97 Oct 17, 2023
31bb335
using asm volatile for type casting of f16
iAmir97 Oct 17, 2023
b027d06
Hipifying csrc file to accomodate rocm builds
kliuae Nov 27, 2023
9a1781c
Checked CUDA ROCm Compatibility (#15)
tjtanaa Nov 29, 2023
0f67117
merged with latest upstream
kliuae Nov 29, 2023
7dbf2d4
format code
kliuae Nov 29, 2023
52ffcf0
downgrade torch requirement in toml to torch 2.0.1 to accommodate ROC…
kliuae Nov 29, 2023
27f0513
Merged changes from vllm main
kliuae Dec 1, 2023
5cce649
Merged with changes in vllm main
kliuae Dec 1, 2023
16d3ccc
Updated Dockerfile, rocm installation guide and setuppy
kliuae Dec 1, 2023
d764f9d
Updated amd installation guide and dockerfile
kliuae Dec 2, 2023
e798632
Added num_gpus for ray init in ROCm
kliuae Dec 2, 2023
0e8129f
Synced torch version with vllm main in pyproject.toml
kliuae Dec 2, 2023
2b3821b
Format code
kliuae Dec 2, 2023
0c8795a
Merge branch 'main' into vllm-cuda-rocm-dev
kliuae Dec 4, 2023
5793f30
Updated dockerfile.rocm and requirements-rocm.txt
kliuae Dec 4, 2023
b172cdd
Disable mistral for ROCm
kliuae Dec 4, 2023
9cd5b18
Format code
kliuae Dec 4, 2023
b86f88a
Revert to cuda kernels
kliuae Dec 5, 2023
9727ab4
Merge remote-tracking branch 'pcmoritz/port-to-rocm'
kliuae Dec 5, 2023
c4aa2af
Port latest kernels to ROCm
kliuae Dec 5, 2023
f8c304e
Update readme
kliuae Dec 5, 2023
e608c30
Cleaned up kernel code
kliuae Dec 5, 2023
951e225
Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize
kliuae Dec 6, 2023
25f9a97
Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize
kliuae Dec 6, 2023
e984ada
Updated ROCm warp size
kliuae Dec 6, 2023
cc1195f
Format code
kliuae Dec 6, 2023
f92980e
Check hip from wrapper
kliuae Dec 6, 2023
66b4aa1
Format code
kliuae Dec 6, 2023
4a0ecb8
Enable support for mistral models
kliuae Dec 6, 2023
acf51a8
Fixed hip device attribute
kliuae Dec 6, 2023
4a52977
Format code
kliuae Dec 6, 2023
23a987a
Restored awq file
kliuae Dec 7, 2023
8787a4e
Format code
kliuae Dec 7, 2023
5911131
Merge latest vllm main
kliuae Dec 7, 2023
9fa8075
Updated rocm dockerfile
kliuae Dec 7, 2023
81e052d
Update amd installation guide
kliuae Dec 7, 2023
fb8ac26
Update vLLM Documentations (#18)
tjtanaa Dec 7, 2023
98f5487
Updated setup.py, vllm/utils.py and amd-installation doc
kliuae Dec 8, 2023
d90187a
Updated setup.py
kliuae Dec 8, 2023
c840531
Format code
kliuae Dec 8, 2023
9dba1d8
Merge branch 'main' into vllm-cuda-rocm-mod
kliuae Dec 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Check hip from wrapper
  • Loading branch information
kliuae committed Dec 6, 2023
commit f92980e357d7fc0691f6ab54df885a2a86ee7ce9
25 changes: 17 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,24 @@
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: Which part of the code makes this requirement? That is, why is gfx8 not supported? While I don't we have to support it, I'd like to know why we don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way we compiled this list of rocm supported archs is based on what AMD is supporting for ROCm and HIP, furthermore each arch has its own set of assembly instructions we have to make sure the currently used assembly instructions is supported by those archs as well.

To the best of our knowledge, the following are the ARCH requirements needed by different libraries:

  1. Pytorch gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1101
  2. vLLM Custom Ops: gfx90a gfx908 gfx906 gfx1030 gfx1100
  3. Flash-Attention-ROCm: gfx90a gfx940 gfx941 gfx942

Should we use the intersection of all three ARCH requirements instead?

Copy link
Collaborator

@WoosukKwon WoosukKwon Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Thanks for the detailed explanation. Sorry, I have little background on this stuff. Maybe I should learn more about ROCm and AMD GPUs 😂

As far as I understand, the vLLM custom ops support every "recent" AMD GPUs, and currently the supported GPU list is limited by the ROCm Flash Attention. Is this correct?

Copy link
Contributor Author

@tjtanaa tjtanaa Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon We believe in near future, the supported GPU ARCH is going to be restricted by Flash Attention ROCm.

Copy link
Collaborator

@hongxiayang hongxiayang Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi: The supported gfx arch for ROCm is documented here (as "LLVM target" column): https://rocm.docs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus.

# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


def _is_hip():
return torch.version.hip
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved


def _is_cuda():
return torch.version.cuda
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved


# Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"]

if torch.version.hip and ROCM_HOME is not None:
if _is_hip() and ROCM_HOME is not None:
NVCC_FLAGS += ["-DUSE_ROCM"]

if torch.version.cuda and CUDA_HOME is None:
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")

Expand Down Expand Up @@ -129,7 +138,7 @@ def get_torch_arch_list() -> Set[str]:

# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if torch.version.cuda and not compute_capabilities:
if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
Expand All @@ -140,7 +149,7 @@ def get_torch_arch_list() -> Set[str]:
"GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}")

if torch.version.cuda:
if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
Expand Down Expand Up @@ -191,7 +200,7 @@ def get_torch_arch_list() -> Set[str]:
num_threads = min(os.cpu_count(), 8)
NVCC_FLAGS += ["--threads", str(num_threads)]

elif torch.version.hip:
elif _is_hip():
amd_arch = get_amdgpu_offload_arch()
if amd_arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
Expand All @@ -211,7 +220,7 @@ def get_torch_arch_list() -> Set[str]:
"csrc/pybind.cpp",
]

if torch.version.cuda:
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

vllm_extension = CUDAExtension(
Expand Down Expand Up @@ -245,7 +254,7 @@ def find_version(filepath: str) -> str:
def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py"))

if torch.version.hip:
if _is_hip():
# Get the HIP version
hipcc_version = get_hipcc_rocm_version()
if hipcc_version != MAIN_CUDA_VERSION:
Expand All @@ -271,7 +280,7 @@ def read_readme() -> str:

def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
if torch.version.hip:
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
else:
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/arg_utils.py
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.utils import is_hip


@dataclass
Expand Down Expand Up @@ -89,7 +90,7 @@ def add_cli_args(
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
if torch.cuda.is_available() and torch.version.hip:
if is_hip():
# do something specific for HIP
parser.add_argument(
'--load-format',
Expand All @@ -106,7 +107,7 @@ def add_cli_args(
help='data type for model weights and activations. '
'The default option is FP16 precision '
'Supports FP16 and BF16 ')
elif torch.cuda.is_available() and torch.version.cuda:
else:
# do something specific for CUDA
parser.add_argument(
'--load-format',
Expand Down Expand Up @@ -197,7 +198,7 @@ def add_cli_args(
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
if torch.cuda.is_available() and torch.version.hip:
if is_hip():
# Quantization settings.
parser.add_argument('--quantization',
'-q',
Expand All @@ -206,7 +207,7 @@ def add_cli_args(
default=None,
help='Method used to quantize the weights')

elif torch.cuda.is_available() and torch.version.cuda:
else:
# Quantization settings.
parser.add_argument('--quantization',
'-q',
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip

import torch

Expand Down Expand Up @@ -75,7 +76,7 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
if torch.version.hip:
if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip

_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
Expand Down Expand Up @@ -161,7 +162,7 @@ def forward(
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(torch.cuda.is_available() and torch.version.hip) else None,
(is_hip()) else None,
)
output = out.view_as(query)
else:
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import torch
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.utils import is_hip

_QUANTIZATION_CONFIG_REGISTRY = {
"squeezellm": SqueezeLLMConfig,
}

if torch.cuda.is_available() and torch.version.cuda:
if not is_hip():
from vllm.model_executor.layers.quantization.awq import AWQConfig
_QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import torch
from torch.nn.parameter import Parameter
if torch.cuda.is_available() and torch.version.hip:
from vllm.utils import is_hip
if is_hip():
# do something specific for HIP
print("Warning: vLLM does not support AWQ on ROCm.")
elif torch.cuda.is_available() and torch.version.cuda:
else:
from vllm._C import ops

from vllm.model_executor.layers.linear import (LinearMethodBase,
Expand Down
10 changes: 4 additions & 6 deletions vllm/model_executor/layers/quantization/squeezellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.utils import is_hip


class SqueezeLLMConfig(QuantizationConfig):
Expand Down Expand Up @@ -114,14 +115,11 @@ def apply_weights(self,
lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if torch.cuda.is_available() and torch.version.hip:
out_float = torch.zeros(out_shape,
device="cuda",
dtype=torch.float)
if is_hip():
out_float = torch.zeros(out_shape, device="cuda", dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_float, lookup_table)
out = out_float.to(dtype=torch.float16)
# do something specific for HIP
elif torch.cuda.is_available() and torch.version.cuda:
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor.models import *
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
from vllm.utils import is_hip

# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
Expand Down Expand Up @@ -44,7 +45,7 @@
# in models such as Mistral
"MistralForCausalLM",
]
if torch.version.hip:
if is_hip():
for rocm_model in _ROCM_DISABLED_MODELS:
del _MODEL_REGISTRY[rocm_model]

Expand Down
4 changes: 4 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@ def random_uuid() -> str:
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()


def is_hip():
return torch.version.hip