Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mypy] Part 3 fix typing for nested directories for most of directory #4161

Merged
merged 10 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed
  • Loading branch information
rkooo567 committed Apr 18, 2024
commit 58f852e47c6aa15d5df3f33ccfccdf0f85e9f17b
61 changes: 35 additions & 26 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,29 @@ def ncclGetUniqueId() -> NcclUniqueId:

# enums
class ncclDataType_t(ctypes.c_int):
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
pass


class ncclDataTypeEnum:
ncclInt8 = ncclDataType_t(0)
ncclChar = ncclDataType_t(0)
ncclUint8 = ncclDataType_t(1)
ncclInt32 = ncclDataType_t(2)
ncclInt = ncclDataType_t(2)
ncclUint32 = ncclDataType_t(3)
ncclInt64 = ncclDataType_t(4)
ncclUint64 = ncclDataType_t(5)
ncclFloat16 = ncclDataType_t(6)
ncclHalf = ncclDataType_t(6)
ncclFloat32 = ncclDataType_t(7)
ncclFloat = ncclDataType_t(7)
ncclFloat64 = ncclDataType_t(8)
ncclDouble = ncclDataType_t(8)
ncclBfloat16 = ncclDataType_t(9)
ncclNumTypes = ncclDataType_t(10)

@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
def from_torch(cls, dtype: torch.dtype) -> "ncclDataType_t":
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
Expand All @@ -163,15 +167,19 @@ def from_torch(cls, dtype: torch.dtype) -> int:


class ncclRedOp_t(ctypes.c_int):
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
pass


class ncclRedOpEnum:
ncclSum = ncclRedOp_t(0)
ncclProd = ncclRedOp_t(1)
ncclMax = ncclRedOp_t(2)
ncclMin = ncclRedOp_t(3)
ncclAvg = ncclRedOp_t(4)
ncclNumOps = ncclRedOp_t(5)

@classmethod
def from_torch(cls, op: ReduceOp) -> int:
def from_torch(cls, op: ReduceOp) -> "ncclRedOp_t":
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
Expand Down Expand Up @@ -262,11 +270,12 @@ def all_reduce(self,
stream=None):
if stream is None:
stream = self.stream
breakpoint()
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm,
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0

Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Callable, List

from transformers import PreTrainedTokenizer

Expand All @@ -8,6 +8,7 @@
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter


class SequenceGroupOutputProcessor(ABC):
Expand All @@ -27,7 +28,7 @@ def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, List
from typing import Callable, List

from transformers import PreTrainedTokenizer

Expand All @@ -11,6 +11,7 @@
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@ def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
Expand Down
7 changes: 4 additions & 3 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Iterator, List, Tuple, Union
from typing import Dict, List, Tuple, Union

from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
Expand All @@ -10,6 +10,7 @@
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

logger = init_logger(__name__)

Expand All @@ -33,13 +34,13 @@ def __init__(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter: Iterator[int] = iter(seq_counter)
self.seq_counter = seq_counter
self.stop_checker = stop_checker

def process_outputs(self, sequence_group: SequenceGroup,
Expand Down