-
Notifications
You must be signed in to change notification settings - Fork 118
feat: support Qwen3 128k context via YaRN scaling. #624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
25765ff to
6961655
Compare
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for YaRN scaling to the Qwen3 model, enabling a 128k context length. The changes involve updating the maximum context length calculation, adding a new utility function for YaRN rotary embeddings, and integrating this into the Qwen3 model implementation. The overall approach is sound. I've identified a minor issue regarding an unused function parameter that should be addressed to improve code clarity and maintainability.
| torch::Tensor get_yarn_rotary_embedding( | ||
| int64_t rotary_dim, | ||
| int64_t max_position_embeddings, | ||
| int64_t original_max_position_embeddings, | ||
| float rope_theta, | ||
| bool interleaved, | ||
| float scaling_factor, | ||
| const torch::TensorOptions& options, | ||
| float extrapolation_factor = 1.0f, | ||
| float attn_factor = 1.0f, | ||
| float beta_fast = 32.0f, | ||
| float beta_slow = 1.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_position_embeddings parameter is unused in the implementation of get_yarn_rotary_embedding. To improve code clarity and prevent confusion, it should be removed from the function declaration. I will suggest corresponding changes in the implementation file and at the call site.
torch::Tensor get_yarn_rotary_embedding(
int64_t rotary_dim,
int64_t original_max_position_embeddings,
float rope_theta,
bool interleaved,
float scaling_factor,
const torch::TensorOptions& options,
float extrapolation_factor = 1.0f,
float attn_factor = 1.0f,
float beta_fast = 32.0f,
float beta_slow = 1.0f);| torch::Tensor get_yarn_rotary_embedding( | ||
| int64_t rotary_dim, | ||
| int64_t max_position_embeddings, | ||
| int64_t original_max_position_embeddings, | ||
| float rope_theta, | ||
| bool interleaved, | ||
| float scaling_factor, | ||
| const torch::TensorOptions& options, | ||
| float extrapolation_factor, | ||
| float attn_factor, | ||
| float beta_fast, | ||
| float beta_slow) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_position_embeddings parameter is not used within this function. It should be removed from the function signature to align with the declaration change and improve code maintainability.
torch::Tensor get_yarn_rotary_embedding(
int64_t rotary_dim,
int64_t original_max_position_embeddings,
float rope_theta,
bool interleaved,
float scaling_factor,
const torch::TensorOptions& options,
float extrapolation_factor,
float attn_factor,
float beta_fast,
float beta_slow) {| cos_sin_ = layer::rotary::get_yarn_rotary_embedding( | ||
| 128, | ||
| model_args.max_position_embeddings(), | ||
| model_args.rope_scaling_original_max_position_embeddings(), | ||
| model_args.rope_theta(), | ||
| false, | ||
| model_args.rope_scaling_factor(), | ||
| options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This argument corresponds to the unused max_position_embeddings parameter in get_yarn_rotary_embedding. It should be removed from the function call to match the updated function signature.
cos_sin_ = layer::rotary::get_yarn_rotary_embedding(
128,
model_args.rope_scaling_original_max_position_embeddings(),
model_args.rope_theta(),
false,
model_args.rope_scaling_factor(),
options);
No description provided.