Skip to content

Commit

Permalink
[Doc] Add typing hints / mypy types cleanup (#3816)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
  • Loading branch information
michaelfeil and ywang96 authored Apr 12, 2024
1 parent e46a60a commit c2b4a1b
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 64 deletions.
62 changes: 33 additions & 29 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0 # Time to first token
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0
Expand Down Expand Up @@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")

data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand Down Expand Up @@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")

data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand All @@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True

else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
Expand Down Expand Up @@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
Expand Down Expand Up @@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
Expand All @@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand Down Expand Up @@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
Expand All @@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

Expand All @@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True
output.latency = latency
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import logging
import sys
from typing import List

from sphinx.ext import autodoc

Expand Down Expand Up @@ -45,7 +46,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[str] = []

# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import sys
from shutil import which
from typing import List
from typing import Dict, List

import torch
from packaging.version import Version, parse
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:

class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config = {}
did_config: Dict[str, bool] = {}

#
# Determine number of compilation jobs and optionally nvcc compile threads.
Expand Down Expand Up @@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
Expand Down
31 changes: 19 additions & 12 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol

from vllm.utils import Device

Expand All @@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None:
pass

@abstractproperty
@property
@abstractmethod
def block_id(self) -> Optional[int]:
pass

@abstractproperty
@property
@abstractmethod
def token_ids(self) -> List[int]:
pass

@abstractproperty
@property
@abstractmethod
def num_empty_slots(self) -> int:
pass

@abstractproperty
@property
@abstractmethod
def is_full(self) -> bool:
pass

@abstractproperty
@property
@abstractmethod
def prev_block(self) -> Optional["Block"]:
pass

Expand All @@ -47,12 +52,13 @@ def __call__(
class BlockAllocator(ABC):

@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass

@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
token_ids: List[int], device: Device) -> Block:
pass

@abstractmethod
Expand All @@ -64,11 +70,12 @@ def fork(self, last_block: Block) -> List[Block]:
pass

@abstractmethod
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Device) -> int:
pass

@abstractproperty
def all_block_ids(self) -> frozenset[int]:
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass

@abstractmethod
Expand Down
10 changes: 8 additions & 2 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Protocol

import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
Expand Down Expand Up @@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float]


class SupportsMetricsInfo(Protocol):

def metrics_info(self) -> Dict[str, str]:
...


class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""

Expand All @@ -135,7 +141,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))

def info(self, type: str, obj: object) -> None:
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())

Expand Down
8 changes: 7 additions & 1 deletion vllm/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import sys
from typing import Optional

VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))

Expand All @@ -26,7 +27,7 @@ def format(self, record):


_root_logger = logging.getLogger("vllm")
_default_handler = None
_default_handler: Optional[logging.Handler] = None


def _setup_logger():
Expand Down Expand Up @@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))

if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler)
logger.propagate = False
return logger
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,


# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
Expand Down Expand Up @@ -293,8 +294,8 @@ def __init__(
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from typing import Dict, Optional

from transformers import AutoConfig, PretrainedConfig

from vllm.transformers_utils.configs import *

_CONFIG_REGISTRY = {
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/configs/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logger = logging.get_logger(__name__)

DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore


class DbrxAttentionConfig(PretrainedConfig):
Expand Down
Loading

0 comments on commit c2b4a1b

Please sign in to comment.