Skip to content

Commit 1e6a32c

Browse files
committed
Reuse existing model when possible (perf)
reloads when model changes or any lora weight changes
1 parent b130f16 commit 1e6a32c

File tree

7 files changed

+670
-471
lines changed

7 files changed

+670
-471
lines changed

modules/generators/base_generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_na
295295
lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing.
296296
lora_values: A list of strength values corresponding to lora_loaded_names.
297297
"""
298-
self.unload_loras()
299-
300298
if not selected_loras:
299+
# Only unload at this point if no LoRAs are selected
300+
self.unload_loras()
301301
print("No LoRAs selected, skipping loading.")
302302
return
303303

@@ -394,6 +394,7 @@ def _find_model_files(model_path):
394394
print(f"Error loading LoRAs with kohya_ss loader: {e}")
395395
traceback.print_exc()
396396
else:
397+
self.unload_loras()
397398
adapter_names = []
398399
strengths = []
399400

modules/generators/model_configuration.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from hashlib import md5
1+
import hashlib
2+
import json
23
from typing import Optional, cast
3-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass, field, asdict
45

56
DEFAULT_WEIGHT: float = 0.8
67

@@ -188,7 +189,7 @@ class ModelConfiguration:
188189

189190
@property
190191
def _hash(self) -> str:
191-
return md5(json.dumps(dataclasses.asdict(self), sort_keys=True).encode()).hexdigest()
192+
return hashlib.md5(json.dumps(asdict(self), sort_keys=True).encode()).hexdigest()
192193

193194
def add_lora_setting(self, setting: ModelLoraSetting) -> None:
194195
self.settings.add_lora_setting(setting)
@@ -205,38 +206,38 @@ def validate(self) -> bool:
205206
return valid
206207

207208
@staticmethod
208-
def from_config(model_name: str, settings: ModelSettings | dict | None):
209-
model_config: ModelSettings | None = None
209+
def from_settings(model_name: str, settings: ModelSettings | dict | None):
210+
model_settings: ModelSettings | None = None
210211
if settings is None:
211-
model_config = ModelSettings()
212+
model_settings = ModelSettings()
212213
elif isinstance(settings, ModelSettings):
213-
model_config = settings
214+
model_settings = settings
214215
elif isinstance(settings, dict):
215-
model_config = ModelSettings(lora_settings=ModelLoraSetting.parse_settings(settings))
216+
model_settings = ModelSettings(lora_settings=ModelLoraSetting.parse_settings(settings))
216217

217-
if model_config is None:
218+
if model_settings is None:
218219
raise ValueError("Invalid config type for ModelConfiguration")
219220

220-
return ModelConfiguration(model_name=model_name, settings=model_config)
221+
return ModelConfiguration(model_name=model_name, settings=model_settings)
221222

222223
@staticmethod
223224
def from_lora_names_and_weights(model_name: str, lora_names: list[str], lora_weights: Optional[list[float | int]] = None) -> "ModelConfiguration":
224225
weights: list[float] = [float(weight) for weight in (lora_weights or [])]
225226
lora_settings = ModelLoraSetting.from_names_and_weights(lora_names, lora_weights=weights)
226-
model_config = ModelSettings(lora_settings=lora_settings)
227+
model_settings = ModelSettings(lora_settings=lora_settings)
227228
del weights, lora_settings
228-
return ModelConfiguration.from_config(model_name=model_name, settings=model_config)
229+
return ModelConfiguration.from_settings(model_name=model_name, settings=model_settings)
229230

230231
def set_model_name(self, model_name: str) -> "ModelConfiguration":
231232
self.model_name = model_name
232233
return self
233234

234-
def set_config(self, config: ModelSettings) -> "ModelConfiguration":
235-
self.config = config
235+
def set_settings(self, settings: ModelSettings) -> "ModelConfiguration":
236+
self.settings = settings
236237
return self
237238

238239
def update_lora_setting(self, lora_settings: list[ModelLoraSetting] | str | list[str] | dict[str, dict]) -> "ModelConfiguration":
239-
self.config.lora_settings = ModelLoraSetting.parse_settings(lora_settings)
240+
self.settings.lora_settings = ModelLoraSetting.parse_settings(lora_settings)
240241
return self
241242

242243

@@ -251,9 +252,7 @@ def update_lora_setting(self, lora_settings: list[ModelLoraSetting] | str | list
251252
)
252253
logger.info(f"Model Name: {config.model_name}")
253254
logger.info(f"LoRA Settings: {config.settings.lora_settings}")
254-
import dataclasses
255-
import json
256-
logger.debug(json.dumps(dataclasses.asdict(config), indent=4))
255+
logger.debug(json.dumps(asdict(config), indent=4))
257256
logger.debug("hash: {0}".format(config._hash))
258257
config.model_name = "changed"
259258
logger.debug("hash: {0}".format(config._hash))
@@ -263,10 +262,10 @@ def update_lora_setting(self, lora_settings: list[ModelLoraSetting] | str | list
263262
)
264263
logger.debug("hash: {0}".format(config._hash))
265264
config.settings.lora_settings.append(ModelLoraSetting(name="lora_D", weight=0.75))
266-
logger.debug(json.dumps(dataclasses.asdict(config), indent=4))
265+
logger.debug(json.dumps(asdict(config), indent=4))
267266
logger.debug("hash: {0}".format(config._hash))
268267
config.add_lora("lora_E")
269-
logger.debug(json.dumps(dataclasses.asdict(config), indent=4))
268+
logger.debug(json.dumps(asdict(config), indent=4))
270269
logger.debug("hash: {0}".format(config._hash))
271270
valid = config.validate()
272271
logger.debug("Config validation result: {0}".format(valid))

0 commit comments

Comments
 (0)