Skip to content

Commit

Permalink
[mypy] Enable type checking for test directory (#5017)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jun 15, 2024
1 parent 1b8a0d7 commit 0e9164b
Show file tree
Hide file tree
Showing 92 changed files with 510 additions and 379 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ jobs:
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy tests --config-file pyproject.toml
18 changes: 9 additions & 9 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple

import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
Expand Down Expand Up @@ -200,12 +200,12 @@ def calculate_metrics(
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens = []
actual_output_lens: List[int] = []
total_input = 0
completed = 0
itls = []
tpots = []
ttfts = []
itls: List[float] = []
tpots: List[float] = []
ttfts: List[float] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all
Expand Down Expand Up @@ -265,7 +265,7 @@ async def benchmark(
disable_tqdm: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS.get(backend)
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")

Expand All @@ -292,7 +292,7 @@ async def benchmark(
pbar = None if disable_tqdm else tqdm(total=len(input_requests))

benchmark_start_time = time.perf_counter()
tasks = []
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
request_func_input = RequestFuncInput(
Expand All @@ -310,7 +310,7 @@ async def benchmark(
pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

if not disable_tqdm:
if pbar is not None:
pbar.close()

benchmark_duration = time.perf_counter() - benchmark_start_time
Expand Down Expand Up @@ -466,7 +466,7 @@ def main(args: argparse.Namespace):

# Save config and results to json
if args.save_result:
result_json = {}
result_json: Dict[str, Any] = {}

# Setup
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def run_vllm(
)

# Add the requests to the engine.
prompts = []
sampling_params = []
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
Expand Down
10 changes: 5 additions & 5 deletions benchmarks/kernels/benchmark_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def dequant_no_scale(
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:

n = parts.sum().item()
n = int(parts.sum().item())

device = torch.device('cuda:0')

Expand Down Expand Up @@ -204,7 +204,7 @@ def main():
sys.stdout = sys.__stdout__


def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
methods):

# I didn't see visible improvements from increasing these, but feel free :)
Expand Down Expand Up @@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
print('')


def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
nbooks: int, bits: int, method) -> float:

n = parts.sum().item()
n = int(parts.sum().item())

device = torch.device('cuda:0')

Expand Down
8 changes: 5 additions & 3 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import List

import torch
import torch.utils.benchmark as benchmark
Expand All @@ -23,8 +24,9 @@
K_FULL_OPTS = [False, True]


def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
size_m, size_k, size_n):
def bench_run(results: List[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
size_m: int, size_k: int, size_n: int):
label = "Quant Matmul"

sub_label = ("{}, act={} k_full={}, b={}, g={}, "
Expand Down Expand Up @@ -156,7 +158,7 @@ def main(args):
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

results = []
results: List[benchmark.Measurement] = []

for model in args.models:
for layer in WEIGHT_SHAPES[model]:
Expand Down
26 changes: 18 additions & 8 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import time
from datetime import datetime
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, TypedDict

import ray
import torch
Expand All @@ -12,8 +12,17 @@
from vllm.model_executor.layers.fused_moe.fused_moe import *


class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int


def benchmark_config(
config: Dict[str, int],
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
Expand Down Expand Up @@ -92,7 +101,7 @@ def run():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

latencies = []
latencies: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
Expand All @@ -111,7 +120,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs = []
configs: List[BenchmarkConfig] = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
Expand Down Expand Up @@ -175,8 +184,8 @@ def tune(
topk: int,
dtype: torch.dtype,
use_fp8: bool,
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
search_space: List[BenchmarkConfig],
) -> BenchmarkConfig:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
Expand All @@ -199,10 +208,11 @@ def tune(
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config


def sort_config(config: Dict[str, int]) -> Dict[str, int]:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
Expand All @@ -214,7 +224,7 @@ def sort_config(config: Dict[str, int]) -> Dict[str, int]:


def save_configs(
configs: Dict[int, Dict[str, int]],
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
Expand Down
11 changes: 7 additions & 4 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import random
import time
from typing import Optional
from typing import List, Optional

import torch

Expand Down Expand Up @@ -54,14 +54,17 @@ def main(

# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
block_tables_lst.append(block_table)

block_tables = torch.tensor(block_tables_lst,
dtype=torch.int,
device=device)

# Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
from itertools import accumulate
from typing import Optional
from typing import List, Optional

import nvtx
import torch

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)


def benchmark_rope_kernels_multi_lora(
Expand Down Expand Up @@ -37,7 +38,7 @@ def benchmark_rope_kernels_multi_lora(
})
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes = []
non_batched_ropes: List[RotaryEmbedding] = []
for scaling_factor in scaling_factors:
non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
Expand Down
12 changes: 6 additions & 6 deletions examples/fp8/extract_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import glob
import json
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -19,7 +19,7 @@ def _prepare_hf_weights(
quantized_model_dir: str,
load_format: str = "auto",
fall_back_to_pt: bool = True,
) -> Tuple[str, List[str], bool]:
) -> Tuple[List[str], bool]:
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
Expand Down Expand Up @@ -94,7 +94,7 @@ def _hf_tensorfile_iterator(filename: str, load_format: str,


def _kv_scales_extractor(
hf_tensor_files: Iterable[str],
hf_tensor_files: List[str],
use_safetensors: bool,
rank_keyword: str = "rank",
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
Expand All @@ -115,7 +115,7 @@ def _kv_scales_extractor(
for char in rank_keyword:
assert not char.isdecimal(
), f"Rank keyword {rank_keyword} contains a numeric character!"
rank_scales_map = {}
rank_scales_map: Dict[int, Dict[int, float]] = {}
for tensor_file in hf_tensor_files:
try:
rank_idx = tensor_file.find(rank_keyword)
Expand All @@ -141,7 +141,7 @@ def _kv_scales_extractor(
raise

if rank not in rank_scales_map:
layer_scales_map = {}
layer_scales_map: Dict[int, float] = {}
rank_scales_map[rank] = layer_scales_map
else:
raise RuntimeError(
Expand Down Expand Up @@ -222,7 +222,7 @@ def _metadata_extractor(quantized_model_dir: str,
"does not exist.")
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))

result = {}
result: Dict[str, Any] = {}
for file in metadata_files:
with open(file) as f:
try:
Expand Down
8 changes: 4 additions & 4 deletions examples/offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
"""

from typing import Dict
from typing import Any, Dict, List

import numpy as np
import ray
Expand Down Expand Up @@ -40,8 +40,8 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
# The output is a list of RequestOutput objects that contain the prompt,
# generated text, and other information.
outputs = self.llm.generate(batch["text"], sampling_params)
prompt = []
generated_text = []
prompt: List[str] = []
generated_text: List[str] = []
for output in outputs:
prompt.append(output.prompt)
generated_text.append(' '.join([o.text for o in output.outputs]))
Expand Down Expand Up @@ -71,7 +71,7 @@ def scheduling_strategy_fn():
pg, placement_group_capture_child_tasks=True))


resources_kwarg = {}
resources_kwarg: Dict[str, Any] = {}
if tensor_parallel_size == 1:
# For tensor_parallel_size == 1, we simply set num_gpus=1.
resources_kwarg["num_gpus"] = 1
Expand Down
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy tests --config-file pyproject.toml


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
8 changes: 5 additions & 3 deletions tests/core/block/test_block_table.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import pytest

from vllm.core.block.block_table import BlockTable
Expand Down Expand Up @@ -28,7 +30,7 @@ def test_allocate_naive(block_size: int, sequence_len: int):
token_ids = list(range(sequence_len))
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))

block_tables = []
block_tables: List[BlockTable] = []
for i in range(5):
assert allocator.get_num_free_blocks(
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
Expand Down Expand Up @@ -73,7 +75,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int):
num_immutable_blocks_per_alloc = len(
chunked_tokens) - num_mutable_blocks_per_alloc

block_tables = []
block_tables: List[BlockTable] = []
for alloc_i in range(1, 6):

block_tables.append(
Expand Down Expand Up @@ -268,7 +270,7 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
)
block_table.allocate(token_ids=token_ids, device=Device.GPU)

appended_so_far = []
appended_so_far: List[int] = []
for append in chunk_list(token_ids_to_append, append_size):
block_table.append_token_ids(append)
appended_so_far.extend(append)
Expand Down
Loading

0 comments on commit 0e9164b

Please sign in to comment.