Skip to content

Commit

Permalink
[FIX] Fix class naming (vllm-project#1803)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Nov 28, 2023
1 parent b943890 commit 708e6c1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceGroupOutputs,
SequenceOutputs, SequenceStatus)
SequenceGroupMetadata, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
Expand Down Expand Up @@ -363,7 +363,7 @@ def _check_beam_search_early_stopping(
return current_worst_score >= highest_attainable_score

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutputs) -> None:
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
Expand All @@ -384,7 +384,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutputs, SequenceOutputs)
SequenceData, SequenceGroupOutput, SequenceOutput)

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -641,7 +641,7 @@ def _build_sampler_output(
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return sampler_output
20 changes: 10 additions & 10 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(
self.block_tables = block_tables


class SequenceOutputs:
class SequenceOutput:
"""The model output associated with a sequence.
Args:
Expand All @@ -374,40 +374,40 @@ def __init__(
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")

def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
if not isinstance(other, SequenceOutput):
raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)


class SequenceGroupOutputs:
"""The model outputs associated with a sequence group."""
class SequenceGroupOutput:
"""The model output associated with a sequence group."""

def __init__(
self,
samples: List[SequenceOutputs],
samples: List[SequenceOutput],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
self.prompt_logprobs = prompt_logprobs

def __repr__(self) -> str:
return (f"SequenceGroupOutputs(samples={self.samples}, "
return (f"SequenceGroupOutput(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")

def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutputs):
if not isinstance(other, SequenceGroupOutput):
raise NotImplementedError()
return (self.samples == other.samples
and self.prompt_logprobs == other.prompt_logprobs)


# For each sequence group, we generate a list of SequenceOutputs object,
# For each sequence group, we generate a list of SequenceOutput object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[SequenceGroupOutputs]
SamplerOutput = List[SequenceGroupOutput]

0 comments on commit 708e6c1

Please sign in to comment.