Skip to content

Commit c058acb

Browse files
authored
Dev/download hook (#143)
* add registry model download func * update download export
1 parent af9b398 commit c058acb

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

diffsynth_engine/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
1818
from .models.sd import SDControlNet
1919
from .models.sdxl import SDXLControlNetUnion
20-
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
20+
from .utils.download import (
21+
fetch_model,
22+
fetch_modelscope_model,
23+
fetch_civitai_model,
24+
register_fetch_modelscope_model,
25+
reset_fetch_modelscope_model,
26+
)
2127
from .utils.video import load_video, save_video
2228
from .tools import (
2329
FluxInpaintingTool,
@@ -52,6 +58,8 @@
5258
"ControlType",
5359
"fetch_model",
5460
"fetch_modelscope_model",
61+
"register_fetch_modelscope_model",
62+
"reset_fetch_modelscope_model",
5563
"fetch_civitai_model",
5664
"load_video",
5765
"save_video",

diffsynth_engine/utils/download.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,31 @@
2020

2121
MODEL_SOURCES = ["modelscope", "civitai"]
2222

23+
# Global registry for custom fetch function
24+
_CUSTOM_MODELSCOPE_FETCHER = None
25+
26+
27+
def register_fetch_modelscope_model(fetch_func):
28+
"""
29+
Register a global custom fetch function for ModelScope models.
30+
31+
Args:
32+
fetch_func (callable): Custom fetch function that should accept the same parameters
33+
as fetch_modelscope_model and return the model path(s)
34+
"""
35+
global _CUSTOM_MODELSCOPE_FETCHER
36+
_CUSTOM_MODELSCOPE_FETCHER = fetch_func
37+
logger.info("Registered global custom ModelScope fetcher")
38+
39+
40+
def reset_fetch_modelscope_model():
41+
"""
42+
Reset the global custom fetch function for ModelScope models.
43+
"""
44+
global _CUSTOM_MODELSCOPE_FETCHER
45+
_CUSTOM_MODELSCOPE_FETCHER = None
46+
logger.info("Reset global custom ModelScope fetcher")
47+
2348

2449
def fetch_model(
2550
model_uri: str,
@@ -43,6 +68,11 @@ def fetch_modelscope_model(
4368
access_token: Optional[str] = None,
4469
fetch_safetensors: bool = True,
4570
) -> str:
71+
# Check if there's a global custom fetcher registered
72+
if _CUSTOM_MODELSCOPE_FETCHER is not None:
73+
logger.info(f"Using global custom fetcher for model: {model_id}")
74+
return _CUSTOM_MODELSCOPE_FETCHER(model_id, revision, path, access_token, fetch_safetensors)
75+
4676
lock_file_name = f"modelscope.{model_id.replace('/', '--')}.{revision if revision else '__version'}.lock"
4777
lock_file_path = os.path.join(DIFFSYNTH_FILELOCK_DIR, lock_file_name)
4878
ensure_directory_exists(lock_file_path)

0 commit comments

Comments
 (0)