Skip to content

Commit

Permalink
Make InternLM follow rope_scaling in config.json (#1956)
Browse files Browse the repository at this point in the history
Co-authored-by: lijie8 <lijie8@sensetime.com>
  • Loading branch information
theFool32 and lijie8 authored Dec 7, 2023
1 parent d940ce4 commit ebede26
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion vllm/model_executor/models/internlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)

Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = InternLMMLP(
hidden_size=self.hidden_size,
Expand Down

0 comments on commit ebede26

Please sign in to comment.