Skip to content

Commit

Permalink
Add automatic NTK-aware alpha scaling to model
Browse files Browse the repository at this point in the history
* enables automatic calculation of NTK-aware alpha scaling for models if the rope_alpha arg is not passed in the config, using the same formula used for draft models
  • Loading branch information
DocShotgun authored Dec 3, 2023
1 parent 61f6e51 commit 1c398b0
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
self.config.prepare()

base_seq_len = self.config.max_seq_len

if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"]
if "rope_alpha" in kwargs:
self.config.scale_alpha_value = kwargs["rope_alpha"]
else:
ratio = self.config.max_seq_len / base_seq_len
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
self.config.scale_alpha_value = alpha
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]

if "low_mem" in kwargs and kwargs["low_mem"]:
Expand Down Expand Up @@ -102,7 +109,7 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
self.draft_config.prepare()

if "draft_rope_alpha" in kwargs:
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha")
else:
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
Expand Down

0 comments on commit 1c398b0

Please sign in to comment.