Skip to content

Commit fc24715

Browse files
authored
Implement EasyCache and Invent LazyCache (Comfy-Org#9496)
* Attempting a universal implementation of EasyCache, starting with flux as test; I screwed up the math a bit, but when I set it just right it works. * Fixed math to make threshold work as expected, refactored code to use EasyCacheHolder instead of a dict wrapped by object * Use sigmas from transformer_options instead of timesteps to be compatible with a greater amount of models, make end_percent work * Make log statement when not skipping useful, preparing for per-cond caching * Added DIFFUSION_MODEL wrapper around forward function for wan model * Add subsampling for heuristic inputs * Add subsampling to output_prev (output_prev_subsampled now) * Properly consider conds in EasyCache logic * Created SuperEasyCache to test what happens if caching and reuse is moved outside the scope of conds, added PREDICT_NOISE wrapper to facilitate this test * Change max reuse_threshold to 3.0 * Mark EasyCache/SuperEasyCache as experimental (beta) * Make Lumina2 compatible with EasyCache * Add EasyCache support for Qwen Image * Fix missing comma, curse you Cursor * Add EasyCache support to AceStep * Add EasyCache support to Chroma * Added EasyCache support to Cosmos Predict t2i * Make EasyCache not crash with Cosmos Predict ImagToVideo latents, but does not work well at all * Add EasyCache support to hidream * Added EasyCache support to hunyuan video * Added EasyCache support to hunyuan3d * Added EasyCache support to LTXV (not very good, but does not crash) * Implemented EasyCache for aura_flow * Renamed SuperEasyCache to LazyCache, hardcoded subsample_factor to 8 on nodes * Eatra logging when verbose is true for EasyCache
1 parent fe31ad0 commit fc24715

File tree

17 files changed

+639
-7
lines changed

17 files changed

+639
-7
lines changed

comfy/ldm/ace/model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import nn
2020

2121
import comfy.model_management
22+
import comfy.patcher_extension
2223

2324
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
2425
from .attention import LinearTransformerBlock, t2i_modulate
@@ -343,7 +344,28 @@ def decode(
343344
output = self.final_layer(hidden_states, embedded_timestep, output_length)
344345
return output
345346

346-
def forward(
347+
def forward(self,
348+
x,
349+
timestep,
350+
attention_mask=None,
351+
context: Optional[torch.Tensor] = None,
352+
text_attention_mask: Optional[torch.LongTensor] = None,
353+
speaker_embeds: Optional[torch.FloatTensor] = None,
354+
lyric_token_idx: Optional[torch.LongTensor] = None,
355+
lyric_mask: Optional[torch.LongTensor] = None,
356+
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
357+
controlnet_scale: Union[float, torch.Tensor] = 1.0,
358+
lyrics_strength=1.0,
359+
**kwargs
360+
):
361+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
362+
self._forward,
363+
self,
364+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
365+
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
366+
controlnet_scale, lyrics_strength, **kwargs)
367+
368+
def _forward(
347369
self,
348370
x,
349371
timestep,

comfy/ldm/aura/mmdit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from comfy.ldm.modules.attention import optimized_attention
1111
import comfy.ops
12+
import comfy.patcher_extension
1213
import comfy.ldm.common_dit
1314

1415
def modulate(x, shift, scale):
@@ -436,6 +437,13 @@ def apply_pos_embeds(self, x, h, w):
436437
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
437438

438439
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
440+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
441+
self._forward,
442+
self,
443+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
444+
).execute(x, timestep, context, transformer_options, **kwargs)
445+
446+
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
439447
patches_replace = transformer_options.get("patches_replace", {})
440448
# patchify x, add PE
441449
b, c, h, w = x.shape

comfy/ldm/chroma/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch import Tensor, nn
77
from einops import rearrange, repeat
8+
import comfy.patcher_extension
89
import comfy.ldm.common_dit
910

1011
from comfy.ldm.flux.layers import (
@@ -253,6 +254,13 @@ def block_wrap(args):
253254
return img
254255

255256
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
257+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
258+
self._forward,
259+
self,
260+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
261+
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
262+
263+
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
256264
bs, c, h, w = x.shape
257265
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
258266

comfy/ldm/cosmos/model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from enum import Enum
2828
import logging
2929

30+
import comfy.patcher_extension
31+
3032
from .blocks import (
3133
FinalLayer,
3234
GeneralDITTransformerBlock,
@@ -435,6 +437,42 @@ def forward(
435437
latent_condition_sigma: Optional[torch.Tensor] = None,
436438
condition_video_augment_sigma: Optional[torch.Tensor] = None,
437439
**kwargs,
440+
):
441+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
442+
self._forward,
443+
self,
444+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
445+
).execute(x,
446+
timesteps,
447+
context,
448+
attention_mask,
449+
fps,
450+
image_size,
451+
padding_mask,
452+
scalar_feature,
453+
data_type,
454+
latent_condition,
455+
latent_condition_sigma,
456+
condition_video_augment_sigma,
457+
**kwargs)
458+
459+
def _forward(
460+
self,
461+
x: torch.Tensor,
462+
timesteps: torch.Tensor,
463+
context: torch.Tensor,
464+
attention_mask: Optional[torch.Tensor] = None,
465+
# crossattn_emb: torch.Tensor,
466+
# crossattn_mask: Optional[torch.Tensor] = None,
467+
fps: Optional[torch.Tensor] = None,
468+
image_size: Optional[torch.Tensor] = None,
469+
padding_mask: Optional[torch.Tensor] = None,
470+
scalar_feature: Optional[torch.Tensor] = None,
471+
data_type: Optional[DataType] = DataType.VIDEO,
472+
latent_condition: Optional[torch.Tensor] = None,
473+
latent_condition_sigma: Optional[torch.Tensor] = None,
474+
condition_video_augment_sigma: Optional[torch.Tensor] = None,
475+
**kwargs,
438476
):
439477
"""
440478
Args:

comfy/ldm/cosmos/predict2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
1212
from torchvision import transforms
1313

14+
import comfy.patcher_extension
1415
from comfy.ldm.modules.attention import optimized_attention
1516

1617
def apply_rotary_pos_emb(
@@ -805,7 +806,21 @@ def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
805806
)
806807
return x_B_C_Tt_Hp_Wp
807808

808-
def forward(
809+
def forward(self,
810+
x: torch.Tensor,
811+
timesteps: torch.Tensor,
812+
context: torch.Tensor,
813+
fps: Optional[torch.Tensor] = None,
814+
padding_mask: Optional[torch.Tensor] = None,
815+
**kwargs,
816+
):
817+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
818+
self._forward,
819+
self,
820+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
821+
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
822+
823+
def _forward(
809824
self,
810825
x: torch.Tensor,
811826
timesteps: torch.Tensor,

comfy/ldm/flux/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import Tensor, nn
77
from einops import rearrange, repeat
88
import comfy.ldm.common_dit
9+
import comfy.patcher_extension
910

1011
from .layers import (
1112
DoubleStreamBlock,
@@ -214,6 +215,13 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
214215
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
215216

216217
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
218+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
219+
self._forward,
220+
self,
221+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
222+
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
223+
224+
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
217225
bs, c, h_orig, w_orig = x.shape
218226
patch_size = self.patch_size
219227

comfy/ldm/hidream/model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from comfy.ldm.modules.attention import optimized_attention
1515
import comfy.model_management
16+
import comfy.patcher_extension
1617
import comfy.ldm.common_dit
1718

1819

@@ -692,7 +693,23 @@ def patchify(self, x, max_seq, img_sizes=None):
692693
raise NotImplementedError
693694
return x, x_masks, img_sizes
694695

695-
def forward(
696+
def forward(self,
697+
x: torch.Tensor,
698+
t: torch.Tensor,
699+
y: Optional[torch.Tensor] = None,
700+
context: Optional[torch.Tensor] = None,
701+
encoder_hidden_states_llama3=None,
702+
image_cond=None,
703+
control = None,
704+
transformer_options = {},
705+
):
706+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
707+
self._forward,
708+
self,
709+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
710+
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
711+
712+
def _forward(
696713
self,
697714
x: torch.Tensor,
698715
t: torch.Tensor,

comfy/ldm/hunyuan3d/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SingleStreamBlock,
88
timestep_embedding,
99
)
10+
import comfy.patcher_extension
1011

1112

1213
class Hunyuan3Dv2(nn.Module):
@@ -67,6 +68,13 @@ def __init__(
6768
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
6869

6970
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
71+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
72+
self._forward,
73+
self,
74+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
75+
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
76+
77+
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
7078
x = x.movedim(-1, -2)
7179
timestep = 1.0 - timestep
7280
txt = context

comfy/ldm/hunyuan_video/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#Based on Flux code because of weird hunyuan video code license.
22

33
import torch
4+
import comfy.patcher_extension
45
import comfy.ldm.flux.layers
56
import comfy.ldm.modules.diffusionmodules.mmdit
67
from comfy.ldm.modules.attention import optimized_attention
@@ -348,6 +349,13 @@ def img_ids(self, x):
348349
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
349350

350351
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
352+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
353+
self._forward,
354+
self,
355+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
356+
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
357+
358+
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
351359
bs, c, t, h, w = x.shape
352360
img_ids = self.img_ids(x)
353361
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

comfy/ldm/lightricks/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch import nn
3+
import comfy.patcher_extension
34
import comfy.ldm.modules.attention
45
import comfy.ldm.common_dit
56
from einops import rearrange
@@ -420,6 +421,13 @@ def __init__(self,
420421
self.patchifier = SymmetricPatchifier(1)
421422

422423
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
424+
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
425+
self._forward,
426+
self,
427+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
428+
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
429+
430+
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
423431
patches_replace = transformer_options.get("patches_replace", {})
424432

425433
orig_shape = list(x.shape)

0 commit comments

Comments
 (0)