11import dataclasses
22import platform
3- from typing import Union
3+ from typing import Optional , Union
44
55import psutil
66import torch
77import torch .nn .functional as F
88import torch .utils .checkpoint
99from hivemind import get_logger
1010from torch import nn
11+ from transformers import PretrainedConfig
1112
1213logger = get_logger (__name__ )
1314
@@ -21,15 +22,16 @@ class LMHeadConfig:
2122
2223
2324class LMHead (nn .Module ):
24- """
25- The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
26- embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
27- In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
28- """
29-
30- def __init__ (self , config : LMHeadConfig , word_embeddings : nn .Embedding ):
25+ def __init__ (self , config : PretrainedConfig ):
3126 super ().__init__ ()
32- self .word_embeddings = word_embeddings
27+
28+ if not config .tie_word_embeddings :
29+ self .weight = nn .Parameter (torch .zeros ((config .vocab_size , config .hidden_size ), requires_grad = False ))
30+ else :
31+ self .weight = None # Will be set to get_input_embeddings().weight during loading the model
32+ self .bias = None
33+ self .in_features = config .hidden_size # Similar to nn.Linear attributes
34+ self .out_features = config .vocab_size
3335
3436 self .use_chunked_forward = config .use_chunked_forward
3537 if self .use_chunked_forward == "auto" :
@@ -45,35 +47,17 @@ def __init__(self, config: LMHeadConfig, word_embeddings: nn.Embedding):
4547 self .chunked_forward_step = config .chunked_forward_step
4648 self ._bf16_warning_shown = False
4749
48- @property
49- def in_features (self ) -> int :
50- return self .word_embeddings .num_embeddings
51-
52- @property
53- def out_features (self ) -> int :
54- return self .word_embeddings .embedding_dim
55-
56- @property
57- def weight (self ):
58- return self .word_embeddings .weight
59-
60- @property
61- def bias (self ):
62- return None
63-
6450 def forward (self , hidden_states ):
65- word_embeddings = self .word_embeddings .weight
66-
6751 if (
68- word_embeddings .dtype in [torch .float16 , torch .bfloat16 ]
69- and word_embeddings .device .type == "cpu"
52+ self . weight .dtype in [torch .float16 , torch .bfloat16 ]
53+ and self . weight .device .type == "cpu"
7054 and self .use_chunked_forward
7155 ):
7256 lm_logits = self .chunked_forward (hidden_states )
7357 else :
7458 # Switch dtype in case word_embeddings are fp16/bf16
75- hidden_states = hidden_states .to (word_embeddings .dtype )
76- lm_logits = F .linear (hidden_states , word_embeddings )
59+ hidden_states = hidden_states .to (self . weight .dtype )
60+ lm_logits = F .linear (hidden_states , self . weight )
7761 return lm_logits
7862
7963 def chunked_forward (self , hidden_states ):
@@ -83,20 +67,17 @@ def chunked_forward(self, hidden_states):
8367 assert self .chunked_forward_step > 0 , "Chunk size for chunked forward must be positive"
8468
8569 if not self ._bf16_warning_shown :
86- if self .word_embeddings . weight .numel () * 4 < 0.9 * psutil .virtual_memory ().total :
70+ if self .weight .numel () * 4 < 0.9 * psutil .virtual_memory ().total :
8771 logger .warning (
8872 "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
8973 "Consider loading the model with torch_dtype='float32'"
9074 )
9175 self ._bf16_warning_shown = True
9276
93- word_embeddings = self .word_embeddings .weight
94- num_embeddings = self .word_embeddings .num_embeddings
95-
9677 hidden_states = hidden_states .float ()
97- output = torch .empty (* hidden_states .shape [:- 1 ], num_embeddings )
78+ output = torch .empty (* hidden_states .shape [:- 1 ], self . out_features )
9879
99- for i in range (0 , num_embeddings , self .chunked_forward_step ):
100- chunk = word_embeddings [i : i + self .chunked_forward_step ].float ()
80+ for i in range (0 , self . out_features , self .chunked_forward_step ):
81+ chunk = self . weight [i : i + self .chunked_forward_step ].float ()
10182 output [..., i : i + self .chunked_forward_step ] = F .linear (hidden_states , chunk )
10283 return output
0 commit comments