@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
89
89
90
90
self .to_json_file (output_config_file )
91
91
logger .info (f"ConfigMixinuration saved in { output_config_file } " )
92
+
92
93
93
94
@classmethod
94
95
def get_config_dict (
@@ -182,35 +183,42 @@ def get_config_dict(
182
183
logger .info (f"loading configuration file { config_file } " )
183
184
else :
184
185
logger .info (f"loading configuration file { config_file } from cache at { resolved_config_file } " )
186
+
187
+ return config_dict
185
188
189
+ @classmethod
190
+ def extract_init_dict (cls , config_dict , ** kwargs ):
186
191
expected_keys = set (dict (inspect .signature (cls .__init__ ).parameters ).keys ())
187
192
expected_keys .remove ("self" )
188
-
193
+ init_dict = {}
189
194
for key in expected_keys :
190
195
if key in kwargs :
191
196
# overwrite key
192
- config_dict [key ] = kwargs .pop (key )
197
+ init_dict [key ] = kwargs .pop (key )
198
+ elif key in config_dict :
199
+ # use value from config dict
200
+ init_dict [key ] = config_dict .pop (key )
193
201
194
- passed_keys = set (config_dict .keys ())
195
-
196
- unused_kwargs = kwargs
197
- for key in passed_keys - expected_keys :
198
- unused_kwargs [key ] = config_dict .pop (key )
199
202
203
+ unused_kwargs = config_dict .update (kwargs )
204
+
205
+ passed_keys = set (init_dict .keys ())
200
206
if len (expected_keys - passed_keys ) > 0 :
201
207
logger .warn (
202
208
f"{ expected_keys - passed_keys } was not found in config. Values will be initialized to default values."
203
209
)
204
210
205
- return config_dict , unused_kwargs
211
+ return init_dict , unused_kwargs
206
212
207
213
@classmethod
208
214
def from_config (cls , pretrained_model_name_or_path : Union [str , os .PathLike ], return_unused_kwargs = False , ** kwargs ):
209
- config_dict , unused_kwargs = cls .get_config_dict (
215
+ config_dict = cls .get_config_dict (
210
216
pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs
211
217
)
212
218
213
- model = cls (** config_dict )
219
+ init_dict , unused_kwargs = cls .extract_init_dict (config_dict , ** kwargs )
220
+
221
+ model = cls (** init_dict )
214
222
215
223
if return_unused_kwargs :
216
224
return model , unused_kwargs
0 commit comments