Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[mypy][5/N] Support all typing on model executor (vllm-project#4427)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored and robertgshaw2-redhat committed May 6, 2024
1 parent 8ab0de8 commit 7f5a450
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ jobs:
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")


@lru_cache
Expand Down
12 changes: 11 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __init__(
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method = UnquantizedLinearMethod()
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)

Expand Down Expand Up @@ -162,6 +163,8 @@ def __init__(
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)

# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
Expand All @@ -175,6 +178,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
Expand Down Expand Up @@ -223,6 +227,8 @@ def __init__(
self.output_size_per_partition = divide(output_size, tp_size)
if output_sizes is None:
output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
Expand Down Expand Up @@ -261,6 +267,7 @@ def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
Expand Down Expand Up @@ -610,6 +617,8 @@ def __init__(
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
Expand Down Expand Up @@ -659,6 +668,7 @@ def forward(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type
from typing import Dict, Type

from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Expand All @@ -9,7 +9,7 @@
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig

QUANTIZATION_METHODS = {
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": Fp8Config,
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -76,8 +76,16 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"quantization config.")

@abstractmethod
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
"""Get the quantize method to use for the quantized layer."""
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError

@abstractmethod
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/squeezellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
return cls(weight_bits)

def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return
return None

def get_scaled_act_names(self) -> List[str]:
return []
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ def forward(
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
idx.device)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

Expand Down
47 changes: 27 additions & 20 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceGroupOutput, SequenceOutput)

# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]


class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs.
Expand Down Expand Up @@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
have not been generated yet
"""
# list of indices in logits that will be set to -inf
logits_to_penalize = []
logits_to_penalize: List[Tuple[int, int]] = []
logits_applied = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
Expand Down Expand Up @@ -269,7 +272,7 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
Expand All @@ -284,7 +287,7 @@ def _greedy_sample(
"""
samples = samples.tolist()
sample_idx = 0
results = []
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
Expand All @@ -304,7 +307,7 @@ def _greedy_sample(
def _random_sample(
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
Expand All @@ -320,7 +323,7 @@ def _random_sample(
# Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
results = []
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
Expand Down Expand Up @@ -348,7 +351,7 @@ def _random_sample(
def _beam_search_sample(
selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
) -> SampleResultType:
"""Run beam sampling on a given samples.
Args:
Expand All @@ -370,7 +373,7 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
Expand All @@ -391,16 +394,16 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs = [
cumulative_logprobs: List[int] = [
seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1))
cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
Expand Down Expand Up @@ -452,8 +455,10 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
Expand Down Expand Up @@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
) -> SampleResultType:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
Expand Down Expand Up @@ -632,7 +639,7 @@ def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
Expand Down Expand Up @@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]],
sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs.
Expand Down Expand Up @@ -751,8 +758,8 @@ def _get_logprobs(
assert len(next_token_ids) == len(query_indices)

if len(query_indices) == 0:
empty_sampled_logprob = []
empty_prompt_logprob = None
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]

query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
Expand Down Expand Up @@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,


def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
Expand Down Expand Up @@ -1009,7 +1016,7 @@ def _build_sampler_output(
)


def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/model_loader/tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _construct_tensorizer_args(self) -> "TensorizerArgs":
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
return TensorizerArgs(**tensorizer_args) # type: ignore

def verify_with_parallel_config(
self,
Expand Down Expand Up @@ -270,8 +270,10 @@ def __init__(self, tensorizer_config: TensorizerConfig,
self.model = self._init_model()

def _init_model(self):
assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
Expand Down

0 comments on commit 7f5a450

Please sign in to comment.