88"""
99
1010import math
11- import inspect
1211from dataclasses import dataclass
1312
1413import torch
@@ -167,99 +166,6 @@ def _init_weights(self, module):
167166 elif isinstance (module , nn .Embedding ):
168167 torch .nn .init .normal_ (module .weight , mean = 0.0 , std = 0.02 )
169168
170- def forward (self , idx , targets = None ):
171- device = idx .device
172- b , t = idx .size ()
173- assert t <= self .config .block_size , f"Cannot forward sequence of length { t } , block size is only { self .config .block_size } "
174- pos = torch .arange (0 , t , dtype = torch .long , device = device ) # shape (t)
175-
176- # forward the GPT model itself
177- tok_emb = self .transformer .wte (idx ) # token embeddings of shape (b, t, n_embd)
178- pos_emb = self .transformer .wpe (pos ) # position embeddings of shape (t, n_embd)
179- x = self .transformer .drop (tok_emb + pos_emb )
180- for block in self .transformer .h :
181- x = block (x )
182- x = self .transformer .ln_f (x )
183-
184- if targets is not None :
185- # if we are given some desired targets also calculate the loss
186- logits = self .lm_head (x )
187- loss = F .cross_entropy (logits .view (- 1 , logits .size (- 1 )), targets .view (- 1 ), ignore_index = - 1 )
188- else :
189- # inference-time mini-optimization: only forward the lm_head on the very last position
190- logits = self .lm_head (x [:, [- 1 ], :]) # note: using list [-1] to preserve the time dim
191- loss = None
192-
193- return logits , loss
194-
195- def crop_block_size (self , block_size ):
196- # model surgery to decrease the block size if necessary
197- # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198- # but want to use a smaller block size for some smaller, simpler model
199- assert block_size <= self .config .block_size
200- self .config .block_size = block_size
201- self .transformer .wpe .weight = nn .Parameter (self .transformer .wpe .weight [:block_size ])
202- for block in self .transformer .h :
203- if hasattr (block .attn , 'bias' ):
204- block .attn .bias = block .attn .bias [:,:,:block_size ,:block_size ]
205-
206- @classmethod
207- def from_pretrained (cls , model_type , override_args = None ):
208- assert model_type in {'gpt2' , 'gpt2-medium' , 'gpt2-large' , 'gpt2-xl' }
209- override_args = override_args or {} # default to empty dict
210- # only dropout can be overridden see more notes below
211- assert all (k == 'dropout' for k in override_args )
212- from transformers import GPT2LMHeadModel
213- print ("loading weights from pretrained gpt: %s" % model_type )
214-
215- # n_layer, n_head and n_embd are determined from model_type
216- config_args = {
217- 'gpt2' : dict (n_layer = 12 , n_head = 12 , n_embd = 768 ), # 124M params
218- 'gpt2-medium' : dict (n_layer = 24 , n_head = 16 , n_embd = 1024 ), # 350M params
219- 'gpt2-large' : dict (n_layer = 36 , n_head = 20 , n_embd = 1280 ), # 774M params
220- 'gpt2-xl' : dict (n_layer = 48 , n_head = 25 , n_embd = 1600 ), # 1558M params
221- }[model_type ]
222- print ("forcing vocab_size=50257, block_size=1024, bias=True" )
223- config_args ['vocab_size' ] = 50257 # always 50257 for GPT model checkpoints
224- config_args ['block_size' ] = 1024 # always 1024 for GPT model checkpoints
225- config_args ['bias' ] = True # always True for GPT model checkpoints
226- # we can override the dropout rate, if desired
227- if 'dropout' in override_args :
228- print (f"overriding dropout rate to { override_args ['dropout' ]} " )
229- config_args ['dropout' ] = override_args ['dropout' ]
230- # create a from-scratch initialized minGPT model
231- config = GPTConfig (** config_args )
232- model = GPT (config )
233- sd = model .state_dict ()
234- sd_keys = sd .keys ()
235- sd_keys = [k for k in sd_keys if not k .endswith ('.attn.bias' )] # discard this mask / buffer, not a param
236-
237- # init a huggingface/transformers model
238- model_hf = GPT2LMHeadModel .from_pretrained (model_type )
239- sd_hf = model_hf .state_dict ()
240-
241- # copy while ensuring all of the parameters are aligned and match in names and shapes
242- sd_keys_hf = sd_hf .keys ()
243- sd_keys_hf = [k for k in sd_keys_hf if not k .endswith ('.attn.masked_bias' )] # ignore these, just a buffer
244- sd_keys_hf = [k for k in sd_keys_hf if not k .endswith ('.attn.bias' )] # same, just the mask (buffer)
245- transposed = ['attn.c_attn.weight' , 'attn.c_proj.weight' , 'mlp.c_fc.weight' , 'mlp.c_proj.weight' ]
246- # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247- # this means that we have to transpose these weights when we import them
248- assert len (sd_keys_hf ) == len (sd_keys ), f"mismatched keys: { len (sd_keys_hf )} != { len (sd_keys )} "
249- for k in sd_keys_hf :
250- if any (k .endswith (w ) for w in transposed ):
251- # special treatment for the Conv1D weights we need to transpose
252- assert sd_hf [k ].shape [::- 1 ] == sd [k ].shape
253- with torch .no_grad ():
254- sd [k ].copy_ (sd_hf [k ].t ())
255- else :
256- # vanilla copy over the other parameters
257- assert sd_hf [k ].shape == sd [k ].shape
258- with torch .no_grad ():
259- sd [k ].copy_ (sd_hf [k ])
260-
261- return model
262-
263169 def configure_optimizers (self , weight_decay , learning_rate , betas ):
264170 # start with all of the candidate parameters
265171 param_dict = {pn : p for pn , p in self .named_parameters ()}
@@ -278,12 +184,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas):
278184 print (f"num decayed parameter tensors: { len (decay_params )} , with { num_decay_params :,} parameters" )
279185 print (f"num non-decayed parameter tensors: { len (nodecay_params )} , with { num_nodecay_params :,} parameters" )
280186
281- # Create AdamW optimizer and use the fused version if it is available
282- # fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
283- # use_fused = fused_available and device_type == 'cuda'
284- # extra_args = dict(fused=True) if use_fused else dict()
285187 optimizer = torch .optim .AdamW (optim_groups , lr = learning_rate , betas = betas , fused = False )
286- # print(f"using fused AdamW: {use_fused}")
287188
288189 return optimizer
289190
0 commit comments