diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 0b2ed14..27c7da8 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -220,53 +220,54 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup): combined_patches[key] = current_patches return combined_patches - def model_patches_to(self, device): - super().model_patches_to(device) - - def patch_model(self, device_to=None, patch_weights=True): + def patch_model(self, *args, **kwargs): + was_injected = False + if self.currently_injected: + self.eject_model() + was_injected = True # first, perform model patching - if patch_weights: # TODO: keep only 'else' portion when don't need to worry about past comfy versions - patched_model = super().patch_model(device_to) - else: - patched_model = super().patch_model(device_to, patch_weights) - # finally, perform motion model injection - self.inject_model() + patched_model = super().patch_model(*args, **kwargs) + # bring injection back to original state + if was_injected and not self.currently_injected: + self.inject_model() return patched_model - def patch_model_lowvram(self, *args, **kwargs): + def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs): + self.eject_model() try: - return super().patch_model_lowvram(*args, **kwargs) + return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs) finally: - # check if any modules have weight_function or bias_function that is not None - # NOTE: this serves no purpose currently, but I have it here for future reasons - for n, m in self.model.named_modules(): - if not hasattr(m, "comfy_cast_weights"): - continue - if getattr(m, "weight_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n - if getattr(m, "bias_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.bias"] = n + self.inject_model() + if lowvram_model_memory > 0: + self._patch_lowvram_extras() + + def _patch_lowvram_extras(self): + # check if any modules have weight_function or bias_function that is not None + # NOTE: this serves no purpose currently, but I have it here for future reasons + self.model_params_lowvram = False + self.model_params_lowvram_keys.clear() + for n, m in self.model.named_modules(): + if not hasattr(m, "comfy_cast_weights"): + continue + if getattr(m, "weight_function", None) is not None: + self.model_params_lowvram = True + self.model_params_lowvram_keys[f"{n}.weight"] = n + if getattr(m, "bias_function", None) is not None: + self.model_params_lowvram = True + self.model_params_lowvram_keys[f"{n}.bias"] = n def unpatch_model(self, device_to=None, unpatch_weights=True): # first, eject motion model from unet self.eject_model() # finally, do normal model unpatching - if unpatch_weights: # TODO: keep only 'else' portion when don't need to worry about past comfy versions + if unpatch_weights: # handle hooked_patches first self.clean_hooks() - try: - return super().unpatch_model(device_to) - finally: - self.model_params_lowvram = False - self.model_params_lowvram_keys.clear() - else: - try: - return super().unpatch_model(device_to, unpatch_weights) - finally: - self.model_params_lowvram = False - self.model_params_lowvram_keys.clear() + try: + return super().unpatch_model(device_to, unpatch_weights) + finally: + self.model_params_lowvram = False + self.model_params_lowvram_keys.clear() def partially_load(self, *args, **kwargs): # partially_load calls patch_model, but we don't want to inject model in the intermediate call; @@ -625,7 +626,7 @@ def patch_hooked_replace_weight_to_device(self, model_sd: dict, replace_patches: else: comfy.utils.set_attr_param(self.model, key, out_weight) - def patch_model(self, device_to=None, patch_weights=True, *args, **kwargs): + def patch_model(self, device_to=None, *args, **kwargs): if self.desired_lora_hooks is not None: self.patches_backup = self.patches.copy() relevant_patches = self.get_combined_hooked_patches(lora_hooks=self.desired_lora_hooks) @@ -633,23 +634,29 @@ def patch_model(self, device_to=None, patch_weights=True, *args, **kwargs): self.patches.setdefault(key, []) self.patches[key].extend(relevant_patches[key]) self.current_lora_hooks = self.desired_lora_hooks - return super().patch_model(device_to, patch_weights, *args, **kwargs) + return super().patch_model(device_to, *args, **kwargs) - def patch_model_lowvram(self, *args, **kwargs): + def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs): try: - return super().patch_model_lowvram(*args, **kwargs) + return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs) finally: - # check if any modules have weight_function or bias_function that is not None - # NOTE: this serves no purpose currently, but I have it here for future reasons - for n, m in self.model.named_modules(): - if not hasattr(m, "comfy_cast_weights"): - continue - if getattr(m, "weight_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n - if getattr(m, "bias_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n + if lowvram_model_memory > 0: + self._patch_lowvram_extras() + + def _patch_lowvram_extras(self): + # check if any modules have weight_function or bias_function that is not None + # NOTE: this serves no purpose currently, but I have it here for future reasons + self.model_params_lowvram = False + self.model_params_lowvram_keys.clear() + for n, m in self.model.named_modules(): + if not hasattr(m, "comfy_cast_weights"): + continue + if getattr(m, "weight_function", None) is not None: + self.model_params_lowvram = True + self.model_params_lowvram_keys[f"{n}.weight"] = n + if getattr(m, "bias_function", None) is not None: + self.model_params_lowvram = True + self.model_params_lowvram_keys[f"{n}.weight"] = n def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): try: @@ -797,10 +804,14 @@ def __init__(self, *args, **kwargs): self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None - - def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, *args, **kwargs): - patched_model = super().patch_model_lowvram(device_to, lowvram_model_memory, force_patch_weights, *args, **kwargs) + def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs): + to_return = super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs) + if lowvram_model_memory > 0: + self._patch_lowvram_extras(device_to=device_to) + return to_return + + def _patch_lowvram_extras(self, device_to=None): # figure out the tensors (likely pe's) that should be cast to device besides just the named_modules remaining_tensors = list(self.model.state_dict().keys()) named_modules = [] @@ -817,8 +828,6 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc if device_to is not None: comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to)) - return patched_model - def pre_run(self, model: ModelPatcherAndInjector): self.cleanup() self.model.set_scale(self.scale_multival, self.per_block_list) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 6a73aea..8244f2c 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -12,12 +12,7 @@ from comfy.ldm.modules.diffusionmodules import openaimodel import comfy.model_management import comfy.samplers -import comfy.sample -SAMPLE_FALLBACK = False -try: - import comfy.sampler_helpers -except ImportError: - SAMPLE_FALLBACK = True +import comfy.sampler_helpers import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel @@ -291,10 +286,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara self.orig_diffusion_model_forward = model.model.diffusion_model.forward self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult - if SAMPLE_FALLBACK: # for backwards compatibility, for now - self.orig_get_additional_models = comfy.sample.get_additional_models - else: - self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models + self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models self.orig_apply_model = model.model.apply_model # Inject Functions openaimodel.forward_timestep_embed = forward_timestep_embed_factory() @@ -324,10 +316,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara del info comfy.samplers.sampling_function = evolved_sampling_function comfy.samplers.get_area_and_mult = get_area_and_mult_ADE - if SAMPLE_FALLBACK: # for backwards compatibility, for now - comfy.sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) - else: - comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) + comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) # create temp_uninjector to help facilitate uninjecting functions self.temp_uninjector = GroupnormUninjectHelper(self) @@ -341,10 +330,7 @@ def restore_functions(self, model: ModelPatcherAndInjector): model.model.diffusion_model.forward = self.orig_diffusion_model_forward comfy.samplers.sampling_function = self.orig_sampling_function comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult - if SAMPLE_FALLBACK: # for backwards compatibility, for now - comfy.sample.get_additional_models = self.orig_get_additional_models - else: - comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models + comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models model.model.apply_model = self.orig_apply_model except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ @@ -505,17 +491,8 @@ def ad_callback(step, x0, x, total_steps): if is_custom: iter_kwargs[IterationOptions.SAMPLER] = None #args[-5] else: - if SAMPLE_FALLBACK: # backwards compatibility, for now - # in older comfy, model needs to be loaded to get proper model_sampling to be used for sigmas - comfy.model_management.load_model_gpu(model) - iter_model = model.model - else: - iter_model = model - current_device = None - if hasattr(model, "current_device"): # backwards compatibility, for now - current_device = model.current_device - else: - current_device = model.model.device + iter_model = model + current_device = model.model.device iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler( iter_model, steps=999, #steps=args[-7], device=current_device, sampler=args[-5], @@ -653,35 +630,20 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() if not ADGS.is_using_sliding_context(): - cond_pred, uncond_pred = calc_cond_uncond_batch_wrapper(model, [cond, uncond_], x, timestep, model_options) + cond_pred, uncond_pred = calc_conds_batch_wrapper(model, [cond, uncond_], x, timestep, model_options) else: cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options) - if hasattr(comfy.samplers, "cfg_function"): - if ADGS.sample_settings.custom_cfg is not None: - cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred) - model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options) - try: - cached_calc_cond_batch = comfy.samplers.calc_cond_batch - # support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch - comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch) - return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) - finally: - comfy.samplers.calc_cond_batch = cached_calc_cond_batch - else: # for backwards compatibility, for now - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) - else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale - - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} - cfg_result = fn(args) - - return cfg_result + if ADGS.sample_settings.custom_cfg is not None: + cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred) + model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options) + try: + cached_calc_cond_batch = comfy.samplers.calc_cond_batch + # support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch + comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch) + return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) + finally: + comfy.samplers.calc_cond_batch = cached_calc_cond_batch finally: ADGS.restore_special_model_features(model) @@ -745,7 +707,7 @@ def wrapped_cfg_sliding_calc_cond_batch(model, conds, x_in, timestep, model_opti # when inside sliding_calc_conds_batch, should return to original calc_cond_batch comfy.samplers.calc_cond_batch = orig_calc_cond_batch if not ADGS.is_using_sliding_context(): - return calc_cond_uncond_batch_wrapper(model, conds, x_in, timestep, model_options) + return calc_conds_batch_wrapper(model, conds, x_in, timestep, model_options) else: return sliding_calc_conds_batch(model, conds, x_in, timestep, model_options) finally: @@ -922,7 +884,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}") - sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) + sub_conds_out = calc_conds_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: full_length = ADGS.params.full_length @@ -1008,7 +970,7 @@ def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseSh return new_conds -def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options): +def calc_conds_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options): # check if conds or unconds contain lora_hook or default_cond contains_lora_hooks = False has_default_cond = False @@ -1028,9 +990,6 @@ def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, times ADGS.hooks_initialize(model, hook_groups=hook_groups) ADGS.prepare_hooks_current_keyframes(timestep, hook_groups=hook_groups) return calc_conds_batch_lora_hook(model, conds, x_in, timestep, model_options, has_default_cond) - # keep for backwards compatibility, for now - if not hasattr(comfy.samplers, "calc_cond_batch"): - return comfy.samplers.calc_cond_uncond_batch(model, conds[0], conds[1], x_in, timestep, model_options) return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options) diff --git a/pyproject.toml b/pyproject.toml index ce99f35..5816d62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.1.4" +version = "1.2.0" license = { file = "LICENSE" } dependencies = []