Skip to content

Commit c1d51ff

Browse files
DarkLight1337amitm02
authored andcommitted
[Doc] Update reproducibility doc and example (vllm-project#18741)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: amit <amit.man@gmail.com>
1 parent dcaa192 commit c1d51ff

File tree

2 files changed

+47
-43
lines changed

2 files changed

+47
-43
lines changed

docs/usage/reproducibility.md

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,52 @@
11
# Reproducibility
22

3-
## Overview
3+
vLLM does not guarantee the reproducibility of the results by default, for the sake of performance. You need to do the following to achieve
4+
reproducible results:
45

5-
The `seed` parameter in vLLM is used to control the random states for various random number generators. This parameter can affect the behavior of random operations in user code, especially when working with models in vLLM.
6+
- For V1: Turn off multiprocessing to make the scheduling deterministic by setting `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
7+
- For V0: Set the global seed (see below).
68

7-
## Default Behavior
9+
Example: <gh-file:examples/offline_inference/reproducibility.py>
810

9-
By default, the `seed` parameter is set to `None`. When the `seed` parameter is `None`, the global random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that the random operations will behave as expected, without any fixed random states.
11+
!!! warning
1012

11-
## Specifying a Seed
13+
Applying the above settings [changes the random state in user code](#locality-of-random-state).
1214

13-
If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly. This can be useful for reproducibility, as it ensures that the random operations produce the same results across multiple runs.
15+
!!! note
1416

15-
## Example Usage
17+
Even with the above settings, vLLM only provides reproducibility
18+
when it runs on the same hardware and the same vLLM version.
19+
Also, the online serving API (`vllm serve`) does not support reproducibility
20+
because it is almost impossible to make the scheduling deterministic in the
21+
online setting.
1622

17-
### Without Specifying a Seed
23+
## Setting the global seed
1824

19-
```python
20-
import random
21-
from vllm import LLM
25+
The `seed` parameter in vLLM is used to control the random states for various random number generators.
2226

23-
# Initialize a vLLM model without specifying a seed
24-
model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct")
27+
If a specific seed value is provided, the random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly.
2528

26-
# Try generating random numbers
27-
print(random.randint(0, 100)) # Outputs different numbers across runs
28-
```
29+
However, in some cases, setting the seed will also [change the random state in user code](#locality-of-random-state).
2930

30-
### Specifying a Seed
31+
### Default Behavior
3132

32-
```python
33-
import random
34-
from vllm import LLM
33+
In V0, the `seed` parameter defaults to `None`. When the `seed` parameter is `None`, the random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that each run of vLLM will produce different results if `temperature > 0`, as expected.
3534

36-
# Initialize a vLLM model with a specific seed
37-
model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", seed=42)
35+
In V1, the `seed` parameter defaults to `0` which sets the random state for each worker, so the results will remain consistent for each vLLM run even if `temperature > 0`.
3836

39-
# Try generating random numbers
40-
print(random.randint(0, 100)) # Outputs the same number across runs
41-
```
37+
!!! note
4238

43-
## Important Notes
39+
It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs
40+
for workflows such as speculative decoding.
41+
42+
For more information, see: <gh-pr:17929>
4443

45-
- If the `seed` parameter is not specified, the behavior of global random states remains unaffected.
46-
- If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set to that value.
47-
- This behavior can be useful for reproducibility but may lead to non-intuitive behavior if the user is not explicitly aware of it.
44+
### Locality of random state
4845

49-
## Conclusion
46+
The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM under the following conditions:
5047

51-
Understanding the behavior of the `seed` parameter in vLLM is crucial for ensuring the expected behavior of random operations in your code. By default, the `seed` parameter is set to `None`, which means that the global random states are not affected. However, specifying a seed value can help achieve reproducibility in your experiments.
48+
- For V0: The seed is specified.
49+
- For V1: The workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
50+
51+
By default, these conditions are not active so you can use vLLM without having to worry about
52+
accidentally making deterministic subsequent operations that rely on random state.
Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Demonstrates how to achieve reproducibility in vLLM.
4+
5+
Main article: https://docs.vllm.ai/en/latest/usage/reproducibility.html
6+
"""
7+
28
import os
9+
import random
310

411
from vllm import LLM, SamplingParams
512

6-
# vLLM does not guarantee the reproducibility of the results by default,
7-
# for the sake of performance. You need to do the following to achieve
8-
# reproducible results:
9-
# 1. Turn off multiprocessing to make the scheduling deterministic.
10-
# NOTE(woosuk): This is not needed and will be ignored for V0.
13+
# V1 only: Turn off multiprocessing to make the scheduling deterministic.
1114
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
12-
# 2. Fix the global seed for reproducibility. The default seed is None, which is
15+
16+
# V0 only: Set the global seed. The default seed is None, which is
1317
# not reproducible.
1418
SEED = 42
1519

16-
# NOTE(woosuk): Even with the above two settings, vLLM only provides
17-
# reproducibility when it runs on the same hardware and the same vLLM version.
18-
# Also, the online serving API (`vllm serve`) does not support reproducibility
19-
# because it is almost impossible to make the scheduling deterministic in the
20-
# online serving setting.
21-
2220
prompts = [
2321
"Hello, my name is",
2422
"The president of the United States is",
@@ -38,6 +36,11 @@ def main():
3836
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
3937
print("-" * 50)
4038

39+
# Try generating random numbers outside vLLM
40+
# The same number is output across runs, meaning that the random state
41+
# in the user code has been updated by vLLM
42+
print(random.randint(0, 100))
43+
4144

4245
if __name__ == "__main__":
4346
main()

0 commit comments

Comments
 (0)