diff --git a/nodes.py b/nodes.py index 536cf97..76c2224 100644 --- a/nodes.py +++ b/nodes.py @@ -235,7 +235,7 @@ def INPUT_TYPES(s): def loadmodel(self, model, base_precision, load_device, quantization, compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None): transformer = None - mm.unload_all_models() + #mm.unload_all_models() mm.soft_empty_cache() manual_offloading = True if "sage" in attention_mode: @@ -328,7 +328,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0) - comfy.model_management.load_models_gpu([patcher], force_full_load=True, force_patch_weights=True) + comfy.model_management.load_models_gpu([patcher]) if load_device == "offload_device": patcher.model.diffusion_model.to(offload_device) @@ -488,6 +488,7 @@ def loadmodel(self, model_name, precision, compile_args=None): if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) + return (vae,) diff --git a/nodes_rf_inversion.py b/nodes_rf_inversion.py index 0eec2a1..3d1da2b 100644 --- a/nodes_rf_inversion.py +++ b/nodes_rf_inversion.py @@ -7,6 +7,7 @@ from diffusers.utils.torch_utils import randn_tensor import comfy.model_management as mm from .hyvideo.diffusion.pipelines.pipeline_hunyuan_video import get_rotary_pos_embed +from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight script_directory = os.path.dirname(os.path.abspath(__file__)) @@ -288,6 +289,7 @@ def INPUT_TYPES(s): }, "optional": { "interpolation_curve": ("FLOAT", {"forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), + "feta_args": ("FETAARGS", ), } } @@ -472,6 +474,7 @@ def INPUT_TYPES(s): }, "optional": { "interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), + "feta_args": ("FETAARGS", ), } } @@ -482,7 +485,7 @@ def INPUT_TYPES(s): EXPERIMENTAL = True def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2, flow_shift, steps, embedded_guidance_scale, - seed, force_offload, alpha, interpolation_curve=None): + seed, force_offload, alpha, interpolation_curve=None, feta_args=None): model = model.model device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -561,6 +564,14 @@ def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2 latents_1 = latents.clone() latents_2 = latents.clone() + if feta_args is not None: + set_enhance_weight(feta_args["weight"]) + feta_start_percent = feta_args["start_percent"] + feta_end_percent = feta_args["end_percent"] + enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"]) + else: + disable_enhance() + # 7. Denoising loop self._num_timesteps = len(timesteps) @@ -574,6 +585,13 @@ def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2 with tqdm(total=len(timesteps)) as progress_bar: for idx, t in enumerate(timesteps): + current_step_percentage = idx / len(timesteps) + + if feta_args is not None: + if feta_start_percent <= current_step_percentage <= feta_end_percent: + enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"]) + else: + disable_enhance() # Pre-compute weighted latents weighted_latents_1 = torch.zeros_like(latents_1)