Skip to content

Commit

Permalink
Prevent unloading LoRAs if multiple model loaders are used
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 22, 2024
1 parent b9ea9bf commit cbfc632
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
5 changes: 3 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def INPUT_TYPES(s):
def loadmodel(self, model, base_precision, load_device, quantization,
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None):
transformer = None
mm.unload_all_models()
#mm.unload_all_models()
mm.soft_empty_cache()
manual_offloading = True
if "sage" in attention_mode:
Expand Down Expand Up @@ -328,7 +328,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,

patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0)

comfy.model_management.load_models_gpu([patcher], force_full_load=True, force_patch_weights=True)
comfy.model_management.load_models_gpu([patcher])
if load_device == "offload_device":
patcher.model.diffusion_model.to(offload_device)

Expand Down Expand Up @@ -488,6 +488,7 @@ def loadmodel(self, model_name, precision, compile_args=None):
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])


return (vae,)

Expand Down
20 changes: 19 additions & 1 deletion nodes_rf_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from diffusers.utils.torch_utils import randn_tensor
import comfy.model_management as mm
from .hyvideo.diffusion.pipelines.pipeline_hunyuan_video import get_rotary_pos_embed
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight

script_directory = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -288,6 +289,7 @@ def INPUT_TYPES(s):
},
"optional": {
"interpolation_curve": ("FLOAT", {"forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}),
"feta_args": ("FETAARGS", ),

}
}
Expand Down Expand Up @@ -472,6 +474,7 @@ def INPUT_TYPES(s):
},
"optional": {
"interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}),
"feta_args": ("FETAARGS", ),
}
}

Expand All @@ -482,7 +485,7 @@ def INPUT_TYPES(s):
EXPERIMENTAL = True

def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2, flow_shift, steps, embedded_guidance_scale,
seed, force_offload, alpha, interpolation_curve=None):
seed, force_offload, alpha, interpolation_curve=None, feta_args=None):
model = model.model
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
Expand Down Expand Up @@ -561,6 +564,14 @@ def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2
latents_1 = latents.clone()
latents_2 = latents.clone()

if feta_args is not None:
set_enhance_weight(feta_args["weight"])
feta_start_percent = feta_args["start_percent"]
feta_end_percent = feta_args["end_percent"]
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()

# 7. Denoising loop
self._num_timesteps = len(timesteps)

Expand All @@ -574,6 +585,13 @@ def process(self, model, width, height, num_frames, hyvid_embeds, hyvid_embeds_2

with tqdm(total=len(timesteps)) as progress_bar:
for idx, t in enumerate(timesteps):
current_step_percentage = idx / len(timesteps)

if feta_args is not None:
if feta_start_percent <= current_step_percentage <= feta_end_percent:
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()

# Pre-compute weighted latents
weighted_latents_1 = torch.zeros_like(latents_1)
Expand Down

0 comments on commit cbfc632

Please sign in to comment.