Skip to content

Commit

Permalink
[CI/Build] Fix mypy errors (vllm-project#6968)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jul 31, 2024
1 parent f230cc2 commit 9f0e69b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import functools
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -336,7 +336,7 @@ def scaled_fp8_quant(
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape = input.shape
shape: Union[Tuple[int, int], torch.Size] = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
Expand Down
6 changes: 2 additions & 4 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase):
"""

@staticmethod
def _try_concat(
tensors: List[NestedTensors],
) -> Union[GenericSequence[NestedTensors], NestedTensors]:
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
Expand Down Expand Up @@ -105,7 +103,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
return {
k: MultiModalInputs._try_concat(item_list)
for k, item_list in item_lists.items()
} # type: ignore
}

@staticmethod
def as_kwargs(
Expand Down

0 comments on commit 9f0e69b

Please sign in to comment.