Skip to content

Commit 2c2a05a

Browse files
njhillAlvant
authored andcommitted
[BugFix] Fix use of per-request seed with pipeline parallel (vllm-project#6698)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 3d3e19a commit 2c2a05a

21 files changed

+222
-137
lines changed

tests/samplers/test_rejection_sampler.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
150150
high=vocab_size,
151151
size=(batch_size, k),
152152
dtype=torch.int64)
153-
generators = [None] * batch_size
154153

155154
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
156-
draft_token_ids, generators)
155+
draft_token_ids)
157156

158157

159158
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
185184

186185
results = []
187186
for _ in range(n_rep):
188-
generators = [
189-
torch.Generator(
190-
device=device).manual_seed(i) if seeded_mask[i] else None
191-
for i in range(batch_size)
192-
]
187+
seeded_seqs = {
188+
i: torch.Generator(device=device).manual_seed(i)
189+
for i in range(batch_size) if seeded_mask[i]
190+
}
193191
results.append(
194192
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
195-
draft_token_ids, generators))
193+
draft_token_ids, seeded_seqs))
196194

197195
for i in range(batch_size):
198196
if seeded_mask[i]:
@@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
242240
raise AssertionError()
243241

244242
oob_token_ids[0][0] = rogue_token_id
245-
generators = [None] * batch_size
246243

247244
with pytest.raises(AssertionError):
248245
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
249-
draft_token_ids, generators)
246+
draft_token_ids)
250247

251248

252249
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@@ -417,15 +414,11 @@ def _estimate_rejection_sampling_pdf(
417414
dtype=torch.int64,
418415
device="cuda").repeat(num_samples, 1)
419416

420-
# unseeded
421-
generators = [None]
422-
423417
# Get output tokens via rejection sampling.
424418
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
425419
bonus_token_ids.to("cuda"),
426420
draft_probs.to("cuda"),
427-
draft_token_ids.to("cuda"),
428-
generators)
421+
draft_token_ids.to("cuda"))
429422

430423
# Remove bonus tokens
431424
output_token_ids = output_token_ids[:, :-1].flatten()

tests/samplers/test_sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
510510
))
511511
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
512512

513+
generators: Dict[str, torch.Generator] = {}
514+
513515
def test_sampling():
514516
sampling_metadata = SamplingMetadata.prepare(
515517
seq_group_metadata_list,
516518
seq_lens,
517519
query_lens=seq_lens,
518520
device=device,
519-
pin_memory=is_pin_memory_available())
521+
pin_memory=is_pin_memory_available(),
522+
generators=generators)
520523
sampler_output = sampler(logits=fake_logits,
521524
sampling_metadata=sampling_metadata)
522525

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
import pytest
2323

24-
from .conftest import run_greedy_equality_correctness_test
24+
from .conftest import (run_equality_correctness_test,
25+
run_greedy_equality_correctness_test)
2526

2627
# main model
2728
MAIN_MODEL = "JackFram/llama-160m"
@@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
7778
force_output_len=True)
7879

7980

81+
@pytest.mark.parametrize(
82+
"common_llm_kwargs",
83+
[{
84+
# Skip cuda graph recording for fast test.
85+
"enforce_eager": True,
86+
87+
# Required for spec decode.
88+
"use_v2_block_manager": True,
89+
90+
# Print spec metrics.
91+
"disable_log_stats": False,
92+
93+
# Precision
94+
"dtype": PRECISION,
95+
96+
# Main model
97+
"model": MAIN_MODEL,
98+
99+
# Speculative model
100+
"speculative_model": SPEC_MODEL,
101+
}])
102+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
103+
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
104+
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
105+
@pytest.mark.parametrize("output_len", [64])
106+
@pytest.mark.parametrize("batch_size", [1, 32])
107+
@pytest.mark.parametrize("temperature", [0.1, 1.0])
108+
@pytest.mark.parametrize("seed", [None])
109+
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
110+
batch_size: int, output_len: int,
111+
temperature: float):
112+
"""Verify seeded runs produce the same output."""
113+
run_equality_correctness_test(baseline_llm_generator,
114+
test_llm_generator,
115+
batch_size,
116+
max_output_len=output_len,
117+
temperature=temperature,
118+
seeded=True,
119+
force_output_len=True)
120+
121+
# Ensure this same test does fail if we _don't_ include per-request seeds
122+
with pytest.raises(AssertionError):
123+
run_equality_correctness_test(baseline_llm_generator,
124+
test_llm_generator,
125+
batch_size,
126+
max_output_len=output_len,
127+
temperature=temperature,
128+
seeded=False,
129+
force_output_len=True)
130+
131+
80132
@pytest.mark.parametrize(
81133
"common_llm_kwargs",
82134
[{

tests/spec_decode/e2e/test_seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"output_len",
3030
[
3131
# Use smaller output len for fast test.
32-
10,
32+
20,
3333
])
3434
@pytest.mark.parametrize("seed", [None])
3535
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,

tests/spec_decode/test_batch_expansion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
8686
input_seq_id,
8787
target_seq_id,
8888
token_ids,
89+
input_seq_group_metadata.sampling_params,
8990
)
9091

9192
assert output.request_id == input_seq_group_metadata.request_id

tests/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
178178
"usage": completion.usage,
179179
})
180180

181+
# test seeded random sampling
182+
completion = client.completions.create(model=model,
183+
prompt=prompt,
184+
max_tokens=5,
185+
seed=33,
186+
temperature=1.0)
187+
188+
results.append({
189+
"test": "seeded_sampling",
190+
"text": completion.choices[0].text,
191+
"finish_reason": completion.choices[0].finish_reason,
192+
"usage": completion.usage,
193+
})
194+
195+
# test seeded random sampling with multiple prompts
196+
completion = client.completions.create(model=model,
197+
prompt=[prompt, prompt],
198+
max_tokens=5,
199+
seed=33,
200+
temperature=1.0)
201+
202+
results.append({
203+
"test":
204+
"seeded_sampling",
205+
"text": [choice.text for choice in completion.choices],
206+
"finish_reason":
207+
[choice.finish_reason for choice in completion.choices],
208+
"usage":
209+
completion.usage,
210+
})
211+
181212
# test simple list
182213
batch = client.completions.create(
183214
model=model,

vllm/core/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
10291029
token_chunk_size=token_chunk_size,
10301030
lora_request=seq_group.lora_request,
10311031
computed_block_nums=common_computed_block_nums,
1032-
state=seq_group.state,
10331032
# `multi_modal_data` will only be present for the 1st comm
10341033
# between engine and worker.
10351034
# the subsequent comms can still use delta, but

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import cached_property
2-
from typing import List, Optional, Tuple
2+
from typing import Dict, List, Optional, Tuple
33

44
import torch
55
import torch.jit
@@ -36,7 +36,7 @@ def forward(
3636
bonus_token_ids: torch.Tensor,
3737
draft_probs: torch.Tensor,
3838
draft_token_ids: torch.Tensor,
39-
generators: List[Optional[torch.Generator]],
39+
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
4040
) -> torch.Tensor:
4141
"""Sample token ids using rejection sampling. This accepts or rejects
4242
tokens proposed by the draft model using the probability of each token
@@ -66,6 +66,9 @@ def forward(
6666
probabilities.
6767
shape = [batch_size, num_speculative_tokens]
6868
69+
seeded_seqs: Dict of batch row index to torch generator, for
70+
sequences using seeded generation.
71+
6972
Returns:
7073
output_token_ids: The token ids sampled via rejection sampling,
7174
or -1 if unable to sample a token because the previous token
@@ -83,7 +86,7 @@ def forward(
8386
target_probs,
8487
draft_probs,
8588
draft_token_ids,
86-
generators,
89+
seeded_seqs,
8790
))
8891

8992
output_token_ids = self._create_output(
@@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
100103
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
101104
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
102105
draft_token_ids: torch.Tensor, # [batch_size, k]
103-
generators: List[Optional[torch.Generator]],
106+
seeded_seqs: Optional[Dict[int, torch.Generator]],
104107
) -> Tuple[torch.Tensor, torch.Tensor]:
105108
"""Perform modified rejection sampling on each sequence.
106109
@@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(
117120

118121
# shape [batch_size, k]
119122
accepted = self._get_accepted(target_probs, draft_probs,
120-
draft_token_ids, generators)
123+
draft_token_ids, seeded_seqs)
121124

122125
recovered_probs = self._get_recovered_probs(
123126
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
124127

125-
seed_indices, non_seed_indices = self._split_batch_by_seeded(
126-
generators, k=k)
127-
128128
# NOTE: the recovered_probs are overwritten by this method.
129129
recovered_token_ids = _multinomial(
130130
recovered_probs,
131131
num_samples=1,
132132
k=k,
133-
generators=generators,
134-
seed_indices=seed_indices,
135-
# this arg is unused when None but torch.jit requires a list
136-
non_seed_indices=non_seed_indices or [],
133+
seeded_seqs=seeded_seqs or {},
137134
).reshape(batch_size, k)
138135

139136
return accepted, recovered_token_ids
@@ -143,7 +140,7 @@ def _get_accepted(
143140
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
144141
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
145142
draft_token_ids: torch.Tensor, # [batch_size, k]
146-
generators: List[Optional[torch.Generator]],
143+
seeded_seqs: Optional[Dict[int, torch.Generator]],
147144
) -> torch.Tensor:
148145
r"""Create bool matrix over the proposed draft tokens. If
149146
True, then a token can be accepted, else it should be
@@ -178,24 +175,26 @@ def _get_accepted(
178175
selected_target_probs = target_probs[batch_indices, probs_indicies,
179176
draft_token_ids]
180177

181-
seed_indices, non_seed_indices = self._split_batch_by_seeded(
182-
generators)
183-
184-
if len(seed_indices) == 0:
178+
if not seeded_seqs:
185179
uniform_rand = torch.rand_like(selected_target_probs)
186180
else:
187181
uniform_rand = torch.empty_like(selected_target_probs)
188182

189-
for idx in seed_indices:
190-
uniform_rand[idx, :] = torch.rand(1,
191-
k,
192-
dtype=self.probs_dtype,
193-
device=target_probs.device,
194-
generator=generators[idx])
195-
196-
if non_seed_indices:
197-
uniform_rand[non_seed_indices, :] = torch.rand(
198-
len(non_seed_indices),
183+
non_seeded_indices = []
184+
for idx in range(batch_size):
185+
generator = seeded_seqs.get(idx)
186+
if generator is None:
187+
non_seeded_indices.append(idx)
188+
else:
189+
uniform_rand[idx, :] = torch.rand(
190+
1,
191+
k,
192+
dtype=self.probs_dtype,
193+
device=target_probs.device,
194+
generator=generator)
195+
if non_seeded_indices:
196+
uniform_rand[non_seeded_indices, :] = torch.rand(
197+
len(non_seeded_indices),
199198
k,
200199
dtype=self.probs_dtype,
201200
device=target_probs.device)
@@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
272271
"""
273272
return torch.finfo(self.probs_dtype).tiny
274273

275-
# partition batch into indices for which a generator is provided
276-
# and indicies for which no generator is provided
277-
@staticmethod
278-
def _split_batch_by_seeded(
279-
generators: List[Optional[torch.Generator]],
280-
k: int = 1,
281-
) -> Tuple[List[int], Optional[List[int]]]:
282-
283-
if all(generator is None for generator in generators):
284-
seed_indices: List[int] = []
285-
non_seed_indices: Optional[List[int]] = None
286-
else:
287-
seed_indices, non_seed_indices = [], []
288-
for i, generator in enumerate(generators):
289-
if generator is None:
290-
non_seed_indices.extend(range(k * i, k * (i + 1)))
291-
else:
292-
seed_indices.extend(range(k * i, k * (i + 1)))
293-
294-
return seed_indices, non_seed_indices
295-
296274

297275
# torch.multinomial forces a GPU<->CPU sync.
298276
# Therefore, we use an optimized implementation instead that skips the sync.
@@ -304,9 +282,7 @@ def _multinomial(
304282
probs: torch.Tensor,
305283
num_samples: int,
306284
k: int,
307-
generators: List[Optional[torch.Generator]],
308-
seed_indices: List[int],
309-
non_seed_indices: List[int],
285+
seeded_seqs: Dict[int, torch.Generator],
310286
) -> torch.Tensor:
311287

312288
if num_samples > 1:
@@ -315,13 +291,20 @@ def _multinomial(
315291
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
316292
probs.shape[1]).contiguous().view(
317293
-1, probs.shape[1])
318-
319294
q = torch.empty_like(probs)
320-
if len(seed_indices) == 0:
295+
if not seeded_seqs:
321296
q.exponential_(1.0)
322297
else:
323-
q[non_seed_indices].exponential_(1.0)
324-
for idx in seed_indices:
325-
q[idx].exponential_(1.0, generator=generators[idx // k])
298+
non_seeded_indices: List[int] = []
299+
start = 0
300+
for idx in range(len(q) // k):
301+
end = start + k
302+
generator = seeded_seqs.get(idx)
303+
if generator is None:
304+
non_seeded_indices.extend(list(range(start, end)))
305+
else:
306+
q[start:end].exponential_(1.0, generator=generator)
307+
start = end
308+
q[non_seeded_indices].exponential_(1.0)
326309

327310
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

0 commit comments

Comments
 (0)