Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit c098739

Browse files
njhillRobert Shaw
authored andcommitted
[Misc] Various simplifications and typing fixes (vllm-project#5368)
1 parent 10e0353 commit c098739

File tree

8 files changed

+63
-90
lines changed

8 files changed

+63
-90
lines changed

vllm/engine/output_processor/multi_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def process_outputs(self, sequence_group: SequenceGroup,
7878

7979
# Since there's only one sequence per sequence group, we can take the
8080
# first sample.
81-
samples = [outputs[step].samples[0] for step in range(len(outputs))]
81+
samples = [output.samples[0] for output in outputs]
8282

8383
# -1 means the output token is not valid (eg. due to spec decode
8484
# rejecting tokens).

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,10 @@ def _create_output(
306306

307307
# Fill in the first k columns of the output tensor using masks and data
308308
# tensors.
309-
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
310-
-torch.ones_like(draft_token_ids))
309+
torch.where(accepted_mask,
310+
draft_token_ids,
311+
-torch.ones_like(draft_token_ids),
312+
out=output)
311313

312314
# Fill the last column.
313315
# We check output directly as accepted may have True values inconsistent

vllm/spec_decode/batch_expansion.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def score_proposals(
8080

8181
target_sampler_output = self._scorer_worker.execute_model(
8282
execute_model_req=execute_model_req.clone(
83-
seq_group_metadata_list=target_seq_group_metadata_list, ))
83+
seq_group_metadata_list=target_seq_group_metadata_list))
8484
assert len(target_sampler_output) == 1, "expected single-step output"
8585
target_sampler_output = target_sampler_output[0]
8686

@@ -140,8 +140,7 @@ def _expand_batch(
140140
num_scoring_tokens)
141141

142142
def _contract_batch(
143-
self, contracted_bs: int,
144-
target_sampler_output: List[SamplerOutput],
143+
self, contracted_bs: int, target_sampler_output: SamplerOutput,
145144
proposals: SpeculativeProposals, num_scoring_tokens: int,
146145
non_spec_indices: List[int], spec_indices: List[int],
147146
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -167,30 +166,16 @@ def _contract_batch(
167166
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
168167
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
169168

170-
target_token_ids = target_token_ids.squeeze().reshape(
171-
spec_expanded_bs, k + 1)
172-
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
173-
self._vocab_size)
174-
target_logprobs = target_logprobs.squeeze().reshape(
175-
spec_expanded_bs, k + 1, self._vocab_size)
176-
177-
all_tokens = torch.full(size=(contracted_bs, k + 1),
178-
fill_value=-1,
179-
device=self._device,
180-
dtype=torch.long)
181-
all_probs = torch.zeros(contracted_bs,
182-
k + 1,
183-
self._vocab_size,
184-
device=self._device,
185-
dtype=torch.float32)
186-
all_logprobs = torch.full(size=(
187-
contracted_bs,
188-
k + 1,
189-
self._vocab_size,
190-
),
191-
fill_value=-float("inf"),
192-
device=self._device,
193-
dtype=torch.float32)
169+
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
170+
target_probs = target_probs.reshape(*target_token_ids.shape,
171+
self._vocab_size)
172+
target_logprobs = target_logprobs.reshape(target_probs.shape)
173+
174+
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
175+
fill_value=-1)
176+
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
177+
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
178+
fill_value=-float("inf"))
194179

195180
if non_spec_indices:
196181
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids

vllm/spec_decode/spec_decode_worker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from vllm.config import SpeculativeConfig
67
from vllm.distributed.communication_op import broadcast_tensor_dict
78
from vllm.logger import init_logger
89
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
@@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
3031
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
3132
"""
3233
assert "speculative_config" in kwargs
33-
speculative_config = kwargs.get("speculative_config")
34+
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
3435
assert speculative_config is not None
3536

3637
target_worker = Worker(*args, **kwargs)
@@ -109,12 +110,11 @@ def create_worker(
109110
logger.info("Configuring SpecDecodeWorker with proposer=%s",
110111
type(proposer_worker))
111112

112-
return SpecDecodeWorker(
113-
proposer_worker,
114-
scorer_worker,
115-
disable_by_batch_size=disable_by_batch_size,
116-
rejection_sampler=RejectionSampler(
117-
disable_bonus_tokens=disable_bonus_tokens, ))
113+
return SpecDecodeWorker(proposer_worker,
114+
scorer_worker,
115+
disable_by_batch_size=disable_by_batch_size,
116+
rejection_sampler=RejectionSampler(
117+
disable_bonus_tokens=disable_bonus_tokens))
118118

119119
def __init__(
120120
self,

vllm/spec_decode/top1_proposer.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def _split_by_proposal_len(
148148
nonzero_proposal_len_indices,
149149
)
150150

151-
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
151+
@staticmethod
152+
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
152153
nonzero_proposal_len_indices, transposed):
153154
"""Remove sequences from nonzero_proposal_len_indices and reset
154155
their proposal_len to 0 the draft worker does not provide a proposal
@@ -207,7 +208,7 @@ def _merge_outputs(
207208
self,
208209
batch_size: int,
209210
proposal_len: int,
210-
maybe_sampler_output: Optional[SamplerOutput],
211+
maybe_sampler_output: Optional[List[SamplerOutput]],
211212
proposal_lens: List[int],
212213
nonzero_proposal_len_indices: List[int],
213214
sampler_transposed: bool,
@@ -218,25 +219,19 @@ def _merge_outputs(
218219
if maybe_sampler_output is None:
219220
# If no speculative tokens, the sampler output will be None.
220221
# In this case we return empty proposals.
221-
proposal_tokens = torch.full(
222-
size=(
223-
batch_size,
224-
proposal_len,
225-
),
226-
fill_value=-1,
227-
dtype=torch.long,
228-
device=self._device,
229-
)
230-
proposal_probs = torch.zeros(
231-
batch_size,
232-
proposal_len,
233-
self._vocab_size,
234-
dtype=torch.float32,
235-
device=self._device,
236-
)
237-
proposal_lens_tensor = torch.zeros(len(proposal_lens),
238-
dtype=torch.long,
239-
device=self._device)
222+
proposal_tokens = torch.tensor(-1,
223+
dtype=torch.long,
224+
device=self._device).expand(
225+
batch_size, proposal_len)
226+
proposal_probs = torch.tensor(0,
227+
dtype=torch.float32,
228+
device=self._device).expand(
229+
batch_size, proposal_len,
230+
self._vocab_size)
231+
proposal_lens_tensor = torch.tensor(0,
232+
dtype=torch.long,
233+
device=self._device).expand(
234+
len(proposal_lens))
240235
return proposal_tokens, proposal_probs, proposal_lens_tensor
241236

242237
sampler_output = maybe_sampler_output
@@ -246,18 +241,14 @@ def _merge_outputs(
246241
# Now, reformat the output GPU tensors such that each sequence has
247242
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
248243

249-
entire_proposal_tokens = torch.full(
244+
entire_proposal_tokens = proposal_tokens.new_full(
250245
size=(batch_size, *proposal_tokens.shape[1:]),
251246
fill_value=-1,
252-
dtype=torch.long,
253-
device=self._device,
254247
)
255248
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
256-
entire_proposal_probs = torch.zeros(
249+
entire_proposal_probs = proposal_probs.new_zeros(
257250
batch_size,
258251
*proposal_probs.shape[1:],
259-
dtype=torch.float32,
260-
device=self._device,
261252
)
262253
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
263254

vllm/spec_decode/util.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from contextlib import contextmanager
2-
from itertools import chain
32
from typing import Dict, List, Tuple
43

54
import torch
65

76
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
87
SamplerOutput, SequenceGroupMetadata,
9-
SequenceGroupOutput, SequenceOutput)
8+
SequenceOutput)
109

1110
SeqId = int
1211

@@ -16,11 +15,7 @@ def get_all_seq_ids(
1615
"""Given a list of SequenceGroupMetadata, create a list of all
1716
sequence ids.
1817
"""
19-
return list(
20-
chain.from_iterable([
21-
seq_group_metadata.seq_data.keys()
22-
for seq_group_metadata in seq_group_metadata_list
23-
]))
18+
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
2419

2520

2621
def get_all_num_logprobs(
@@ -68,7 +63,7 @@ def create_sequence_group_output(
6863
seq_id: SeqId,
6964
topk_token_ids: List[int],
7065
topk_logprobs: List[float],
71-
) -> SequenceGroupOutput:
66+
) -> CompletionSequenceGroupOutput:
7267
"""Create a SequenceGroupOutput given the sampling results.
7368
7469
Args:

vllm/transformers_utils/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, Type
22

33
from transformers import PretrainedConfig
44

@@ -9,7 +9,7 @@
99

1010
logger = init_logger(__name__)
1111

12-
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
12+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
1313
"chatglm": ChatGLMConfig,
1414
"dbrx": DbrxConfig,
1515
"mpt": MPTConfig,
@@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
6868
assert hasattr(config.text_config, "num_attention_heads")
6969
return config.text_config
7070
else:
71-
return config
71+
return config

vllm/worker/model_runner.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -527,28 +527,13 @@ def _prepare_model_input(
527527
)
528528
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
529529

530-
context_lens_tensor = torch.tensor(context_lens,
531-
dtype=torch.int,
532-
device=self.device)
533-
query_lens_tensor = torch.tensor(query_lens,
534-
dtype=torch.long,
535-
device=self.device)
536-
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
537-
dtype=torch.int32,
538-
device=self.device)
539-
540530
seq_lens_tensor = torch.tensor(seq_lens,
541531
dtype=torch.int,
542532
device=self.device)
543533
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
544534
dtype=torch.int32,
545535
device=self.device)
546536

547-
torch.cumsum(query_lens_tensor,
548-
dim=0,
549-
dtype=query_start_loc.dtype,
550-
out=query_start_loc[1:])
551-
552537
torch.cumsum(seq_lens_tensor,
553538
dim=0,
554539
dtype=seq_start_loc.dtype,
@@ -601,6 +586,21 @@ def _prepare_model_input(
601586
seq_start_loc=seq_start_loc,
602587
data_type=kv_cache_dtype)
603588
else:
589+
context_lens_tensor = torch.tensor(context_lens,
590+
dtype=torch.int,
591+
device=self.device)
592+
query_lens_tensor = torch.tensor(query_lens,
593+
dtype=torch.long,
594+
device=self.device)
595+
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
596+
dtype=torch.int32,
597+
device=self.device)
598+
599+
torch.cumsum(query_lens_tensor,
600+
dim=0,
601+
dtype=query_start_loc.dtype,
602+
out=query_start_loc[1:])
603+
604604
attn_metadata = self.attn_backend.make_metadata(
605605
num_prefills=num_prefills,
606606
slot_mapping=slot_mapping_tensor,

0 commit comments

Comments
 (0)