Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mypy] Part 3 fix typing for nested directories for most of directory #4161

Merged
merged 10 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,20 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml

mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml

26 changes: 12 additions & 14 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'

# Run mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml

# TODO(sang): Follow up
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml


CODESPELL_EXCLUDES=(
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ ignore = [
python_version = "3.8"

ignore_missing_imports = true
check_untyped_defs = true
check_untyped_defs = true
follow_imports = "skip"

files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]


[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down Expand Up @@ -136,6 +136,7 @@ def forward(
kv_scale)

if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _run_memory_efficient_xformers_forward(
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.prompt_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
Expand Down
1 change: 1 addition & 0 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def append_token_ids(self,
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._blocks is not None

self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
Expand Down
6 changes: 4 additions & 2 deletions vllm/core/block/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
refcounter: RefCounter,
allocator: BlockAllocator,
):
self._copy_on_writes = defaultdict(list)
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator

Expand Down Expand Up @@ -138,6 +138,8 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
prev_block=block.prev_block).block_id

# Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id)

return block_id
Expand Down Expand Up @@ -180,6 +182,6 @@ def recurse(block: Block, lst: List[Block]) -> None:
recurse(block.prev_block, lst)
lst.append(block)

all_blocks = []
all_blocks: List[Block] = []
recurse(last_block, all_blocks)
return all_blocks
6 changes: 2 additions & 4 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def __call__(
class BlockAllocator(ABC):

@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass

@abstractmethod
Expand Down Expand Up @@ -98,8 +97,7 @@ class NoFreeBlocksError(ValueError):
class DeviceAwareBlockAllocator(BlockAllocator):

@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass

@abstractmethod
Expand Down
16 changes: 9 additions & 7 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Optional
from typing import Any, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -17,7 +17,7 @@

logger = init_logger(__name__)

_CA_HANDLE = None
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]

Expand Down Expand Up @@ -50,7 +50,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size)
Expand Down Expand Up @@ -110,7 +110,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
Expand All @@ -128,6 +128,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)

return None


@contextmanager
def _nvml():
Expand Down Expand Up @@ -210,14 +212,14 @@ def _get_ipc_meta(self, inp: torch.Tensor):
return self._gather_ipc_meta(shard_data)

def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)

handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0])
offsets.append(all_data[i][1])
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
return handles, offsets

def register_buffer(self, inp: torch.Tensor):
Expand Down
11 changes: 5 additions & 6 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def ncclGetUniqueId() -> NcclUniqueId:
]


# enums
class ncclDataType_t(ctypes.c_int):
class ncclDataType_t:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
Expand All @@ -128,7 +127,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10

@classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
def from_torch(cls, dtype: torch.dtype) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is weird that mypy can't handle this, it should defenitely return the c_int

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm when I added a breakpoint, it does return int.

(Pdb) ncclDataType_t.from_torch(tensor.dtype)
7
(Pdb) type(ncclDataType_t.from_torch(tensor.dtype))
<class 'int'>

Do you think it is a bug?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, technically, this function indeed returns int. ctypes will automatically convert int to c_int . Note that ctypes.c_int is a class/container to hold an int, and itself is not an int.

TL;DR; def from_torch(cls, dtype: torch.dtype) -> int: is correct.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# enums
class ncclDataType_t(ctypes.c_int):
    pass


class ncclDataTypeEnum:
    ncclInt8 = ncclDataType_t(0)
    ncclChar = ncclDataType_t(0)
    ncclUint8 = ncclDataType_t(1)
    ncclInt32 = ncclDataType_t(2)
    ncclInt = ncclDataType_t(2)
    ncclUint32 = ncclDataType_t(3)
    ncclInt64 = ncclDataType_t(4)
    ncclUint64 = ncclDataType_t(5)
    ncclFloat16 = ncclDataType_t(6)
    ncclHalf = ncclDataType_t(6)
    ncclFloat32 = ncclDataType_t(7)
    ncclFloat = ncclDataType_t(7)
    ncclFloat64 = ncclDataType_t(8)
    ncclDouble = ncclDataType_t(8)
    ncclBfloat16 = ncclDataType_t(9)
    ncclNumTypes = ncclDataType_t(10)

    @classmethod
    def from_torch(cls, dtype: torch.dtype) -> "ncclDataType_t":
        if dtype == torch.int8:
            return cls.ncclInt8
        if dtype == torch.uint8:
            return cls.ncclUint8
        if dtype == torch.int32:
            return cls.ncclInt32
        if dtype == torch.int64:
            return cls.ncclInt64
        if dtype == torch.float16:
            return cls.ncclFloat16
        if dtype == torch.float32:
            return cls.ncclFloat32
        if dtype == torch.float64:
            return cls.ncclFloat64
        if dtype == torch.bfloat16:
            return cls.ncclBfloat16
        raise ValueError(f"Unsupported dtype: {dtype}")

I fixed this way so that it returns the correct type.

Copy link
Collaborator Author

@rkooo567 rkooo567 Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh @youkaichao I just saw your msg. So just returning int makes more sense if there's automatic conversion then ^?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think automatic conversion makes more sense. The above code is difficult to read.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think technically inheriting c_int here is not necessary then? it is just like a simple enum

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's ChatGPT's fault 😉

Copy link
Collaborator Author

@rkooo567 rkooo567 Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. PTAL!
81313dc

if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
Expand Down Expand Up @@ -157,7 +156,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5

@classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
Expand All @@ -180,8 +179,8 @@ def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int,
ctypes.c_void_p, ctypes.c_void_p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intuitive. Let's keep ncclDataType_t and ncclRedOp_t here.

Does mypy understand the call signature of _c_ncclAllReduce? I suppose it will just ignore them. And we can leave a comment here, saying that int will be automatically converted by ctypes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found ncclDataType_t being a random enum and type at the same time a bit weird though. Why don't we then just

ncclDataType_t(ctype.c_int):
    pass

ncclDataEnum:
    ... = 0
    .... = 1

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that's also doable.

Please use ncclDataType_t = ctype.c_int, and ncclDataTypeEnum

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm!

]

# equivalent to c declaration:
Expand Down
5 changes: 4 additions & 1 deletion vllm/distributed/device_communicators/pynccl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
Expand All @@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op)


Expand All @@ -62,8 +64,9 @@ def destroy_process_group() -> None:

def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size


def get_nccl_backend():
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Callable, List

from transformers import PreTrainedTokenizer

Expand All @@ -8,6 +8,7 @@
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter


class SequenceGroupOutputProcessor(ABC):
Expand All @@ -27,7 +28,7 @@ def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, List
from typing import Callable, List

from transformers import PreTrainedTokenizer

Expand All @@ -11,6 +11,7 @@
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@ def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Tuple, Union
from typing import Dict, List, Tuple, Union

from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
Expand All @@ -10,6 +10,7 @@
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@ def __init__(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
Expand Down Expand Up @@ -68,7 +69,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
Expand All @@ -91,7 +92,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
Expand Down
Loading
Loading