Skip to content

Commit

Permalink
Model: Add override base seq len
Browse files Browse the repository at this point in the history
Some models (such as mistral and mixtral) set their base sequence
length to 32k due to assumptions of support for sliding window
attention.

Therefore, add this parameter to override the base sequence length
of a model which helps with auto-calculation of rope alpha.

If auto-calculation of rope alpha isn't being used, the max_seq_len
parameter works fine as is.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Dec 20, 2023
1 parent 5368ed7 commit ab10b26
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
5 changes: 3 additions & 2 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class DraftModelLoadRequest(BaseModel):
class ModelLoadRequest(BaseModel):
name: str

# Max seq len is defaulted when loading the model itself
max_seq_len: Optional[int] = None
# Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None)
override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None)
gpu_split_auto: Optional[bool] = True
gpu_split: Optional[List[float]] = Field(default_factory=list)
rope_scale: Optional[float] = 1.0
Expand Down
8 changes: 7 additions & 1 deletion config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ model:

# The below parameters apply only if model_name is set

# Override maximum model context length (default: None)
# Max sequence length (default: None)
# Fetched from the model's base sequence length in config.json by default
max_seq_len:

# Overrides base model context length (default: None)
# WARNING: Don't set this unless you know what you're doing!
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral/Mixtral models)
override_base_seq_len:

# Automatically allocate resources to GPUs (default: True)
gpu_split_auto: True

Expand Down
13 changes: 9 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
self.config.max_seq_len = 4096
self.config.prepare()

# Then override the max_seq_len if present
override_max_seq_len = kwargs.get("max_seq_len")
if override_max_seq_len:
self.config.max_seq_len = kwargs.get("max_seq_len")
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
self.config.max_seq_len = override_base_seq_len

# Grab the base model's sequence length before overrides for rope calculations
base_seq_len = self.config.max_seq_len

# Set the target seq len if present
target_max_seq_len = kwargs.get("max_seq_len")
if target_max_seq_len:
self.config.max_seq_len = target_max_seq_len

# Set the rope scale
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)

Expand Down

0 comments on commit ab10b26

Please sign in to comment.