From 3f2fcbcc4512a942d8eb29a490e9a2c3a8bea621 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 5 Dec 2023 18:51:36 -0800 Subject: [PATCH] Add fallback to draft_rope_scale to 1.0 --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 929916ff..360cf88f 100644 --- a/model.py +++ b/model.py @@ -111,7 +111,7 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool) self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - if "draft_rope_scale" in kwargs: self.draft_config.scale_pos_emb = kwargs["draft_rope_scale"] + self.draft_config.scale_pos_emb = kwargs.get("draft_rope_scale") or 1.0 self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.max_seq_len = self.config.max_seq_len