Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Dynamic RoPE scaling #4638

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,58 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() is None

mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


def test_rope_scaling():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}

llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert llama_model_config.max_model_len == 8192

llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert llama_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == LONGCHAT_ROPE_SCALING
assert longchat_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
Expand Down Expand Up @@ -84,6 +87,7 @@ def __init__(
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
Expand All @@ -102,6 +106,7 @@ def __init__(
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.quantization_param_path = quantization_param_path
Expand All @@ -116,7 +121,7 @@ def __init__(
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
code_revision, self.rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
Expand Down
18 changes: 13 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -49,6 +50,7 @@ class EngineArgs:
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
Expand Down Expand Up @@ -330,6 +332,11 @@ def add_cli_args(
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
Expand Down Expand Up @@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig:
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init, self.served_model_name)
self.code_revision, self.rope_scaling, self.tokenizer_revision,
self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager,
self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
10 changes: 6 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"rope_scaling=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__,
Expand All @@ -117,6 +118,7 @@ def __init__(
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.rope_scaling,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down
10 changes: 9 additions & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from transformers import AutoConfig, PretrainedConfig

from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)

logger = init_logger(__name__)

_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
Expand All @@ -18,7 +21,8 @@
def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model,
Expand All @@ -41,6 +45,10 @@ def get_config(model: str,
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
if rope_scaling is not None:
logger.info("Updating rope_scaling from %r to %r",
getattr(config, "rope_scaling", None), rope_scaling)
config.update({"rope_scaling": rope_scaling})
return config


Expand Down
Loading