File tree Expand file tree Collapse file tree 2 files changed +25
-2
lines changed Expand file tree Collapse file tree 2 files changed +25
-2
lines changed Original file line number Diff line number Diff line change @@ -25,7 +25,8 @@ def get_model(
2525 torch_dtype = STR_DTYPE_TO_TORCH_DTYPE [dtype .lower ()]
2626 else :
2727 torch_dtype = dtype
28- for model_class , model in MODEL_CLASSES .items ():
28+ for model_class , hf_model in MODEL_CLASSES .items ():
2929 if model_class in model_name :
30- return model .from_pretrained (model_name , torch_dtype = torch_dtype )
30+ model = hf_model .from_pretrained (model_name , torch_dtype = torch_dtype )
31+ return model .eval ()
3132 raise ValueError (f'Invalid model name: { model_name } ' )
Original file line number Diff line number Diff line change @@ -232,6 +232,28 @@ def __init__(self, config):
232232 # Initialize weights and apply final processing
233233 self .post_init ()
234234
235+ # NOTE(woosuk): While the following methods are not called in the model code,
236+ # they may be internally used by the transformers library.
237+ # For example, tie_weights() does not work without these methods.
238+ # Thus, do not delete these methods.
239+ def get_input_embeddings (self ):
240+ return self .model .decoder .embed_tokens
241+
242+ def set_input_embeddings (self , value ):
243+ self .model .decoder .embed_tokens = value
244+
245+ def get_output_embeddings (self ):
246+ return self .lm_head
247+
248+ def set_output_embeddings (self , new_embeddings ):
249+ self .lm_head = new_embeddings
250+
251+ def set_decoder (self , decoder ):
252+ self .model .decoder = decoder
253+
254+ def get_decoder (self ):
255+ return self .model .decoder
256+
235257 def forward (
236258 self ,
237259 input_ids : torch .LongTensor ,
You can’t perform that action at this time.
0 commit comments