Skip to content

[Misc] refactor context extension #19246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions examples/offline_inference/context_extension.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script demonstrates how to extend the context length
of a Qwen model using the YARN method (rope_scaling)
and run a simple chat example.

Usage:
python examples/offline_inference/context_extension.py
"""

from vllm import LLM, SamplingParams

rope_theta = 1000000
original_max_position_embeddings = 32768
factor = 4.0

# Use yarn to extend context
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
},
"max_model_len": int(original_max_position_embeddings * factor),
}

llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=128,
)

conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)

def create_llm():
rope_theta = 1000000
original_max_position_embeddings = 32768
factor = 4.0

# Use yarn to extend context
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
},
"max_model_len": int(original_max_position_embeddings * factor),
}

llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
return llm


def run_llm_chat(llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=128,
)

conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
return outputs


def print_outputs(outputs):
Expand All @@ -44,4 +58,11 @@ def print_outputs(outputs):
print("-" * 80)


print_outputs(outputs)
def main():
llm = create_llm()
outputs = run_llm_chat(llm)
print_outputs(outputs)


if __name__ == "__main__":
main()