Skip to content

Commit 28ba0ff

Browse files
authored
make from hub import work
make from hub import work
2 parents 9616419 + 7335462 commit 28ba0ff

File tree

5 files changed

+422
-18
lines changed

5 files changed

+422
-18
lines changed

src/diffusers/configuration_utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
8989

9090
self.to_json_file(output_config_file)
9191
logger.info(f"ConfigMixinuration saved in {output_config_file}")
92+
9293

9394
@classmethod
9495
def get_config_dict(
@@ -182,35 +183,42 @@ def get_config_dict(
182183
logger.info(f"loading configuration file {config_file}")
183184
else:
184185
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
186+
187+
return config_dict
185188

189+
@classmethod
190+
def extract_init_dict(cls, config_dict, **kwargs):
186191
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
187192
expected_keys.remove("self")
188-
193+
init_dict = {}
189194
for key in expected_keys:
190195
if key in kwargs:
191196
# overwrite key
192-
config_dict[key] = kwargs.pop(key)
197+
init_dict[key] = kwargs.pop(key)
198+
elif key in config_dict:
199+
# use value from config dict
200+
init_dict[key] = config_dict.pop(key)
193201

194-
passed_keys = set(config_dict.keys())
195-
196-
unused_kwargs = kwargs
197-
for key in passed_keys - expected_keys:
198-
unused_kwargs[key] = config_dict.pop(key)
199202

203+
unused_kwargs = config_dict.update(kwargs)
204+
205+
passed_keys = set(init_dict.keys())
200206
if len(expected_keys - passed_keys) > 0:
201207
logger.warn(
202208
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
203209
)
204210

205-
return config_dict, unused_kwargs
211+
return init_dict, unused_kwargs
206212

207213
@classmethod
208214
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
209-
config_dict, unused_kwargs = cls.get_config_dict(
215+
config_dict = cls.get_config_dict(
210216
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
211217
)
212218

213-
model = cls(**config_dict)
219+
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
220+
221+
model = cls(**init_dict)
214222

215223
if return_unused_kwargs:
216224
return model, unused_kwargs

0 commit comments

Comments
 (0)