Skip to content

Commit

Permalink
Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (vllm-project#3290
Browse files Browse the repository at this point in the history
)

Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: guofangze <guofangze@kuaishou.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
  • Loading branch information
12 people authored Apr 3, 2024
1 parent 3dcb3e8 commit 2ff767b
Show file tree
Hide file tree
Showing 41 changed files with 2,592 additions and 142 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h

# Benchmark dataset
*.json
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")

#
# Supported/expected torch versions for CUDA/ROCm.
Expand Down
18 changes: 16 additions & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
enable_chunked_prefill=args.enable_chunked_prefill,
Expand Down Expand Up @@ -127,10 +128,23 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2'],
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--profile',
action='store_true',
Expand Down
22 changes: 19 additions & 3 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def run_vllm(
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
Expand All @@ -89,6 +90,7 @@ def run_vllm(
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
Expand Down Expand Up @@ -217,7 +219,8 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device,
args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf":
Expand Down Expand Up @@ -306,10 +309,23 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
"--device",
type=str,
Expand Down
13 changes: 10 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

# Using default kv_scale
kv_scale = 1.0

for _ in range(num_iters):
if version == "v1":
ops.paged_attention_v1(
Expand All @@ -112,6 +115,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
Expand All @@ -130,6 +134,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down Expand Up @@ -179,11 +184,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args()
print(args)

Expand Down
1 change: 1 addition & 0 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)

list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8_E4M3"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
Expand Down
2 changes: 1 addition & 1 deletion csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8_e5m2.cuh"
#include "dtype_fp8.cuh"
Loading

0 comments on commit 2ff767b

Please sign in to comment.