Skip to content

Commit

Permalink
[Minor] More fix of test_cache.py CI test failure (#2750)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU authored Feb 6, 2024
1 parent ed70c70 commit fe6d09a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
9 changes: 4 additions & 5 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,15 @@ def test_swap_blocks(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
src_device = f"{direction[0]}:{device}" if direction[
0] == "cuda" else direction[0]
dst_device = f"{direction[1]}:{device}" if direction[
1] == "cuda" else direction[1]

src_device = device if direction[0] == "cuda" else 'cpu'
dst_device = device if direction[1] == "cuda" else 'cpu'

src_blocks = random.sample(range(num_blocks), num_mappings)
# For the same device, mapping must not overlap
Expand Down
18 changes: 12 additions & 6 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,13 @@ def create_kv_caches_with_random(
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
if cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(key_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
key_cache.uniform_(-scale, scale)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_cache)

value_cache_shape = (num_blocks, num_heads, head_size, block_size)
Expand All @@ -270,9 +273,12 @@ def create_kv_caches_with_random(
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
if cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(value_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
value_cache.uniform_(-scale, scale)
else:
raise ValueError(
f"Does not support value cache of type {cache_dtype}")
value_caches.append(value_cache)
return key_caches, value_caches

0 comments on commit fe6d09a

Please sign in to comment.