Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[BugFix][Model] Fix commandr RoPE max_position_embeddings (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored and andy-neuma committed Apr 12, 2024
1 parent ce01e55 commit 88385bb
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def __init__(
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = config.max_position_embeddings
self.max_position_embeddings = getattr(
config, "model_max_length", None) or getattr(
config, "max_position_embeddings", 8192)
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
Expand Down

0 comments on commit 88385bb

Please sign in to comment.