Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypy] Pass type checking in vllm/inputs #11680

Merged
merged 15 commits into from
Jan 2, 2025
Prev Previous commit
Next Next commit
just ignore errors
Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com>
  • Loading branch information
CloseChoice committed Jan 2, 2025
commit 17254a9c8d51c0ccc5c841be6de13540db0b1059
29 changes: 4 additions & 25 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
Optional, Tuple, Union, cast)

import torch
from typing_extensions import (NotRequired, TypedDict, TypeIs, TypeVar,
assert_never)
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never

if TYPE_CHECKING:
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
Expand Down Expand Up @@ -177,26 +176,6 @@ class TokenInputs(TypedDict):
"""


def is_token_inputs(
inputs: Union[TokenInputs,
"MultiModalInputsV2"]) -> TypeIs[TokenInputs]:
"""
Helper function to make sure mypy narrows down the type to
TokenInputs.
"""
return inputs["type"] == "token"


def is_multimodal_inputs(
inputs: Union[TokenInputs, "MultiModalInputsV2"]
) -> TypeIs["MultiModalInputsV2"]:
"""
Helper function to make sure mypy narrows down the type to
MultiModalInputsV2.
"""
return inputs["type"] == "multimodal"


def token_inputs(
prompt_token_ids: List[int],
token_type_ids: Optional[List[int]] = None,
Expand Down Expand Up @@ -328,11 +307,11 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs

if is_token_inputs(inputs):
if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])
elif is_multimodal_inputs(inputs):
elif inputs["type"] == "multimodal":
# only the case when we use MultiModalInputsV2
return inputs.get("mm_hashes", [])
return inputs.get("mm_hashes", []) # type: ignore[return-value]

assert_never(inputs) # type: ignore[arg-type]

Expand Down
Loading