Skip to content

Commit 29ff86d

Browse files
reidliu41minpeter
authored andcommitted
[Misc] refactor context extension (vllm-project#19246)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent a08f72f commit 29ff86d

File tree

1 file changed

+51
-30
lines changed

1 file changed

+51
-30
lines changed
Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,51 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This script demonstrates how to extend the context length
5+
of a Qwen model using the YARN method (rope_scaling)
6+
and run a simple chat example.
7+
8+
Usage:
9+
python examples/offline_inference/context_extension.py
10+
"""
311

412
from vllm import LLM, SamplingParams
513

6-
rope_theta = 1000000
7-
original_max_position_embeddings = 32768
8-
factor = 4.0
9-
10-
# Use yarn to extend context
11-
hf_overrides = {
12-
"rope_theta": rope_theta,
13-
"rope_scaling": {
14-
"rope_type": "yarn",
15-
"factor": factor,
16-
"original_max_position_embeddings": original_max_position_embeddings,
17-
},
18-
"max_model_len": int(original_max_position_embeddings * factor),
19-
}
20-
21-
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
22-
23-
sampling_params = SamplingParams(
24-
temperature=0.8,
25-
top_p=0.95,
26-
max_tokens=128,
27-
)
28-
29-
conversation = [
30-
{"role": "system", "content": "You are a helpful assistant"},
31-
{"role": "user", "content": "Hello"},
32-
{"role": "assistant", "content": "Hello! How can I assist you today?"},
33-
]
34-
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
14+
15+
def create_llm():
16+
rope_theta = 1000000
17+
original_max_position_embeddings = 32768
18+
factor = 4.0
19+
20+
# Use yarn to extend context
21+
hf_overrides = {
22+
"rope_theta": rope_theta,
23+
"rope_scaling": {
24+
"rope_type": "yarn",
25+
"factor": factor,
26+
"original_max_position_embeddings": original_max_position_embeddings,
27+
},
28+
"max_model_len": int(original_max_position_embeddings * factor),
29+
}
30+
31+
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
32+
return llm
33+
34+
35+
def run_llm_chat(llm):
36+
sampling_params = SamplingParams(
37+
temperature=0.8,
38+
top_p=0.95,
39+
max_tokens=128,
40+
)
41+
42+
conversation = [
43+
{"role": "system", "content": "You are a helpful assistant"},
44+
{"role": "user", "content": "Hello"},
45+
{"role": "assistant", "content": "Hello! How can I assist you today?"},
46+
]
47+
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
48+
return outputs
3549

3650

3751
def print_outputs(outputs):
@@ -44,4 +58,11 @@ def print_outputs(outputs):
4458
print("-" * 80)
4559

4660

47-
print_outputs(outputs)
61+
def main():
62+
llm = create_llm()
63+
outputs = run_llm_chat(llm)
64+
print_outputs(outputs)
65+
66+
67+
if __name__ == "__main__":
68+
main()

0 commit comments

Comments
 (0)