From 6958f2f692fbf048821e141c22e81c3c36dd14a1 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Tue, 6 Feb 2024 11:38:38 -0800 Subject: [PATCH] [Minor] More fix of test_cache.py CI test failure (#2750) --- tests/kernels/test_cache.py | 9 ++++----- vllm/utils.py | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index e0368d926d51a..d8dc74bc7b003 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -181,16 +181,15 @@ def test_swap_blocks( num_blocks: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) - src_device = f"{direction[0]}:{device}" if direction[ - 0] == "cuda" else direction[0] - dst_device = f"{direction[1]}:{device}" if direction[ - 1] == "cuda" else direction[1] + + src_device = device if direction[0] == "cuda" else 'cpu' + dst_device = device if direction[1] == "cuda" else 'cpu' src_blocks = random.sample(range(num_blocks), num_mappings) # For the same device, mapping must not overlap diff --git a/vllm/utils.py b/vllm/utils.py index 9e9126a2d6377..d7a3a3a2a9ef9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -258,10 +258,13 @@ def create_kv_caches_with_random( key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8_e5m2': + if cache_dtype == 'fp8_e5m2': _generate_random_fp8_e5m2(key_cache, -scale, scale) + elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + key_cache.uniform_(-scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) @@ -270,9 +273,12 @@ def create_kv_caches_with_random( value_cache = torch.empty(size=value_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8_e5m2': + if cache_dtype == 'fp8_e5m2': _generate_random_fp8_e5m2(value_cache, -scale, scale) + elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + value_cache.uniform_(-scale, scale) + else: + raise ValueError( + f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches