Skip to content

Commit c2b4a1b

Browse files
michaelfeilywang96
andauthored
[Doc] Add typing hints / mypy types cleanup (vllm-project#3816)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
1 parent e46a60a commit c2b4a1b

File tree

11 files changed

+90
-64
lines changed

11 files changed

+90
-64
lines changed

benchmarks/backend_request_func.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class RequestFuncInput:
2727
class RequestFuncOutput:
2828
generated_text: str = ""
2929
success: bool = False
30-
latency: float = 0
31-
ttft: float = 0 # Time to first token
30+
latency: float = 0.0
31+
ttft: float = 0.0 # Time to first token
3232
itl: List[float] = field(
3333
default_factory=list) # List of inter-token latencies
3434
prompt_len: int = 0
@@ -58,23 +58,24 @@ async def async_request_tgi(
5858
output = RequestFuncOutput()
5959
output.prompt_len = request_func_input.prompt_len
6060

61-
ttft = 0
61+
ttft = 0.0
6262
st = time.perf_counter()
6363
most_recent_timestamp = st
6464
try:
6565
async with session.post(url=api_url, json=payload) as response:
6666
if response.status == 200:
67-
async for chunk in response.content:
68-
chunk = chunk.strip()
69-
if not chunk:
67+
async for chunk_bytes in response.content:
68+
chunk_bytes = chunk_bytes.strip()
69+
if not chunk_bytes:
7070
continue
7171

72-
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
72+
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
73+
"data:")
7374

7475
data = json.loads(chunk)
7576
timestamp = time.perf_counter()
7677
# First token
77-
if ttft == 0:
78+
if ttft == 0.0:
7879
ttft = time.perf_counter() - st
7980
output.ttft = ttft
8081

@@ -119,23 +120,24 @@ async def async_request_trt_llm(
119120
output = RequestFuncOutput()
120121
output.prompt_len = request_func_input.prompt_len
121122

122-
ttft = 0
123+
ttft = 0.0
123124
st = time.perf_counter()
124125
most_recent_timestamp = st
125126
try:
126127
async with session.post(url=api_url, json=payload) as response:
127128
if response.status == 200:
128-
async for chunk in response.content:
129-
chunk = chunk.strip()
130-
if not chunk:
129+
async for chunk_bytes in response.content:
130+
chunk_bytes = chunk_bytes.strip()
131+
if not chunk_bytes:
131132
continue
132133

133-
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
134+
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
135+
"data:")
134136

135137
data = json.loads(chunk)
136138
timestamp = time.perf_counter()
137139
# First token
138-
if ttft == 0:
140+
if ttft == 0.0:
139141
ttft = time.perf_counter() - st
140142
output.ttft = ttft
141143

@@ -151,7 +153,7 @@ async def async_request_trt_llm(
151153
output.success = True
152154

153155
else:
154-
output.error = response.reason
156+
output.error = response.reason or ""
155157
output.success = False
156158
except Exception:
157159
output.success = False
@@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
195197
output.generated_text = parsed_resp["text"][0]
196198
output.success = True
197199
else:
198-
output.error = response.reason
200+
output.error = response.reason or ""
199201
output.success = False
200202
except Exception:
201203
output.success = False
@@ -234,19 +236,20 @@ async def async_request_openai_completions(
234236
output.prompt_len = request_func_input.prompt_len
235237

236238
generated_text = ""
237-
ttft = 0
239+
ttft = 0.0
238240
st = time.perf_counter()
239241
most_recent_timestamp = st
240242
try:
241243
async with session.post(url=api_url, json=payload,
242244
headers=headers) as response:
243245
if response.status == 200:
244-
async for chunk in response.content:
245-
chunk = chunk.strip()
246-
if not chunk:
246+
async for chunk_bytes in response.content:
247+
chunk_bytes = chunk_bytes.strip()
248+
if not chunk_bytes:
247249
continue
248250

249-
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
251+
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
252+
"data: ")
250253
if chunk == "[DONE]":
251254
latency = time.perf_counter() - st
252255
else:
@@ -255,7 +258,7 @@ async def async_request_openai_completions(
255258
if data["choices"][0]["text"]:
256259
timestamp = time.perf_counter()
257260
# First token
258-
if ttft == 0:
261+
if ttft == 0.0:
259262
ttft = time.perf_counter() - st
260263
output.ttft = ttft
261264

@@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
315318
output.prompt_len = request_func_input.prompt_len
316319

317320
generated_text = ""
318-
ttft = 0
321+
ttft = 0.0
319322
st = time.perf_counter()
320323
most_recent_timestamp = st
321324
try:
322325
async with session.post(url=api_url, json=payload,
323326
headers=headers) as response:
324327
if response.status == 200:
325-
async for chunk in response.content:
326-
chunk = chunk.strip()
327-
if not chunk:
328+
async for chunk_bytes in response.content:
329+
chunk_bytes = chunk_bytes.strip()
330+
if not chunk_bytes:
328331
continue
329332

330-
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
333+
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
334+
"data: ")
331335
if chunk == "[DONE]":
332336
latency = time.perf_counter() - st
333337
else:
@@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
337341
delta = data["choices"][0]["delta"]
338342
if delta.get("content", None):
339343
# First token
340-
if ttft == 0:
344+
if ttft == 0.0:
341345
ttft = time.perf_counter() - st
342346
output.ttft = ttft
343347

@@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
354358
output.success = True
355359
output.latency = latency
356360
else:
357-
output.error = response.reason
361+
output.error = response.reason or ""
358362
output.success = False
359363
except Exception:
360364
output.success = False

docs/source/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import logging
1414
import sys
15+
from typing import List
1516

1617
from sphinx.ext import autodoc
1718

@@ -45,7 +46,7 @@
4546
# List of patterns, relative to source directory, that match files and
4647
# directories to ignore when looking for source files.
4748
# This pattern also affects html_static_path and html_extra_path.
48-
exclude_patterns = []
49+
exclude_patterns: List[str] = []
4950

5051
# Exclude the prompt "$" when copying code
5152
copybutton_prompt_text = r"\$ "

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import sys
77
from shutil import which
8-
from typing import List
8+
from typing import Dict, List
99

1010
import torch
1111
from packaging.version import Version, parse
@@ -52,7 +52,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
5252

5353
class cmake_build_ext(build_ext):
5454
# A dict of extension directories that have been configured.
55-
did_config = {}
55+
did_config: Dict[str, bool] = {}
5656

5757
#
5858
# Determine number of compilation jobs and optionally nvcc compile threads.
@@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
261261
262262
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
263263
"""
264+
assert CUDA_HOME is not None, "CUDA_HOME is not set"
264265
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
265266
universal_newlines=True)
266267
output = nvcc_output.split()

vllm/core/block/interfaces.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from abc import ABC, abstractmethod, abstractproperty
2-
from typing import Dict, List, Optional, Protocol
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, FrozenSet, List, Optional, Protocol
33

44
from vllm.utils import Device
55

@@ -10,23 +10,28 @@ class Block(ABC):
1010
def append_token_ids(self, token_ids: List[int]) -> None:
1111
pass
1212

13-
@abstractproperty
13+
@property
14+
@abstractmethod
1415
def block_id(self) -> Optional[int]:
1516
pass
1617

17-
@abstractproperty
18+
@property
19+
@abstractmethod
1820
def token_ids(self) -> List[int]:
1921
pass
2022

21-
@abstractproperty
23+
@property
24+
@abstractmethod
2225
def num_empty_slots(self) -> int:
2326
pass
2427

25-
@abstractproperty
28+
@property
29+
@abstractmethod
2630
def is_full(self) -> bool:
2731
pass
2832

29-
@abstractproperty
33+
@property
34+
@abstractmethod
3035
def prev_block(self) -> Optional["Block"]:
3136
pass
3237

@@ -47,12 +52,13 @@ def __call__(
4752
class BlockAllocator(ABC):
4853

4954
@abstractmethod
50-
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
55+
def allocate_mutable(self, prev_block: Optional[Block],
56+
device: Device) -> Block:
5157
pass
5258

5359
@abstractmethod
5460
def allocate_immutable(self, prev_block: Optional[Block],
55-
token_ids: List[int]) -> Block:
61+
token_ids: List[int], device: Device) -> Block:
5662
pass
5763

5864
@abstractmethod
@@ -64,11 +70,12 @@ def fork(self, last_block: Block) -> List[Block]:
6470
pass
6571

6672
@abstractmethod
67-
def get_num_free_blocks(self) -> int:
73+
def get_num_free_blocks(self, device: Device) -> int:
6874
pass
6975

70-
@abstractproperty
71-
def all_block_ids(self) -> frozenset[int]:
76+
@property
77+
@abstractmethod
78+
def all_block_ids(self) -> FrozenSet[int]:
7279
pass
7380

7481
@abstractmethod

vllm/engine/metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
from dataclasses import dataclass
3-
from typing import Dict, List
3+
from typing import Dict, List, Protocol
44

55
import numpy as np
66
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@@ -119,6 +119,12 @@ class Stats:
119119
time_e2e_requests: List[float]
120120

121121

122+
class SupportsMetricsInfo(Protocol):
123+
124+
def metrics_info(self) -> Dict[str, str]:
125+
...
126+
127+
122128
class StatLogger:
123129
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
124130

@@ -135,7 +141,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
135141
self.labels = labels
136142
self.metrics = Metrics(labelnames=list(labels.keys()))
137143

138-
def info(self, type: str, obj: object) -> None:
144+
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
139145
if type == "cache_config":
140146
self.metrics.info_cache_config.info(obj.metrics_info())
141147

vllm/logger.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import sys
7+
from typing import Optional
78

89
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
910

@@ -26,7 +27,7 @@ def format(self, record):
2627

2728

2829
_root_logger = logging.getLogger("vllm")
29-
_default_handler = None
30+
_default_handler: Optional[logging.Handler] = None
3031

3132

3233
def _setup_logger():
@@ -55,7 +56,12 @@ def init_logger(name: str):
5556
# Use the same settings as above for root logger
5657
logger = logging.getLogger(name)
5758
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
59+
5860
if VLLM_CONFIGURE_LOGGING:
61+
if _default_handler is None:
62+
raise ValueError(
63+
"_default_handler is not set up. This should never happen!"
64+
" Please open an issue on Github.")
5965
logger.addHandler(_default_handler)
6066
logger.propagate = False
6167
return logger

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
247247

248248

249249
# Find dim range bounds based on rotations
250-
def _yarn_find_correction_range(low_rot: int,
251-
high_rot: int,
252-
dim: int,
253-
base: float = 10000,
254-
max_position_embeddings: int = 2048) -> int:
250+
def _yarn_find_correction_range(
251+
low_rot: int,
252+
high_rot: int,
253+
dim: int,
254+
base: float = 10000,
255+
max_position_embeddings: int = 2048) -> Tuple[int, int]:
255256
low = math.floor(
256257
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
257258
high = math.ceil(
@@ -293,8 +294,8 @@ def __init__(
293294
*,
294295
extrapolation_factor: float = 1,
295296
attn_factor: float = 1,
296-
beta_fast: float = 32,
297-
beta_slow: float = 1,
297+
beta_fast: int = 32,
298+
beta_slow: int = 1,
298299
) -> None:
299300
self.scaling_factor = scaling_factor
300301
self.extrapolation_factor = extrapolation_factor

vllm/transformers_utils/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Optional
1+
from typing import Dict, Optional
22

33
from transformers import AutoConfig, PretrainedConfig
44

55
from vllm.transformers_utils.configs import *
66

7-
_CONFIG_REGISTRY = {
7+
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
88
"chatglm": ChatGLMConfig,
99
"dbrx": DbrxConfig,
1010
"mpt": MPTConfig,

vllm/transformers_utils/configs/dbrx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
logger = logging.get_logger(__name__)
1414

15-
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
15+
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
1616

1717

1818
class DbrxAttentionConfig(PretrainedConfig):

0 commit comments

Comments
 (0)