Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Add typing hints / mypy types cleanup #3816

Merged
merged 15 commits into from
Apr 12, 2024
62 changes: 33 additions & 29 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0 # Time to first token
latency: float = 0.0
ttft: float = 0.0 # Time to first token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so unfortunate mypy cannot handle this :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand mypy hints a potential source for an actual bug - I think in this context any plausible use of that int should however also be compatible with int or float.
E.g. int(str(0)) works but int(str(0.0)) breaks.

itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0
Expand Down Expand Up @@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")

data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand Down Expand Up @@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")

data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand All @@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True

else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
Expand Down Expand Up @@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
Expand Down Expand Up @@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
Expand All @@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand Down Expand Up @@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
Expand All @@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand All @@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True
output.latency = latency
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import sys
import typing as t

from sphinx.ext import autodoc

Expand Down Expand Up @@ -48,7 +49,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: t.List[str] = []

# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import sys
from shutil import which
from typing import List
from typing import Dict, List

import torch
from packaging.version import Version, parse
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:

class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config = {}
did_config: Dict[str, bool] = {}

#
# Determine number of compilation jobs and optionally nvcc compile threads.
Expand Down Expand Up @@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:

Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
Expand Down
31 changes: 19 additions & 12 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol

from vllm.utils import Device

Expand All @@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None:
pass

@abstractproperty
@property
@abstractmethod
def block_id(self) -> Optional[int]:
Comment on lines -13 to +14
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note to the audience why this change makes sense
https://docs.python.org/3.10/library/abc.html#abc.abstractproperty

pass

@abstractproperty
@property
@abstractmethod
def token_ids(self) -> List[int]:
pass

@abstractproperty
@property
@abstractmethod
def num_empty_slots(self) -> int:
pass

@abstractproperty
@property
@abstractmethod
def is_full(self) -> bool:
pass

@abstractproperty
@property
@abstractmethod
def prev_block(self) -> Optional["Block"]:
pass

Expand All @@ -47,12 +52,13 @@ def __call__(
class BlockAllocator(ABC):

@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass

@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
token_ids: List[int], device: Device) -> Block:
pass

@abstractmethod
Expand All @@ -64,11 +70,12 @@ def fork(self, last_block: Block) -> List[Block]:
pass

@abstractmethod
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Device) -> int:
pass

@abstractproperty
def all_block_ids(self) -> frozenset[int]:
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass

@abstractmethod
Expand Down
10 changes: 8 additions & 2 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Protocol

import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
Expand Down Expand Up @@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float]


class SupportsMetricsInfo(Protocol):

def metrics_info(self) -> Dict[str, str]:
...


class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""

Expand All @@ -135,7 +141,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))

def info(self, type: str, obj: object) -> None:
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: should we change the var name obj?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would slight improve readability, but I am not sure if there is any dependency on the kwarg name. obj is no a great variable name, I agree.

if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())

Expand Down
7 changes: 6 additions & 1 deletion vllm/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import sys
from typing import Optional

VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))

Expand All @@ -26,7 +27,7 @@ def format(self, record):


_root_logger = logging.getLogger("vllm")
_default_handler = None
_default_handler: Optional[logging.Handler] = None


def _setup_logger():
Expand Down Expand Up @@ -55,7 +56,11 @@ def init_logger(name: str):
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))

if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up, this should not happen.")
logger.addHandler(_default_handler)
logger.propagate = False
return logger
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,


# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
Expand Down Expand Up @@ -293,8 +294,8 @@ def __init__(
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from typing import Dict, Optional

from transformers import AutoConfig, PretrainedConfig

from vllm.transformers_utils.configs import *

_CONFIG_REGISTRY = {
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/configs/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logger = logging.get_logger(__name__)

DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore


class DbrxAttentionConfig(PretrainedConfig):
Expand Down
Loading
Loading