|
12 | 12 |
|
13 | 13 | from cacheflow.models import InputMetadata |
14 | 14 | from cacheflow.models.attention import LlamaCacheFlowAttention |
| 15 | +from cacheflow.models.layernorm import RMSNorm |
15 | 16 | from cacheflow.models.sample import Sampler |
16 | 17 | from cacheflow.parallel_utils.parallel_state import ( |
17 | 18 | get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) |
|
23 | 24 | KVCache = Tuple[torch.Tensor, torch.Tensor] |
24 | 25 |
|
25 | 26 |
|
26 | | -class LlamaRMSNorm(nn.Module): |
27 | | - |
28 | | - def __init__(self, hidden_size, eps=1e-6): |
29 | | - super().__init__() |
30 | | - self.weight = nn.Parameter(torch.ones(hidden_size)) |
31 | | - self.variance_epsilon = eps |
32 | | - |
33 | | - def forward(self, hidden_states): |
34 | | - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
35 | | - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
36 | | - # convert into half-precision if necessary |
37 | | - if self.weight.dtype in [torch.float16, torch.bfloat16]: |
38 | | - hidden_states = hidden_states.to(self.weight.dtype) |
39 | | - return self.weight * hidden_states |
40 | | - |
41 | | - |
42 | 27 | class LlamaMLP(nn.Module): |
43 | 28 |
|
44 | 29 | def __init__( |
@@ -148,8 +133,8 @@ def __init__(self, config: LlamaConfig): |
148 | 133 | intermediate_size=config.intermediate_size, |
149 | 134 | hidden_act=config.hidden_act, |
150 | 135 | ) |
151 | | - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
152 | | - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 136 | + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 137 | + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
153 | 138 |
|
154 | 139 | def forward( |
155 | 140 | self, |
@@ -190,7 +175,7 @@ def __init__(self, config: LlamaConfig): |
190 | 175 | self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, |
191 | 176 | perform_initialization=False) |
192 | 177 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
193 | | - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 178 | + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
194 | 179 |
|
195 | 180 | def forward( |
196 | 181 | self, |
|
0 commit comments