diff --git a/modules/shared.py b/modules/shared.py index a99b500b211..f66ef39485a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -550,6 +550,7 @@ def list_samplers(): options_templates.update(options_section(('ui', "User interface"), { "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), + "re_download_theme": OptionInfo(False, "Re-download the selected Gradio theme"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), @@ -846,6 +847,38 @@ def sd_model(self, value): progress_print_out = sys.stdout + +def from_hub_with_cache_wrapper(func): + def wrapper(*args, **kwargs): + import pickle + repo_name = '' + if 'key_name' in kwargs: + repo_name = kwargs['repo_name'] + elif args and len(args) >= 1: + repo_name = args[0] + if repo_name: + theme_cache_path = os.path.join(script_path, 'tmp', 'gradio_themes', repo_name.replace('/', '_')) + # if theme is cached use cache and same gradio version + if not opts.re_download_theme and os.path.exists(theme_cache_path): + with open(theme_cache_path, 'rb') as cached_theme: + theme_cache = pickle.load(cached_theme) + if gr.__version__ == theme_cache.get('gradio_version'): + return theme_cache.get('theme') + # get theme from hub + result = func(*args, **kwargs) + # save theme to cache + os.makedirs(os.path.dirname(theme_cache_path), exist_ok=True) + with open(theme_cache_path, 'wb') as cached_theme: + theme_cache = {'theme': result, 'gradio_version': gr.__version__} + pickle.dump(theme_cache, cached_theme) + + return result + return wrapper + + +gr.themes.ThemeClass.from_hub = from_hub_with_cache_wrapper(gr.themes.ThemeClass.from_hub) # decorates gr.themes.ThemeClass.from_hub with from_hub_with_cache_wrapper + + gradio_theme = gr.themes.Base() @@ -869,7 +902,6 @@ def reload_gradio_theme(theme_name=None): gradio_theme = gr.themes.Default(**default_theme_args) - class TotalTQDM: def __init__(self): self._tqdm = None