From ac67be542a08046daf6a92514329967ae22cf5ac Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 10 Feb 2025 12:38:17 -0500 Subject: [PATCH 01/19] rcfg --- __init__.py | 8 +- controls/utility_controls.py | 2 - stream_diffusion_nodes.py | 294 +++++++++++++++++++++++++++++++++++ 3 files changed, 301 insertions(+), 3 deletions(-) create mode 100644 stream_diffusion_nodes.py diff --git a/__init__.py b/__init__.py index d7c660e..5671113 100644 --- a/__init__.py +++ b/__init__.py @@ -3,9 +3,11 @@ from .controls.utility_controls import FPSMonitor, SimilarityFilter, LazyCondition from .controls.motion_controls import MotionController, ROINode, IntegerMotionController from .misc_nodes import DTypeConverter, FastWebcamCapture, YOLOSimilarityCompare, TextRenderer, QuickShapeMask, MultilineText, LoadImageFromPath_ - +from .stream_diffusion_nodes import StreamConditioning, StreamCFG import re + + NODE_CLASS_MAPPINGS = { "FloatControl": FloatControl, "IntControl": IntControl, @@ -15,12 +17,16 @@ "StringSequence": StringSequence, "FPSMonitor": FPSMonitor, "SimilarityFilter": SimilarityFilter, + "StreamCFG": StreamCFG, + "StreamConditioning": StreamConditioning, "LazyCondition": LazyCondition, "MotionController": MotionController, "IntegerMotionController": IntegerMotionController, "YOLOSimilarityCompare": YOLOSimilarityCompare, "TextRenderer": TextRenderer, + "ROINode": ROINode, + #"IntervalControl": IntervalCo ntrol, #"DeltaControl": DeltaControl, "QuickShapeMask": QuickShapeMask, diff --git a/controls/utility_controls.py b/controls/utility_controls.py index 9b3f382..8f29533 100644 --- a/controls/utility_controls.py +++ b/controls/utility_controls.py @@ -269,5 +269,3 @@ def update(self, condition, if_true, fallback, use_fallback): - - diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py new file mode 100644 index 0000000..4ccb605 --- /dev/null +++ b/stream_diffusion_nodes.py @@ -0,0 +1,294 @@ +import torch +from .base.control_base import ControlNodeBase + +class StreamCFG(ControlNodeBase): + """Implements CFG approaches for temporal consistency between workflow runs""" + + RETURN_TYPES = ("MODEL",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "model": ("MODEL",), + "cfg_type": (["self", "full", "initialize"], { + "default": "self", + "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" + }), + "residual_scale": ("FLOAT", { + "default": 0.4, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Scale factor for residual (higher = more temporal consistency)" + }), + "delta": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 5.0, + "step": 0.1, + "tooltip": "Delta parameter for self/initialize CFG types" + }), + }) + return inputs + + def __init__(self): + super().__init__() + # Store the last model to detect when we need to reapply the hook + self.last_model_hash = None + self.post_cfg_function = None + + def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.4, delta=1.0): + print(f"[StreamCFG] Initializing with cfg_type={cfg_type}, residual_scale={residual_scale}, delta={delta}") + + # Get state with defaults + state = self.get_state({ + "last_uncond": None, # Store last workflow's unconditioned prediction + "initialized": False, + "cfg_type": cfg_type, + "residual_scale": residual_scale, + "delta": delta, + "workflow_count": 0, # Track number of workflow runs + "current_sigmas": None, # Track sigmas for this workflow + "seen_sigmas": set(), # Track which sigmas we've seen this workflow + "is_last_step": False, # Track if we're on the last step + "alpha_prod_t": None, # Store alpha values for proper scaling + "beta_prod_t": None, # Store beta values for proper scaling + }) + + def post_cfg_function(args): + # Extract info + denoised = args["denoised"] + cond = args["cond"] + uncond = args["uncond"] + cond_denoised = args["cond_denoised"] + uncond_denoised = args["uncond_denoised"] + cond_scale = args["cond_scale"] + sigma = args["sigma"].item() if torch.is_tensor(args["sigma"]) else args["sigma"] + + # Get step info from model options + model_options = args["model_options"] + sample_sigmas = model_options["transformer_options"].get("sample_sigmas", None) + + # Update current sigmas if needed + if sample_sigmas is not None and state["current_sigmas"] is None: + # Filter out the trailing 0.0 if present + sigmas = [s.item() for s in sample_sigmas] + if sigmas[-1] == 0.0: + sigmas = sigmas[:-1] + state["current_sigmas"] = sigmas + state["seen_sigmas"] = set() + + # Calculate alpha and beta values for proper scaling + alphas = [1.0 / (1.0 + s**2) for s in sigmas] + state["alpha_prod_t"] = torch.tensor(alphas, device=denoised.device, dtype=denoised.dtype) + state["beta_prod_t"] = torch.sqrt(1 - state["alpha_prod_t"]) + + # Track this sigma + state["seen_sigmas"].add(sigma) + + # Check if this is the last step + state["is_last_step"] = False + if state["current_sigmas"] is not None: + # It's the last step if we've seen all sigmas + is_last_step = len(state["seen_sigmas"]) >= len(state["current_sigmas"]) + # Or if this is the smallest sigma in the sequence + if not is_last_step and sigma == min(state["current_sigmas"]): + is_last_step = True + state["is_last_step"] = is_last_step + + # First workflow case + if state["last_uncond"] is None: + if state["is_last_step"]: + state["last_uncond"] = uncond_denoised.detach().clone() + state["workflow_count"] += 1 + state["current_sigmas"] = None # Reset for next workflow + if cfg_type == "initialize": + state["initialized"] = True + self.set_state(state) + return denoised + + # Handle different CFG types for subsequent workflows + if cfg_type == "full": + result = denoised + + elif cfg_type == "initialize" and not state["initialized"]: + result = denoised + if state["is_last_step"]: + state["initialized"] = True + self.set_state(state) + + else: # self or initialized initialize + # Get current step index + current_idx = len(state["seen_sigmas"]) - 1 + + # Scale last prediction with proper alpha/beta values + noise_pred_uncond = state["last_uncond"] * delta + + # Apply CFG with scaled prediction + result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) + + # Store last prediction if this is the last step + if state["is_last_step"]: + # Calculate properly scaled residual + scaled_noise = state["beta_prod_t"][current_idx] * state["last_uncond"] + delta_x = uncond_denoised - scaled_noise + + # Scale delta_x with next step's alpha/beta + if current_idx < len(state["current_sigmas"]) - 1: + alpha_next = state["alpha_prod_t"][current_idx + 1] + beta_next = state["beta_prod_t"][current_idx + 1] + else: + alpha_next = torch.ones_like(state["alpha_prod_t"][0]) + beta_next = torch.ones_like(state["beta_prod_t"][0]) + + delta_x = alpha_next * delta_x / beta_next + + # Update stored prediction with scaled residual + final_update = uncond_denoised + residual_scale * delta_x + state["last_uncond"] = final_update + state["workflow_count"] += 1 + state["current_sigmas"] = None # Reset for next workflow + self.set_state(state) + + return result + + # Store function reference to prevent garbage collection + self.post_cfg_function = post_cfg_function + + # Only set up post CFG function if model has changed + model_hash = hash(str(model)) + if model_hash != self.last_model_hash: + m = model.clone() + m.model_options = m.model_options.copy() + m.model_options["sampler_post_cfg_function"] = [self.post_cfg_function] + self.last_model_hash = model_hash + return (m,) + + # Make sure our function is still in the list + if not any(f is self.post_cfg_function for f in model.model_options.get("sampler_post_cfg_function", [])): + m = model.clone() + m.model_options = m.model_options.copy() + m.model_options["sampler_post_cfg_function"] = [self.post_cfg_function] + return (m,) + + return (model,) + + +class StreamConditioning(ControlNodeBase): + """Applies Residual CFG to conditioning for improved temporal consistency with different CFG types""" + #NOTE: experimental + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "cfg_type": (["full", "self", "initialize"], { + "default": "full", + "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" + }), + "residual_scale": ("FLOAT", { + "default": 0.4, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Scale factor for residual conditioning (higher = more temporal consistency)" + }), + "delta": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 5.0, + "step": 0.1, + "tooltip": "Delta parameter for self/initialize CFG types" + }), + "always_execute": ("BOOLEAN", { + "default": False, + }), + } + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "update" + CATEGORY = "real-time/control/utility" + + def __init__(self): + super().__init__() + + def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta=1.0, always_execute=False): + # Get state with defaults + state = self.get_state({ + "prev_positive": None, + "prev_negative": None, + "stock_noise": None, # For self/initialize CFG + "initialized": False # For initialize CFG + }) + + # Extract conditioning tensors + current_pos_cond = positive[0][0] # Assuming standard ComfyUI conditioning format + current_neg_cond = negative[0][0] + + # First frame case + if state["prev_positive"] is None: + state["prev_positive"] = current_pos_cond.detach().clone() + state["prev_negative"] = current_neg_cond.detach().clone() + if cfg_type == "initialize": + # For initialize, we use the first negative as our stock noise + state["stock_noise"] = current_neg_cond.detach().clone() + elif cfg_type == "self": + # For self, we start with a scaled version of the negative + state["stock_noise"] = current_neg_cond.detach().clone() * delta + self.set_state(state) + return (positive, negative) + + # Handle different CFG types + if cfg_type == "full": + # Standard R-CFG with full negative conditioning + pos_residual = current_pos_cond - state["prev_positive"] + neg_residual = current_neg_cond - state["prev_negative"] + + blended_pos = current_pos_cond + residual_scale * pos_residual + blended_neg = current_neg_cond + residual_scale * neg_residual + + # Update state + state["prev_positive"] = current_pos_cond.detach().clone() + state["prev_negative"] = current_neg_cond.detach().clone() + + # Reconstruct conditioning format + positive_out = [[blended_pos, positive[0][1]]] + negative_out = [[blended_neg, negative[0][1]]] + + else: # self or initialize + # Calculate residual for positive conditioning + pos_residual = current_pos_cond - state["prev_positive"] + blended_pos = current_pos_cond + residual_scale * pos_residual + + # Update stock noise based on current prediction + if cfg_type == "initialize" and not state["initialized"]: + # First prediction for initialize type + state["stock_noise"] = current_neg_cond.detach().clone() + state["initialized"] = True + else: + # Update stock noise with temporal consistency + stock_residual = current_neg_cond - state["stock_noise"] + state["stock_noise"] = current_neg_cond + residual_scale * stock_residual + + # Scale stock noise by delta + scaled_stock = state["stock_noise"] * delta + + # Update state + state["prev_positive"] = current_pos_cond.detach().clone() + state["prev_negative"] = scaled_stock.detach().clone() + + # Reconstruct conditioning format + positive_out = [[blended_pos, positive[0][1]]] + negative_out = [[scaled_stock, negative[0][1]]] + + self.set_state(state) + return (positive_out, negative_out) + + + From 135d012a0fad2acb8fdb3c1d47945305412de6ea Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:53:42 -0500 Subject: [PATCH 02/19] streambatchsampler, streambatch scheduler, stream corss attention --- __init__.py | 7 +- stream_diffusion_nodes.py | 321 +++++++++++++++++++++++++++++++++++++- 2 files changed, 326 insertions(+), 2 deletions(-) diff --git a/__init__.py b/__init__.py index 5671113..2ee77c9 100644 --- a/__init__.py +++ b/__init__.py @@ -3,7 +3,7 @@ from .controls.utility_controls import FPSMonitor, SimilarityFilter, LazyCondition from .controls.motion_controls import MotionController, ROINode, IntegerMotionController from .misc_nodes import DTypeConverter, FastWebcamCapture, YOLOSimilarityCompare, TextRenderer, QuickShapeMask, MultilineText, LoadImageFromPath_ -from .stream_diffusion_nodes import StreamConditioning, StreamCFG +from .stream_diffusion_nodes import StreamConditioning, StreamCFG, StreamBatchSampler, StreamScheduler, StreamCrossAttention import re @@ -19,12 +19,17 @@ "SimilarityFilter": SimilarityFilter, "StreamCFG": StreamCFG, "StreamConditioning": StreamConditioning, + "StreamBatchSampler": StreamBatchSampler, + "StreamScheduler": StreamScheduler, + "StreamCrossAttention": StreamCrossAttention, "LazyCondition": LazyCondition, "MotionController": MotionController, "IntegerMotionController": IntegerMotionController, "YOLOSimilarityCompare": YOLOSimilarityCompare, "TextRenderer": TextRenderer, + + "ROINode": ROINode, #"IntervalControl": IntervalCo ntrol, diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index 4ccb605..bc13a65 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -1,5 +1,8 @@ import torch from .base.control_base import ControlNodeBase +import comfy.model_management +import comfy.samplers +import random class StreamCFG(ControlNodeBase): """Implements CFG approaches for temporal consistency between workflow runs""" @@ -66,7 +69,13 @@ def post_cfg_function(args): cond_denoised = args["cond_denoised"] uncond_denoised = args["uncond_denoised"] cond_scale = args["cond_scale"] - sigma = args["sigma"].item() if torch.is_tensor(args["sigma"]) else args["sigma"] + + # Handle both batched and single sigmas + sigma = args["sigma"] + if torch.is_tensor(sigma): + # For batched sampling, use first sigma in batch + # This is safe because we process in order and track seen sigmas + sigma = sigma[0].item() if len(sigma.shape) > 0 else sigma.item() # Get step info from model options model_options = args["model_options"] @@ -291,4 +300,314 @@ def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta= return (positive_out, negative_out) +class StreamBatchSampler(ControlNodeBase): + """Implements batched denoising for faster inference by processing multiple steps in parallel""" + + RETURN_TYPES = ("SAMPLER",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "batch_size": ("INT", { + "default": 2, + "min": 1, + "max": 10, + "step": 1, + "tooltip": "Number of steps to batch together. Higher values use more memory but are faster." + }), + }) + return inputs + + def __init__(self): + super().__init__() + self.batch_size = None + + def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): + """Sample with batched denoising steps""" + extra_args = {} if extra_args is None else extra_args + print(f"[StreamBatchSampler] Starting sampling with {len(sigmas)-1} steps, batch_size={self.batch_size}") + print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, device: {noise.device}") + print(f"[StreamBatchSampler] Sigmas: {sigmas.tolist()}") + + # Prepare batched sampling + num_sigmas = len(sigmas) - 1 + num_batches = (num_sigmas + self.batch_size - 1) // self.batch_size + x = noise + + for batch_idx in range(num_batches): + # Get sigmas for this batch + start_idx = batch_idx * self.batch_size + end_idx = min(start_idx + self.batch_size, num_sigmas) + batch_sigmas = sigmas[start_idx:end_idx+1] + print(f"\n[StreamBatchSampler] Batch {batch_idx+1}/{num_batches}") + print(f"[StreamBatchSampler] Processing steps {start_idx}-{end_idx}") + print(f"[StreamBatchSampler] Batch sigmas: {batch_sigmas.tolist()}") + + # Create batch of identical latents + batch_size = end_idx - start_idx + x_batch = x.repeat(batch_size, 1, 1, 1) + + # Create batch of sigmas + sigma_batch = batch_sigmas[:-1] # All but last sigma + + # Run model on entire batch at once + with torch.no_grad(): + # Process all steps in parallel + denoised_batch = model(x_batch, sigma_batch, **extra_args) + print(f"[StreamBatchSampler] Denoised batch shape: {denoised_batch.shape}") + + # Process results one at a time to maintain callback + for i in range(batch_size): + sigma = sigma_batch[i] + sigma_next = batch_sigmas[i + 1] + denoised = denoised_batch[i:i+1] + + # Calculate step size (now always positive as we go from high to low sigma) + dt = sigma - sigma_next + + # Update x using Euler method + # The (denoised - x) term gives us the direction to move + # dt/sigma scales how far we move based on current noise level + x = x + (denoised - x) * (dt / sigma) + print(f"[StreamBatchSampler] Step {start_idx+i}: sigma={sigma:.4f}, next_sigma={sigma_next:.4f}, dt={dt:.4f}") + + # Call callback if provided + if callback is not None: + callback({'x': x, 'i': start_idx + i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + + print(f"\n[StreamBatchSampler] Sampling complete. Final x shape: {x.shape}") + return x + + def update(self, batch_size=2, always_execute=True): + """Create sampler with specified settings""" + self.batch_size = batch_size + sampler = comfy.samplers.KSAMPLER(self.sample) + return (sampler,) + + +class StreamScheduler(ControlNodeBase): + """Implements StreamDiffusion's efficient timestep selection""" + + RETURN_TYPES = ("SIGMAS",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "model": ("MODEL",), + "t_index_list": ("STRING", { + "default": "32,45", + "tooltip": "Comma-separated list of timesteps to actually use for denoising. Examples: '32,45' for img2img or '0,16,32,45' for txt2img" + }), + "num_inference_steps": ("INT", { + "default": 50, + "min": 1, + "max": 1000, + "step": 1, + "tooltip": "Total number of timesteps in schedule. StreamDiffusion uses 50 by default. Only timesteps specified in t_index_list are actually used." + }), + }) + return inputs + + def update(self, model, t_index_list="32,45", num_inference_steps=50, always_execute=True): + # Get model's sampling parameters + model_sampling = model.get_model_object("model_sampling") + print(f"[StreamScheduler] Model sampling max sigma: {model_sampling.sigma_max}, min sigma: {model_sampling.sigma_min}") + + # Parse timestep list + try: + t_index_list = [int(t.strip()) for t in t_index_list.split(",")] + except ValueError as e: + print(f"Error parsing timesteps: {e}. Using default [32,45]") + t_index_list = [32, 45] + print(f"[StreamScheduler] Using timesteps: {t_index_list}") + + # Create full schedule using normal scheduler + full_sigmas = comfy.samplers.normal_scheduler(model_sampling, num_inference_steps) + print(f"[StreamScheduler] Full sigma schedule: {full_sigmas.tolist()}") + + # Select only the sigmas at our desired indices, but in reverse order + # This ensures we go from high noise to low noise + selected_sigmas = [] + for t in sorted(t_index_list, reverse=True): # Sort in reverse to go from high noise to low + if t < 0 or t >= num_inference_steps: + print(f"Warning: timestep {t} out of range [0,{num_inference_steps}), skipping") + continue + selected_sigmas.append(float(full_sigmas[t])) + print(f"[StreamScheduler] Selected sigmas: {selected_sigmas}") + + # Add final sigma + selected_sigmas.append(0.0) + print(f"[StreamScheduler] Final sigma schedule: {selected_sigmas}") + + # Convert to tensor and move to appropriate device + selected_sigmas = torch.FloatTensor(selected_sigmas).to(comfy.model_management.get_torch_device()) + return (selected_sigmas,) + + + + + +class StreamCrossAttention(ControlNodeBase): + """Implements optimized cross attention for real-time generation""" + + RETURN_TYPES = ("MODEL",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "model": ("MODEL",), + "qk_norm": ("BOOLEAN", { + "default": True, + "tooltip": "Whether to apply layer normalization to query and key tensors" + }), + "use_rope": ("BOOLEAN", { + "default": True, + "tooltip": "Whether to use rotary position embeddings for better temporal consistency" + }), + }) + return inputs + + def __init__(self): + super().__init__() + self.last_model_hash = None + self.cross_attention_hook = None + + def update(self, model, always_execute=True, qk_norm=True, use_rope=True): + print(f"[StreamCrossAttention] Initializing with qk_norm={qk_norm}, use_rope={use_rope}") + + # Get state with defaults + state = self.get_state({ + "qk_norm": qk_norm, + "use_rope": use_rope, + "workflow_count": 0, + }) + + def cross_attention_forward(module, x, context=None, mask=None, value=None): + q = module.to_q(x) + context = x if context is None else context + k = module.to_k(context) + # Use provided value tensor if given, otherwise compute it + v = value if value is not None else module.to_v(context) + + # Apply QK normalization if enabled + if state["qk_norm"]: + q_norm = torch.nn.LayerNorm(q.shape[-1], device=q.device, dtype=q.dtype) + k_norm = torch.nn.LayerNorm(k.shape[-1], device=k.device, dtype=k.dtype) + q = q_norm(q) + k = k_norm(k) + + # Apply rotary embeddings if enabled + if state["use_rope"]: + # Calculate position embeddings + batch_size = q.shape[0] + seq_len = q.shape[1] + dim = q.shape[2] + + # Create position indices + position = torch.arange(seq_len, device=q.device).unsqueeze(0).unsqueeze(-1) + position = position.repeat(batch_size, 1, dim//2) + + # Calculate frequencies + freq = 10000.0 ** (-torch.arange(0, dim//2, 2, device=q.device) / dim) + freq = freq.repeat((dim + 1) // 2)[:dim//2] + + # Calculate rotation angles + theta = position * freq + + # Apply rotations + cos = torch.cos(theta) + sin = torch.sin(theta) + + # Reshape q and k for rotation + q_reshaped = q.view(*q.shape[:-1], -1, 2) + k_reshaped = k.view(*k.shape[:-1], -1, 2) + + # Apply rotations + q_out = torch.cat([ + q_reshaped[..., 0] * cos - q_reshaped[..., 1] * sin, + q_reshaped[..., 0] * sin + q_reshaped[..., 1] * cos + ], dim=-1) + + k_out = torch.cat([ + k_reshaped[..., 0] * cos - k_reshaped[..., 1] * sin, + k_reshaped[..., 0] * sin + k_reshaped[..., 1] * cos + ], dim=-1) + + q = q_out + k = k_out + + # Compute attention with optimized memory access pattern + batch_size, seq_len = q.shape[0], q.shape[1] + head_dim = q.shape[-1] // module.heads + + # Reshape for multi-head attention + q = q.view(batch_size, seq_len, module.heads, head_dim) + k = k.view(batch_size, -1, module.heads, head_dim) + v = v.view(batch_size, -1, module.heads, head_dim) + + # Transpose for attention computation + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Compute attention scores + scale = head_dim ** -0.5 + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + + if mask is not None: + scores = scores + mask + + # Apply attention + attn = torch.softmax(scores, dim=-1) + out = torch.matmul(attn, v) + + # Reshape back + out = out.transpose(1, 2).contiguous() + out = out.view(batch_size, seq_len, -1) + + # Project back to original dimension + out = module.to_out[0](out) + + return out + + def hook_cross_attention(module, input, output): + if isinstance(module, torch.nn.Module) and hasattr(module, "to_q"): + # Store original forward + if not hasattr(module, "_original_forward"): + module._original_forward = module.forward + # Replace with our optimized version + module.forward = lambda *args, **kwargs: cross_attention_forward(module, *args, **kwargs) + return output + + # Only set up hooks if model has changed + model_hash = hash(str(model)) + if model_hash != self.last_model_hash: + m = model.clone() + + # Remove old hooks if they exist + if self.cross_attention_hook is not None: + self.cross_attention_hook.remove() + + # Register hook for cross attention modules + def register_hooks(module): + if isinstance(module, torch.nn.Module) and hasattr(module, "to_q"): + self.cross_attention_hook = module.register_forward_hook(hook_cross_attention) + + m.model.apply(register_hooks) + self.last_model_hash = model_hash + return (m,) + + return (model,) + + + From f6fbbc8af2edaf9ab3e37ca3ad7d0c5600b6e97d Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:03:18 -0500 Subject: [PATCH 03/19] precompute kv --- stream_diffusion_nodes.py | 104 +++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 30 deletions(-) diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index bc13a65..fa7b76b 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -450,11 +450,8 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50, always_exe return (selected_sigmas,) - - - class StreamCrossAttention(ControlNodeBase): - """Implements optimized cross attention for real-time generation""" + """Implements optimized cross attention with KV-cache for real-time generation""" RETURN_TYPES = ("MODEL",) FUNCTION = "update" @@ -473,6 +470,17 @@ def INPUT_TYPES(cls): "default": True, "tooltip": "Whether to use rotary position embeddings for better temporal consistency" }), + "context_size": ("INT", { + "default": 4, + "min": 1, + "max": 32, + "step": 1, + "tooltip": "Maximum number of past frames to keep in context. Higher values use more memory but may improve temporal consistency." + }), + "use_kv_cache": ("BOOLEAN", { + "default": True, + "tooltip": "Whether to cache key-value pairs for static prompts to avoid recomputation" + }), }) return inputs @@ -481,22 +489,52 @@ def __init__(self): self.last_model_hash = None self.cross_attention_hook = None - def update(self, model, always_execute=True, qk_norm=True, use_rope=True): - print(f"[StreamCrossAttention] Initializing with qk_norm={qk_norm}, use_rope={use_rope}") + def update(self, model, always_execute=True, qk_norm=True, use_rope=True, context_size=4, use_kv_cache=True): + print(f"[StreamCrossAttention] Initializing with qk_norm={qk_norm}, use_rope={use_rope}, context_size={context_size}, use_kv_cache={use_kv_cache}") # Get state with defaults state = self.get_state({ "qk_norm": qk_norm, "use_rope": use_rope, + "context_size": context_size, + "use_kv_cache": use_kv_cache, "workflow_count": 0, + "context_queue": [], # Store past context tensors + "kv_cache": {}, # Store cached key-value pairs for each prompt + "last_prompt_embeds": None, # Store last prompt embeddings for cache comparison }) def cross_attention_forward(module, x, context=None, mask=None, value=None): q = module.to_q(x) context = x if context is None else context - k = module.to_k(context) - # Use provided value tensor if given, otherwise compute it - v = value if value is not None else module.to_v(context) + + # Check if we can use cached KV pairs + cache_hit = False + if state["use_kv_cache"] and state["last_prompt_embeds"] is not None: + # Compare current context with cached prompt embeddings + if torch.allclose(context, state["last_prompt_embeds"], rtol=1e-5, atol=1e-5): + cache_hit = True + k, v = state["kv_cache"].get(module, (None, None)) + if k is not None and v is not None: + print("[StreamCrossAttention] Using cached KV pairs") + + if not cache_hit: + # Update context queue for temporal attention + if len(state["context_queue"]) >= state["context_size"]: + state["context_queue"].pop(0) + state["context_queue"].append(context.detach().clone()) + + # Concatenate current context with past contexts + full_context = torch.cat(state["context_queue"], dim=1) + + # Generate k/v for full context + k = module.to_k(full_context) + v = value if value is not None else module.to_v(full_context) + + # Cache KV pairs if this is a prompt context + if state["use_kv_cache"]: + state["last_prompt_embeds"] = context.detach().clone() + state["kv_cache"][module] = (k.detach().clone(), v.detach().clone()) # Apply QK normalization if enabled if state["qk_norm"]: @@ -510,49 +548,55 @@ def cross_attention_forward(module, x, context=None, mask=None, value=None): # Calculate position embeddings batch_size = q.shape[0] seq_len = q.shape[1] + full_seq_len = k.shape[1] # Use full context length for k/v dim = q.shape[2] - # Create position indices - position = torch.arange(seq_len, device=q.device).unsqueeze(0).unsqueeze(-1) - position = position.repeat(batch_size, 1, dim//2) + # Create position indices for q and k separately + q_position = torch.arange(seq_len, device=q.device).unsqueeze(0).unsqueeze(-1) + k_position = torch.arange(full_seq_len, device=k.device).unsqueeze(0).unsqueeze(-1) + + q_position = q_position.repeat(batch_size, 1, dim//2) + k_position = k_position.repeat(batch_size, 1, dim//2) # Calculate frequencies freq = 10000.0 ** (-torch.arange(0, dim//2, 2, device=q.device) / dim) freq = freq.repeat((dim + 1) // 2)[:dim//2] # Calculate rotation angles - theta = position * freq + q_theta = q_position * freq + k_theta = k_position * freq - # Apply rotations - cos = torch.cos(theta) - sin = torch.sin(theta) - - # Reshape q and k for rotation + # Apply rotations to q + q_cos = torch.cos(q_theta) + q_sin = torch.sin(q_theta) q_reshaped = q.view(*q.shape[:-1], -1, 2) - k_reshaped = k.view(*k.shape[:-1], -1, 2) - - # Apply rotations q_out = torch.cat([ - q_reshaped[..., 0] * cos - q_reshaped[..., 1] * sin, - q_reshaped[..., 0] * sin + q_reshaped[..., 1] * cos + q_reshaped[..., 0] * q_cos - q_reshaped[..., 1] * q_sin, + q_reshaped[..., 0] * q_sin + q_reshaped[..., 1] * q_cos ], dim=-1) + # Apply rotations to k + k_cos = torch.cos(k_theta) + k_sin = torch.sin(k_theta) + k_reshaped = k.view(*k.shape[:-1], -1, 2) k_out = torch.cat([ - k_reshaped[..., 0] * cos - k_reshaped[..., 1] * sin, - k_reshaped[..., 0] * sin + k_reshaped[..., 1] * cos + k_reshaped[..., 0] * k_cos - k_reshaped[..., 1] * k_sin, + k_reshaped[..., 0] * k_sin + k_reshaped[..., 1] * k_cos ], dim=-1) q = q_out k = k_out # Compute attention with optimized memory access pattern - batch_size, seq_len = q.shape[0], q.shape[1] + batch_size = q.shape[0] + q_seq_len = q.shape[1] + k_seq_len = k.shape[1] head_dim = q.shape[-1] // module.heads # Reshape for multi-head attention - q = q.view(batch_size, seq_len, module.heads, head_dim) - k = k.view(batch_size, -1, module.heads, head_dim) - v = v.view(batch_size, -1, module.heads, head_dim) + q = q.view(batch_size, q_seq_len, module.heads, head_dim) + k = k.view(batch_size, k_seq_len, module.heads, head_dim) + v = v.view(batch_size, k_seq_len, module.heads, head_dim) # Transpose for attention computation q = q.transpose(1, 2) @@ -572,7 +616,7 @@ def cross_attention_forward(module, x, context=None, mask=None, value=None): # Reshape back out = out.transpose(1, 2).contiguous() - out = out.view(batch_size, seq_len, -1) + out = out.view(batch_size, q_seq_len, -1) # Project back to original dimension out = module.to_out[0](out) From 34174a71821a610173509a864ddf52ff6725d1f7 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:40:35 -0500 Subject: [PATCH 04/19] update comments/tooltips --- stream_diffusion_nodes.py | 49 ++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index fa7b76b..2c45253 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -451,7 +451,18 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50, always_exe class StreamCrossAttention(ControlNodeBase): - """Implements optimized cross attention with KV-cache for real-time generation""" + """Implements optimized cross attention with KV-cache for real-time generation + + Paper reference: StreamDiffusion Section 3.5 "Pre-computation" + - Pre-computes and caches prompt embeddings + - Stores Key-Value pairs for reuse with static prompts + - Only recomputes KV pairs when prompt changes + + Additional optimizations beyond paper: + - QK normalization for better numerical stability + - Rotary position embeddings (RoPE) for improved temporal consistency + - Configurable context window size for memory/quality tradeoff + """ RETURN_TYPES = ("MODEL",) FUNCTION = "update" @@ -464,22 +475,22 @@ def INPUT_TYPES(cls): "model": ("MODEL",), "qk_norm": ("BOOLEAN", { "default": True, - "tooltip": "Whether to apply layer normalization to query and key tensors" + "tooltip": "Additional optimization: Whether to apply layer normalization to query and key tensors" }), "use_rope": ("BOOLEAN", { "default": True, - "tooltip": "Whether to use rotary position embeddings for better temporal consistency" + "tooltip": "Additional optimization: Whether to use rotary position embeddings for better temporal consistency" }), "context_size": ("INT", { "default": 4, "min": 1, "max": 32, "step": 1, - "tooltip": "Maximum number of past frames to keep in context. Higher values use more memory but may improve temporal consistency." + "tooltip": "Additional optimization: Maximum number of past frames to keep in context. Higher values use more memory but may improve temporal consistency." }), "use_kv_cache": ("BOOLEAN", { "default": True, - "tooltip": "Whether to cache key-value pairs for static prompts to avoid recomputation" + "tooltip": "Paper Section 3.5: Whether to cache key-value pairs for static prompts to avoid recomputation" }), }) return inputs @@ -494,21 +505,21 @@ def update(self, model, always_execute=True, qk_norm=True, use_rope=True, contex # Get state with defaults state = self.get_state({ - "qk_norm": qk_norm, - "use_rope": use_rope, - "context_size": context_size, - "use_kv_cache": use_kv_cache, + "qk_norm": qk_norm, # Additional optimization + "use_rope": use_rope, # Additional optimization + "context_size": context_size, # Additional optimization + "use_kv_cache": use_kv_cache, # From paper Section 3.5 "workflow_count": 0, - "context_queue": [], # Store past context tensors - "kv_cache": {}, # Store cached key-value pairs for each prompt - "last_prompt_embeds": None, # Store last prompt embeddings for cache comparison + "context_queue": [], # Additional: Store past context tensors for temporal consistency + "kv_cache": {}, # From paper Section 3.5: Cache KV pairs for each prompt + "last_prompt_embeds": None, # From paper Section 3.5: For cache validation }) def cross_attention_forward(module, x, context=None, mask=None, value=None): q = module.to_q(x) context = x if context is None else context - # Check if we can use cached KV pairs + # Paper Section 3.5: KV Caching Logic cache_hit = False if state["use_kv_cache"] and state["last_prompt_embeds"] is not None: # Compare current context with cached prompt embeddings @@ -519,31 +530,31 @@ def cross_attention_forward(module, x, context=None, mask=None, value=None): print("[StreamCrossAttention] Using cached KV pairs") if not cache_hit: - # Update context queue for temporal attention + # Additional optimization: Temporal context management if len(state["context_queue"]) >= state["context_size"]: state["context_queue"].pop(0) state["context_queue"].append(context.detach().clone()) - # Concatenate current context with past contexts + # Additional optimization: Use past context for temporal consistency full_context = torch.cat(state["context_queue"], dim=1) # Generate k/v for full context k = module.to_k(full_context) v = value if value is not None else module.to_v(full_context) - # Cache KV pairs if this is a prompt context + # Paper Section 3.5: Cache KV pairs for static prompts if state["use_kv_cache"]: state["last_prompt_embeds"] = context.detach().clone() state["kv_cache"][module] = (k.detach().clone(), v.detach().clone()) - # Apply QK normalization if enabled + # Additional optimization: QK normalization if state["qk_norm"]: q_norm = torch.nn.LayerNorm(q.shape[-1], device=q.device, dtype=q.dtype) k_norm = torch.nn.LayerNorm(k.shape[-1], device=k.device, dtype=k.dtype) q = q_norm(q) k = k_norm(k) - # Apply rotary embeddings if enabled + # Additional optimization: Rotary position embeddings if state["use_rope"]: # Calculate position embeddings batch_size = q.shape[0] @@ -587,7 +598,7 @@ def cross_attention_forward(module, x, context=None, mask=None, value=None): q = q_out k = k_out - # Compute attention with optimized memory access pattern + # Standard attention computation with memory-efficient access pattern batch_size = q.shape[0] q_seq_len = q.shape[1] k_seq_len = k.shape[1] From e0cd08911e3af4f7188a314921bc4789bb9960c9 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 10 Feb 2025 20:45:19 -0500 Subject: [PATCH 05/19] split files --- __init__.py | 3 +- stream_sampler.py | 155 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 stream_sampler.py diff --git a/__init__.py b/__init__.py index 2ee77c9..4b98126 100644 --- a/__init__.py +++ b/__init__.py @@ -3,7 +3,8 @@ from .controls.utility_controls import FPSMonitor, SimilarityFilter, LazyCondition from .controls.motion_controls import MotionController, ROINode, IntegerMotionController from .misc_nodes import DTypeConverter, FastWebcamCapture, YOLOSimilarityCompare, TextRenderer, QuickShapeMask, MultilineText, LoadImageFromPath_ -from .stream_diffusion_nodes import StreamConditioning, StreamCFG, StreamBatchSampler, StreamScheduler, StreamCrossAttention +from .stream_diffusion_nodes import StreamConditioning, StreamCFG, StreamCrossAttention +from .stream_sampler import StreamBatchSampler, StreamScheduler import re diff --git a/stream_sampler.py b/stream_sampler.py new file mode 100644 index 0000000..8269ed1 --- /dev/null +++ b/stream_sampler.py @@ -0,0 +1,155 @@ +import torch +from .base.control_base import ControlNodeBase +import comfy.model_management +import comfy.samplers +import random + + +class StreamBatchSampler(ControlNodeBase): + """Implements batched denoising for faster inference by processing multiple steps in parallel""" + + RETURN_TYPES = ("SAMPLER",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "batch_size": ("INT", { + "default": 2, + "min": 1, + "max": 10, + "step": 1, + "tooltip": "Number of steps to batch together. Higher values use more memory but are faster." + }), + }) + return inputs + + def __init__(self): + super().__init__() + self.batch_size = None + + def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): + """Sample with batched denoising steps""" + extra_args = {} if extra_args is None else extra_args + print(f"[StreamBatchSampler] Starting sampling with {len(sigmas)-1} steps, batch_size={self.batch_size}") + print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, device: {noise.device}") + print(f"[StreamBatchSampler] Sigmas: {sigmas.tolist()}") + + # Prepare batched sampling + num_sigmas = len(sigmas) - 1 + num_batches = (num_sigmas + self.batch_size - 1) // self.batch_size + x = noise + + for batch_idx in range(num_batches): + # Get sigmas for this batch + start_idx = batch_idx * self.batch_size + end_idx = min(start_idx + self.batch_size, num_sigmas) + batch_sigmas = sigmas[start_idx:end_idx+1] + print(f"\n[StreamBatchSampler] Batch {batch_idx+1}/{num_batches}") + print(f"[StreamBatchSampler] Processing steps {start_idx}-{end_idx}") + print(f"[StreamBatchSampler] Batch sigmas: {batch_sigmas.tolist()}") + + # Create batch of identical latents + batch_size = end_idx - start_idx + x_batch = x.repeat(batch_size, 1, 1, 1) + + # Create batch of sigmas + sigma_batch = batch_sigmas[:-1] # All but last sigma + + # Run model on entire batch at once + with torch.no_grad(): + # Process all steps in parallel + denoised_batch = model(x_batch, sigma_batch, **extra_args) + print(f"[StreamBatchSampler] Denoised batch shape: {denoised_batch.shape}") + + # Process results one at a time to maintain callback + for i in range(batch_size): + sigma = sigma_batch[i] + sigma_next = batch_sigmas[i + 1] + denoised = denoised_batch[i:i+1] + + # Calculate step size (now always positive as we go from high to low sigma) + dt = sigma - sigma_next + + # Update x using Euler method + # The (denoised - x) term gives us the direction to move + # dt/sigma scales how far we move based on current noise level + x = x + (denoised - x) * (dt / sigma) + print(f"[StreamBatchSampler] Step {start_idx+i}: sigma={sigma:.4f}, next_sigma={sigma_next:.4f}, dt={dt:.4f}") + + # Call callback if provided + if callback is not None: + callback({'x': x, 'i': start_idx + i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + + print(f"\n[StreamBatchSampler] Sampling complete. Final x shape: {x.shape}") + return x + + def update(self, batch_size=2, always_execute=True): + """Create sampler with specified settings""" + self.batch_size = batch_size + sampler = comfy.samplers.KSAMPLER(self.sample) + return (sampler,) + + +class StreamScheduler(ControlNodeBase): + """Implements StreamDiffusion's efficient timestep selection""" + + RETURN_TYPES = ("SIGMAS",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "model": ("MODEL",), + "t_index_list": ("STRING", { + "default": "32,45", + "tooltip": "Comma-separated list of timesteps to actually use for denoising. Examples: '32,45' for img2img or '0,16,32,45' for txt2img" + }), + "num_inference_steps": ("INT", { + "default": 50, + "min": 1, + "max": 1000, + "step": 1, + "tooltip": "Total number of timesteps in schedule. StreamDiffusion uses 50 by default. Only timesteps specified in t_index_list are actually used." + }), + }) + return inputs + + def update(self, model, t_index_list="32,45", num_inference_steps=50, always_execute=True): + # Get model's sampling parameters + model_sampling = model.get_model_object("model_sampling") + print(f"[StreamScheduler] Model sampling max sigma: {model_sampling.sigma_max}, min sigma: {model_sampling.sigma_min}") + + # Parse timestep list + try: + t_index_list = [int(t.strip()) for t in t_index_list.split(",")] + except ValueError as e: + print(f"Error parsing timesteps: {e}. Using default [32,45]") + t_index_list = [32, 45] + print(f"[StreamScheduler] Using timesteps: {t_index_list}") + + # Create full schedule using normal scheduler + full_sigmas = comfy.samplers.normal_scheduler(model_sampling, num_inference_steps) + print(f"[StreamScheduler] Full sigma schedule: {full_sigmas.tolist()}") + + # Select only the sigmas at our desired indices, but in reverse order + # This ensures we go from high noise to low noise + selected_sigmas = [] + for t in sorted(t_index_list, reverse=True): # Sort in reverse to go from high noise to low + if t < 0 or t >= num_inference_steps: + print(f"Warning: timestep {t} out of range [0,{num_inference_steps}), skipping") + continue + selected_sigmas.append(float(full_sigmas[t])) + print(f"[StreamScheduler] Selected sigmas: {selected_sigmas}") + + # Add final sigma + selected_sigmas.append(0.0) + print(f"[StreamScheduler] Final sigma schedule: {selected_sigmas}") + + # Convert to tensor and move to appropriate device + selected_sigmas = torch.FloatTensor(selected_sigmas).to(comfy.model_management.get_torch_device()) + return (selected_sigmas,) From 9ce8047fe84ed9dcd172b75b5dd534665e2e7b9a Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 10 Feb 2025 22:05:25 -0500 Subject: [PATCH 06/19] coherent output --- __init__.py | 4 +- stream_diffusion_nodes.py | 190 +++++++------------------------------ stream_sampler.py | 191 ++++++++++++++++++++++++++------------ 3 files changed, 165 insertions(+), 220 deletions(-) diff --git a/__init__.py b/__init__.py index 4b98126..80bc9af 100644 --- a/__init__.py +++ b/__init__.py @@ -4,7 +4,7 @@ from .controls.motion_controls import MotionController, ROINode, IntegerMotionController from .misc_nodes import DTypeConverter, FastWebcamCapture, YOLOSimilarityCompare, TextRenderer, QuickShapeMask, MultilineText, LoadImageFromPath_ from .stream_diffusion_nodes import StreamConditioning, StreamCFG, StreamCrossAttention -from .stream_sampler import StreamBatchSampler, StreamScheduler +from .stream_sampler import StreamBatchSampler, StreamScheduler, StreamFrameBuffer import re @@ -22,6 +22,7 @@ "StreamConditioning": StreamConditioning, "StreamBatchSampler": StreamBatchSampler, "StreamScheduler": StreamScheduler, + "StreamFrameBuffer": StreamFrameBuffer, "StreamCrossAttention": StreamCrossAttention, "LazyCondition": LazyCondition, "MotionController": MotionController, @@ -30,7 +31,6 @@ "TextRenderer": TextRenderer, - "ROINode": ROINode, #"IntervalControl": IntervalCo ntrol, diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index 2c45253..b47677a 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -186,6 +186,7 @@ def post_cfg_function(args): return (model,) +#NOTE: totally and utterly experimental. No theoretical backing whatsoever. class StreamConditioning(ControlNodeBase): """Applies Residual CFG to conditioning for improved temporal consistency with different CFG types""" #NOTE: experimental @@ -213,6 +214,13 @@ def INPUT_TYPES(s): "step": 0.1, "tooltip": "Delta parameter for self/initialize CFG types" }), + "context_size": ("INT", { + "default": 4, + "min": 1, + "max": 32, + "step": 1, + "tooltip": "Number of past conditionings to keep in context. Higher values = smoother transitions but more memory usage." + }), "always_execute": ("BOOLEAN", { "default": False, }), @@ -227,37 +235,53 @@ def INPUT_TYPES(s): def __init__(self): super().__init__() - def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta=1.0, always_execute=False): + def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta=1.0, context_size=4, always_execute=False): # Get state with defaults state = self.get_state({ "prev_positive": None, "prev_negative": None, "stock_noise": None, # For self/initialize CFG - "initialized": False # For initialize CFG + "initialized": False, # For initialize CFG + "pos_context": [], # Store past positive conditionings + "neg_context": [] # Store past negative conditionings }) # Extract conditioning tensors - current_pos_cond = positive[0][0] # Assuming standard ComfyUI conditioning format + current_pos_cond = positive[0][0] current_neg_cond = negative[0][0] + # Update context queues + if len(state["pos_context"]) >= context_size: + state["pos_context"].pop(0) + state["neg_context"].pop(0) + state["pos_context"].append(current_pos_cond.detach().clone()) + state["neg_context"].append(current_neg_cond.detach().clone()) + # First frame case if state["prev_positive"] is None: state["prev_positive"] = current_pos_cond.detach().clone() state["prev_negative"] = current_neg_cond.detach().clone() if cfg_type == "initialize": - # For initialize, we use the first negative as our stock noise state["stock_noise"] = current_neg_cond.detach().clone() elif cfg_type == "self": - # For self, we start with a scaled version of the negative state["stock_noise"] = current_neg_cond.detach().clone() * delta self.set_state(state) return (positive, negative) # Handle different CFG types if cfg_type == "full": - # Standard R-CFG with full negative conditioning - pos_residual = current_pos_cond - state["prev_positive"] - neg_residual = current_neg_cond - state["prev_negative"] + # Use entire context for smoother transitions + pos_context = torch.stack(state["pos_context"], dim=0) + neg_context = torch.stack(state["neg_context"], dim=0) + + # Calculate weighted residuals across context + weights = torch.linspace(0.5, 1.0, len(state["pos_context"]), device=current_pos_cond.device) + pos_residual = (current_pos_cond - pos_context) * weights.view(-1, 1, 1) + neg_residual = (current_neg_cond - neg_context) * weights.view(-1, 1, 1) + + # Average residuals + pos_residual = pos_residual.mean(dim=0) + neg_residual = neg_residual.mean(dim=0) blended_pos = current_pos_cond + residual_scale * pos_residual blended_neg = current_neg_cond + residual_scale * neg_residual @@ -300,156 +324,6 @@ def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta= return (positive_out, negative_out) -class StreamBatchSampler(ControlNodeBase): - """Implements batched denoising for faster inference by processing multiple steps in parallel""" - - RETURN_TYPES = ("SAMPLER",) - FUNCTION = "update" - CATEGORY = "real-time/sampling" - - @classmethod - def INPUT_TYPES(cls): - inputs = super().INPUT_TYPES() - inputs["required"].update({ - "batch_size": ("INT", { - "default": 2, - "min": 1, - "max": 10, - "step": 1, - "tooltip": "Number of steps to batch together. Higher values use more memory but are faster." - }), - }) - return inputs - - def __init__(self): - super().__init__() - self.batch_size = None - - def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): - """Sample with batched denoising steps""" - extra_args = {} if extra_args is None else extra_args - print(f"[StreamBatchSampler] Starting sampling with {len(sigmas)-1} steps, batch_size={self.batch_size}") - print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, device: {noise.device}") - print(f"[StreamBatchSampler] Sigmas: {sigmas.tolist()}") - - # Prepare batched sampling - num_sigmas = len(sigmas) - 1 - num_batches = (num_sigmas + self.batch_size - 1) // self.batch_size - x = noise - - for batch_idx in range(num_batches): - # Get sigmas for this batch - start_idx = batch_idx * self.batch_size - end_idx = min(start_idx + self.batch_size, num_sigmas) - batch_sigmas = sigmas[start_idx:end_idx+1] - print(f"\n[StreamBatchSampler] Batch {batch_idx+1}/{num_batches}") - print(f"[StreamBatchSampler] Processing steps {start_idx}-{end_idx}") - print(f"[StreamBatchSampler] Batch sigmas: {batch_sigmas.tolist()}") - - # Create batch of identical latents - batch_size = end_idx - start_idx - x_batch = x.repeat(batch_size, 1, 1, 1) - - # Create batch of sigmas - sigma_batch = batch_sigmas[:-1] # All but last sigma - - # Run model on entire batch at once - with torch.no_grad(): - # Process all steps in parallel - denoised_batch = model(x_batch, sigma_batch, **extra_args) - print(f"[StreamBatchSampler] Denoised batch shape: {denoised_batch.shape}") - - # Process results one at a time to maintain callback - for i in range(batch_size): - sigma = sigma_batch[i] - sigma_next = batch_sigmas[i + 1] - denoised = denoised_batch[i:i+1] - - # Calculate step size (now always positive as we go from high to low sigma) - dt = sigma - sigma_next - - # Update x using Euler method - # The (denoised - x) term gives us the direction to move - # dt/sigma scales how far we move based on current noise level - x = x + (denoised - x) * (dt / sigma) - print(f"[StreamBatchSampler] Step {start_idx+i}: sigma={sigma:.4f}, next_sigma={sigma_next:.4f}, dt={dt:.4f}") - - # Call callback if provided - if callback is not None: - callback({'x': x, 'i': start_idx + i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) - - print(f"\n[StreamBatchSampler] Sampling complete. Final x shape: {x.shape}") - return x - - def update(self, batch_size=2, always_execute=True): - """Create sampler with specified settings""" - self.batch_size = batch_size - sampler = comfy.samplers.KSAMPLER(self.sample) - return (sampler,) - - -class StreamScheduler(ControlNodeBase): - """Implements StreamDiffusion's efficient timestep selection""" - - RETURN_TYPES = ("SIGMAS",) - FUNCTION = "update" - CATEGORY = "real-time/sampling" - - @classmethod - def INPUT_TYPES(cls): - inputs = super().INPUT_TYPES() - inputs["required"].update({ - "model": ("MODEL",), - "t_index_list": ("STRING", { - "default": "32,45", - "tooltip": "Comma-separated list of timesteps to actually use for denoising. Examples: '32,45' for img2img or '0,16,32,45' for txt2img" - }), - "num_inference_steps": ("INT", { - "default": 50, - "min": 1, - "max": 1000, - "step": 1, - "tooltip": "Total number of timesteps in schedule. StreamDiffusion uses 50 by default. Only timesteps specified in t_index_list are actually used." - }), - }) - return inputs - - def update(self, model, t_index_list="32,45", num_inference_steps=50, always_execute=True): - # Get model's sampling parameters - model_sampling = model.get_model_object("model_sampling") - print(f"[StreamScheduler] Model sampling max sigma: {model_sampling.sigma_max}, min sigma: {model_sampling.sigma_min}") - - # Parse timestep list - try: - t_index_list = [int(t.strip()) for t in t_index_list.split(",")] - except ValueError as e: - print(f"Error parsing timesteps: {e}. Using default [32,45]") - t_index_list = [32, 45] - print(f"[StreamScheduler] Using timesteps: {t_index_list}") - - # Create full schedule using normal scheduler - full_sigmas = comfy.samplers.normal_scheduler(model_sampling, num_inference_steps) - print(f"[StreamScheduler] Full sigma schedule: {full_sigmas.tolist()}") - - # Select only the sigmas at our desired indices, but in reverse order - # This ensures we go from high noise to low noise - selected_sigmas = [] - for t in sorted(t_index_list, reverse=True): # Sort in reverse to go from high noise to low - if t < 0 or t >= num_inference_steps: - print(f"Warning: timestep {t} out of range [0,{num_inference_steps}), skipping") - continue - selected_sigmas.append(float(full_sigmas[t])) - print(f"[StreamScheduler] Selected sigmas: {selected_sigmas}") - - # Add final sigma - selected_sigmas.append(0.0) - print(f"[StreamScheduler] Final sigma schedule: {selected_sigmas}") - - # Convert to tensor and move to appropriate device - selected_sigmas = torch.FloatTensor(selected_sigmas).to(comfy.model_management.get_torch_device()) - return (selected_sigmas,) - - class StreamCrossAttention(ControlNodeBase): """Implements optimized cross attention with KV-cache for real-time generation diff --git a/stream_sampler.py b/stream_sampler.py index 8269ed1..c59275e 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -6,7 +6,7 @@ class StreamBatchSampler(ControlNodeBase): - """Implements batched denoising for faster inference by processing multiple steps in parallel""" + """Implements batched denoising for faster inference by processing multiple frames in parallel at different denoising steps""" RETURN_TYPES = ("SAMPLER",) FUNCTION = "update" @@ -16,79 +16,100 @@ class StreamBatchSampler(ControlNodeBase): def INPUT_TYPES(cls): inputs = super().INPUT_TYPES() inputs["required"].update({ - "batch_size": ("INT", { - "default": 2, + "num_steps": ("INT", { + "default": 4, "min": 1, "max": 10, "step": 1, - "tooltip": "Number of steps to batch together. Higher values use more memory but are faster." + "tooltip": "Number of denoising steps. Should match the frame buffer size." }), }) return inputs def __init__(self): super().__init__() - self.batch_size = None + self.num_steps = None + self.frame_buffer = [] + self.x_t_latent_buffer = None + self.stock_noise = None def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): - """Sample with batched denoising steps""" + """Sample with staggered batch denoising steps""" extra_args = {} if extra_args is None else extra_args - print(f"[StreamBatchSampler] Starting sampling with {len(sigmas)-1} steps, batch_size={self.batch_size}") - print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, device: {noise.device}") - print(f"[StreamBatchSampler] Sigmas: {sigmas.tolist()}") + print(f"[StreamBatchSampler] Starting sampling with {len(sigmas)-1} steps") - # Prepare batched sampling - num_sigmas = len(sigmas) - 1 - num_batches = (num_sigmas + self.batch_size - 1) // self.batch_size - x = noise + # Get number of frames in batch and available sigmas + batch_size = noise.shape[0] + num_sigmas = len(sigmas) - 1 # Subtract 1 because last sigma is the target (0.0) - for batch_idx in range(num_batches): - # Get sigmas for this batch - start_idx = batch_idx * self.batch_size - end_idx = min(start_idx + self.batch_size, num_sigmas) - batch_sigmas = sigmas[start_idx:end_idx+1] - print(f"\n[StreamBatchSampler] Batch {batch_idx+1}/{num_batches}") - print(f"[StreamBatchSampler] Processing steps {start_idx}-{end_idx}") - print(f"[StreamBatchSampler] Batch sigmas: {batch_sigmas.tolist()}") + print(f"[StreamBatchSampler] Input sigmas: {sigmas}") + print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, min: {noise.min():.3f}, max: {noise.max():.3f}") + + # Verify batch size matches number of timesteps + if batch_size != num_sigmas: + raise ValueError(f"Batch size ({batch_size}) must match number of timesteps ({num_sigmas})") + + # Pre-compute alpha and beta terms + alpha_prod_t = (sigmas[:-1] / sigmas[0]).view(-1, 1, 1, 1) # [B,1,1,1] + beta_prod_t = (1 - alpha_prod_t) + + print(f"[StreamBatchSampler] Alpha values: {alpha_prod_t.view(-1)}") + print(f"[StreamBatchSampler] Beta values: {beta_prod_t.view(-1)}") + + # Initialize stock noise if needed + if self.stock_noise is None: + self.stock_noise = torch.randn_like(noise[0]) # Random noise instead of zeros + print(f"[StreamBatchSampler] Initialized random stock noise with shape: {self.stock_noise.shape}") + + # Scale noise for each frame based on its sigma + scaled_noise = [] + for i in range(batch_size): + frame_noise = noise[i] + self.stock_noise * sigmas[i] # Add scaled noise to input + scaled_noise.append(frame_noise) + x = torch.stack(scaled_noise, dim=0) + print(f"[StreamBatchSampler] Scaled noise shape: {x.shape}, min: {x.min():.3f}, max: {x.max():.3f}") - # Create batch of identical latents - batch_size = end_idx - start_idx - x_batch = x.repeat(batch_size, 1, 1, 1) + # Initialize frame buffer if needed + if self.x_t_latent_buffer is None and num_sigmas > 1: + self.x_t_latent_buffer = x[0].clone() # Initialize with noised first frame + print(f"[StreamBatchSampler] Initialized buffer with shape: {self.x_t_latent_buffer.shape}") - # Create batch of sigmas - sigma_batch = batch_sigmas[:-1] # All but last sigma + # Use buffer for first frame to maintain temporal consistency + if num_sigmas > 1: + x = torch.cat([self.x_t_latent_buffer.unsqueeze(0), x[1:]], dim=0) + print(f"[StreamBatchSampler] Combined with buffer, shape: {x.shape}") - # Run model on entire batch at once - with torch.no_grad(): - # Process all steps in parallel - denoised_batch = model(x_batch, sigma_batch, **extra_args) - print(f"[StreamBatchSampler] Denoised batch shape: {denoised_batch.shape}") + # Run model on entire batch at once + with torch.no_grad(): + # Process all frames in parallel + sigma_batch = sigmas[:-1] + print(f"[StreamBatchSampler] Using sigmas for denoising: {sigma_batch}") + + denoised_batch = model(x, sigma_batch, **extra_args) + print(f"[StreamBatchSampler] Denoised batch shape: {denoised_batch.shape}") + print(f"[StreamBatchSampler] Denoised stats - min: {denoised_batch.min():.3f}, max: {denoised_batch.max():.3f}") + + # Update buffer with intermediate results + if num_sigmas > 1: + # Store result from first frame as buffer for next iteration + self.x_t_latent_buffer = denoised_batch[0].clone() + print(f"[StreamBatchSampler] Updated buffer with shape: {self.x_t_latent_buffer.shape}") - # Process results one at a time to maintain callback - for i in range(batch_size): - sigma = sigma_batch[i] - sigma_next = batch_sigmas[i + 1] - denoised = denoised_batch[i:i+1] - - # Calculate step size (now always positive as we go from high to low sigma) - dt = sigma - sigma_next - - # Update x using Euler method - # The (denoised - x) term gives us the direction to move - # dt/sigma scales how far we move based on current noise level - x = x + (denoised - x) * (dt / sigma) - print(f"[StreamBatchSampler] Step {start_idx+i}: sigma={sigma:.4f}, next_sigma={sigma_next:.4f}, dt={dt:.4f}") - - # Call callback if provided - if callback is not None: - callback({'x': x, 'i': start_idx + i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + # Return result from last frame + x_0_pred_out = denoised_batch[-1].unsqueeze(0) + else: + x_0_pred_out = denoised_batch + self.x_t_latent_buffer = None + + # Call callback if provided + if callback is not None: + callback({'x': x_0_pred_out, 'i': 0, 'sigma': sigmas[0], 'sigma_hat': sigmas[0], 'denoised': denoised_batch[-1:]}) - print(f"\n[StreamBatchSampler] Sampling complete. Final x shape: {x.shape}") - return x + return x_0_pred_out - def update(self, batch_size=2, always_execute=True): + def update(self, num_steps=4, always_execute=True): """Create sampler with specified settings""" - self.batch_size = batch_size + self.num_steps = num_steps sampler = comfy.samplers.KSAMPLER(self.sample) return (sampler,) @@ -122,7 +143,6 @@ def INPUT_TYPES(cls): def update(self, model, t_index_list="32,45", num_inference_steps=50, always_execute=True): # Get model's sampling parameters model_sampling = model.get_model_object("model_sampling") - print(f"[StreamScheduler] Model sampling max sigma: {model_sampling.sigma_max}, min sigma: {model_sampling.sigma_min}") # Parse timestep list try: @@ -130,11 +150,9 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50, always_exe except ValueError as e: print(f"Error parsing timesteps: {e}. Using default [32,45]") t_index_list = [32, 45] - print(f"[StreamScheduler] Using timesteps: {t_index_list}") - + # Create full schedule using normal scheduler full_sigmas = comfy.samplers.normal_scheduler(model_sampling, num_inference_steps) - print(f"[StreamScheduler] Full sigma schedule: {full_sigmas.tolist()}") # Select only the sigmas at our desired indices, but in reverse order # This ensures we go from high noise to low noise @@ -144,12 +162,65 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50, always_exe print(f"Warning: timestep {t} out of range [0,{num_inference_steps}), skipping") continue selected_sigmas.append(float(full_sigmas[t])) - print(f"[StreamScheduler] Selected sigmas: {selected_sigmas}") - + # Add final sigma selected_sigmas.append(0.0) - print(f"[StreamScheduler] Final sigma schedule: {selected_sigmas}") # Convert to tensor and move to appropriate device selected_sigmas = torch.FloatTensor(selected_sigmas).to(comfy.model_management.get_torch_device()) return (selected_sigmas,) + + +class StreamFrameBuffer(ControlNodeBase): + """Accumulates frames to enable staggered batch denoising like StreamDiffusion""" + + RETURN_TYPES = ("LATENT",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "latent": ("LATENT",), + "buffer_size": ("INT", { + "default": 4, + "min": 1, + "max": 10, + "step": 1, + "tooltip": "Number of frames to buffer before starting batch processing. Should match number of denoising steps." + }), + }) + return inputs + + def __init__(self): + super().__init__() + self.frame_buffer = [] # List to store incoming frames + self.buffer_size = None + + def update(self, latent, buffer_size=4, always_execute=True): + """Add new frame to buffer and return batch when ready""" + self.buffer_size = buffer_size + + # Extract latent tensor from input and remove batch dimension if present + x = latent["samples"] + if x.dim() == 4: # [B,C,H,W] + x = x.squeeze(0) # Remove batch dimension -> [C,H,W] + + # Add new frame to buffer + if len(self.frame_buffer) == 0: + # First frame - initialize buffer with copies + self.frame_buffer = [x.clone() for _ in range(self.buffer_size)] + print(f"[StreamFrameBuffer] Initialized buffer with {self.buffer_size} copies of first frame") + else: + # Shift frames forward and add new frame + self.frame_buffer.pop(0) # Remove oldest frame + self.frame_buffer.append(x.clone()) # Add new frame + print(f"[StreamFrameBuffer] Added new frame to buffer") + + # Stack frames into batch + batch = torch.stack(self.frame_buffer, dim=0) # [B,C,H,W] + print(f"[StreamFrameBuffer] Created batch with shape: {batch.shape}") + + # Return as latent dict + return ({"samples": batch},) From 5c178b8d57ebd2742294cb33d568ab523f884f67 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Feb 2025 09:36:25 -0500 Subject: [PATCH 07/19] RCFG update --- stream_diffusion_nodes.py | 66 ++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index b47677a..ddec93a 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -48,17 +48,20 @@ def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.4 # Get state with defaults state = self.get_state({ - "last_uncond": None, # Store last workflow's unconditioned prediction + "last_uncond": None, "initialized": False, "cfg_type": cfg_type, "residual_scale": residual_scale, "delta": delta, - "workflow_count": 0, # Track number of workflow runs - "current_sigmas": None, # Track sigmas for this workflow - "seen_sigmas": set(), # Track which sigmas we've seen this workflow - "is_last_step": False, # Track if we're on the last step - "alpha_prod_t": None, # Store alpha values for proper scaling - "beta_prod_t": None, # Store beta values for proper scaling + "workflow_count": 0, + "current_sigmas": None, + "seen_sigmas": set(), + "is_last_step": False, + # Add new state variables for proper scaling + "alpha_prod_t": None, + "beta_prod_t": None, + "c_skip": None, + "c_out": None, }) def post_cfg_function(args): @@ -90,10 +93,17 @@ def post_cfg_function(args): state["current_sigmas"] = sigmas state["seen_sigmas"] = set() - # Calculate alpha and beta values for proper scaling - alphas = [1.0 / (1.0 + s**2) for s in sigmas] - state["alpha_prod_t"] = torch.tensor(alphas, device=denoised.device, dtype=denoised.dtype) - state["beta_prod_t"] = torch.sqrt(1 - state["alpha_prod_t"]) + # Calculate paper's exact scaling factors + state["alpha_prod_t"] = torch.tensor([1.0 / (1.0 + s**2) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + state["beta_prod_t"] = torch.tensor([s / (1.0 + s**2) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + + # Calculate c_skip and c_out coefficients + state["c_skip"] = torch.tensor([1.0 / (s**2 + 1.0) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + state["c_out"] = torch.tensor([-s / torch.sqrt(torch.tensor(s**2 + 1.0)) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) # Track this sigma state["seen_sigmas"].add(sigma) @@ -113,27 +123,24 @@ def post_cfg_function(args): if state["is_last_step"]: state["last_uncond"] = uncond_denoised.detach().clone() state["workflow_count"] += 1 - state["current_sigmas"] = None # Reset for next workflow + state["current_sigmas"] = None if cfg_type == "initialize": state["initialized"] = True self.set_state(state) return denoised - # Handle different CFG types for subsequent workflows + # Handle different CFG types if cfg_type == "full": result = denoised - elif cfg_type == "initialize" and not state["initialized"]: result = denoised if state["is_last_step"]: state["initialized"] = True self.set_state(state) - else: # self or initialized initialize - # Get current step index current_idx = len(state["seen_sigmas"]) - 1 - # Scale last prediction with proper alpha/beta values + # Use paper's exact formulation for noise prediction noise_pred_uncond = state["last_uncond"] * delta # Apply CFG with scaled prediction @@ -141,25 +148,26 @@ def post_cfg_function(args): # Store last prediction if this is the last step if state["is_last_step"]: - # Calculate properly scaled residual - scaled_noise = state["beta_prod_t"][current_idx] * state["last_uncond"] - delta_x = uncond_denoised - scaled_noise + # Calculate F_theta using paper's formulation + F_theta = (uncond_denoised - state["beta_prod_t"][current_idx] * noise_pred_uncond) / state["alpha_prod_t"][current_idx] + delta_x = state["c_out"][current_idx] * F_theta + state["c_skip"][current_idx] * uncond_denoised - # Scale delta_x with next step's alpha/beta + # Scale delta_x with next step's coefficients if current_idx < len(state["current_sigmas"]) - 1: - alpha_next = state["alpha_prod_t"][current_idx + 1] - beta_next = state["beta_prod_t"][current_idx + 1] + next_alpha = state["alpha_prod_t"][current_idx + 1] + next_beta = state["beta_prod_t"][current_idx + 1] else: - alpha_next = torch.ones_like(state["alpha_prod_t"][0]) - beta_next = torch.ones_like(state["beta_prod_t"][0]) + next_alpha = torch.ones_like(state["alpha_prod_t"][0]) + next_beta = torch.zeros_like(state["beta_prod_t"][0]) - delta_x = alpha_next * delta_x / beta_next + # Update stored prediction with properly scaled residual + final_update = (next_alpha * delta_x) / next_beta + if next_beta > 0: # Add noise only when beta > 0 + final_update = final_update + torch.randn_like(delta_x) * (1 - next_alpha**2).sqrt() - # Update stored prediction with scaled residual - final_update = uncond_denoised + residual_scale * delta_x state["last_uncond"] = final_update state["workflow_count"] += 1 - state["current_sigmas"] = None # Reset for next workflow + state["current_sigmas"] = None self.set_state(state) return result From 213719c16994f11a3556271c8f115ac992183e6d Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Feb 2025 10:22:40 -0500 Subject: [PATCH 08/19] fix final step using uncond_denoised for the final step instead of trying to scale delta_x when next_beta is 0 --- stream_diffusion_nodes.py | 48 ++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index ddec93a..575a765 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -73,12 +73,21 @@ def post_cfg_function(args): uncond_denoised = args["uncond_denoised"] cond_scale = args["cond_scale"] + # Debug prints for tensor stats + print(f"\n[StreamCFG Debug] Step Info:") + print(f"- Workflow count: {state['workflow_count']}") + print(f"- CFG Type: {state['cfg_type']}") + print(f"- Tensor Stats:") + print(f" - denoised shape: {denoised.shape}, range: [{denoised.min():.3f}, {denoised.max():.3f}]") + print(f" - uncond_denoised shape: {uncond_denoised.shape}, range: [{uncond_denoised.min():.3f}, {uncond_denoised.max():.3f}]") + if state["last_uncond"] is not None: + print(f" - last_uncond shape: {state['last_uncond'].shape}, range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") + # Handle both batched and single sigmas sigma = args["sigma"] if torch.is_tensor(sigma): - # For batched sampling, use first sigma in batch - # This is safe because we process in order and track seen sigmas sigma = sigma[0].item() if len(sigma.shape) > 0 else sigma.item() + print(f"- Current sigma: {sigma:.6f}") # Get step info from model options model_options = args["model_options"] @@ -86,12 +95,12 @@ def post_cfg_function(args): # Update current sigmas if needed if sample_sigmas is not None and state["current_sigmas"] is None: - # Filter out the trailing 0.0 if present sigmas = [s.item() for s in sample_sigmas] if sigmas[-1] == 0.0: sigmas = sigmas[:-1] state["current_sigmas"] = sigmas state["seen_sigmas"] = set() + print(f"- New sigma sequence: {sigmas}") # Calculate paper's exact scaling factors state["alpha_prod_t"] = torch.tensor([1.0 / (1.0 + s**2) for s in sigmas], @@ -104,6 +113,12 @@ def post_cfg_function(args): device=denoised.device, dtype=denoised.dtype) state["c_out"] = torch.tensor([-s / torch.sqrt(torch.tensor(s**2 + 1.0)) for s in sigmas], device=denoised.device, dtype=denoised.dtype) + + print(f"- Scaling factors for first step:") + print(f" alpha: {state['alpha_prod_t'][0]:.6f}") + print(f" beta: {state['beta_prod_t'][0]:.6f}") + print(f" c_skip: {state['c_skip'][0]:.6f}") + print(f" c_out: {state['c_out'][0]:.6f}") # Track this sigma state["seen_sigmas"].add(sigma) @@ -117,6 +132,8 @@ def post_cfg_function(args): if not is_last_step and sigma == min(state["current_sigmas"]): is_last_step = True state["is_last_step"] = is_last_step + print(f"- Is last step: {is_last_step}") + print(f"- Seen sigmas: {sorted(state['seen_sigmas'])}") # First workflow case if state["last_uncond"] is None: @@ -127,6 +144,7 @@ def post_cfg_function(args): if cfg_type == "initialize": state["initialized"] = True self.set_state(state) + print("- First workflow complete, stored last_uncond") return denoised # Handle different CFG types @@ -139,18 +157,24 @@ def post_cfg_function(args): self.set_state(state) else: # self or initialized initialize current_idx = len(state["seen_sigmas"]) - 1 + print(f"- Current step index: {current_idx}") # Use paper's exact formulation for noise prediction - noise_pred_uncond = state["last_uncond"] * delta + noise_pred_uncond = state["last_uncond"] * state["delta"] + print(f"- Scaled noise prediction range: [{noise_pred_uncond.min():.3f}, {noise_pred_uncond.max():.3f}]") # Apply CFG with scaled prediction - result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) + result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) * state["residual_scale"] + print(f"- Result range after CFG: [{result.min():.3f}, {result.max():.3f}]") # Store last prediction if this is the last step if state["is_last_step"]: # Calculate F_theta using paper's formulation F_theta = (uncond_denoised - state["beta_prod_t"][current_idx] * noise_pred_uncond) / state["alpha_prod_t"][current_idx] + print(f"- F_theta range: [{F_theta.min():.3f}, {F_theta.max():.3f}]") + delta_x = state["c_out"][current_idx] * F_theta + state["c_skip"][current_idx] * uncond_denoised + print(f"- delta_x range: [{delta_x.min():.3f}, {delta_x.max():.3f}]") # Scale delta_x with next step's coefficients if current_idx < len(state["current_sigmas"]) - 1: @@ -159,12 +183,20 @@ def post_cfg_function(args): else: next_alpha = torch.ones_like(state["alpha_prod_t"][0]) next_beta = torch.zeros_like(state["beta_prod_t"][0]) + print(f"- Next step coefficients - alpha: {next_alpha:.6f}, beta: {next_beta:.6f}") # Update stored prediction with properly scaled residual - final_update = (next_alpha * delta_x) / next_beta - if next_beta > 0: # Add noise only when beta > 0 - final_update = final_update + torch.randn_like(delta_x) * (1 - next_alpha**2).sqrt() + if next_beta > 0: + final_update = (next_alpha * delta_x) / next_beta + # Add noise only when beta > 0 + noise = torch.randn_like(delta_x) * (1 - next_alpha**2).sqrt() + final_update = final_update + noise + print(f"- Added noise range: [{noise.min():.3f}, {noise.max():.3f}]") + else: + # For the last step, just use the current prediction + final_update = uncond_denoised + print(f"- Final update range: [{final_update.min():.3f}, {final_update.max():.3f}]") state["last_uncond"] = final_update state["workflow_count"] += 1 state["current_sigmas"] = None From 71c429d9a69f9c2501b2ac1be5f8e246e8d6773e Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Feb 2025 23:02:17 -0500 Subject: [PATCH 09/19] simplified stream cross attention --- stream_diffusion_nodes.py | 243 +++++++++++--------------------------- 1 file changed, 71 insertions(+), 172 deletions(-) diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index 575a765..7411471 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -21,7 +21,7 @@ def INPUT_TYPES(cls): "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" }), "residual_scale": ("FLOAT", { - "default": 0.4, + "default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01, @@ -39,14 +39,12 @@ def INPUT_TYPES(cls): def __init__(self): super().__init__() - # Store the last model to detect when we need to reapply the hook self.last_model_hash = None self.post_cfg_function = None - def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.4, delta=1.0): + def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.7, delta=1.0): print(f"[StreamCFG] Initializing with cfg_type={cfg_type}, residual_scale={residual_scale}, delta={delta}") - # Get state with defaults state = self.get_state({ "last_uncond": None, "initialized": False, @@ -57,15 +55,14 @@ def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.4 "current_sigmas": None, "seen_sigmas": set(), "is_last_step": False, - # Add new state variables for proper scaling "alpha_prod_t": None, "beta_prod_t": None, "c_skip": None, "c_out": None, + "last_noise": None, # Store noise from previous frame }) def post_cfg_function(args): - # Extract info denoised = args["denoised"] cond = args["cond"] uncond = args["uncond"] @@ -73,7 +70,6 @@ def post_cfg_function(args): uncond_denoised = args["uncond_denoised"] cond_scale = args["cond_scale"] - # Debug prints for tensor stats print(f"\n[StreamCFG Debug] Step Info:") print(f"- Workflow count: {state['workflow_count']}") print(f"- CFG Type: {state['cfg_type']}") @@ -83,17 +79,14 @@ def post_cfg_function(args): if state["last_uncond"] is not None: print(f" - last_uncond shape: {state['last_uncond'].shape}, range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") - # Handle both batched and single sigmas sigma = args["sigma"] if torch.is_tensor(sigma): sigma = sigma[0].item() if len(sigma.shape) > 0 else sigma.item() print(f"- Current sigma: {sigma:.6f}") - # Get step info from model options model_options = args["model_options"] sample_sigmas = model_options["transformer_options"].get("sample_sigmas", None) - # Update current sigmas if needed if sample_sigmas is not None and state["current_sigmas"] is None: sigmas = [s.item() for s in sample_sigmas] if sigmas[-1] == 0.0: @@ -102,13 +95,11 @@ def post_cfg_function(args): state["seen_sigmas"] = set() print(f"- New sigma sequence: {sigmas}") - # Calculate paper's exact scaling factors state["alpha_prod_t"] = torch.tensor([1.0 / (1.0 + s**2) for s in sigmas], device=denoised.device, dtype=denoised.dtype) state["beta_prod_t"] = torch.tensor([s / (1.0 + s**2) for s in sigmas], device=denoised.device, dtype=denoised.dtype) - # Calculate c_skip and c_out coefficients state["c_skip"] = torch.tensor([1.0 / (s**2 + 1.0) for s in sigmas], device=denoised.device, dtype=denoised.dtype) state["c_out"] = torch.tensor([-s / torch.sqrt(torch.tensor(s**2 + 1.0)) for s in sigmas], @@ -119,16 +110,26 @@ def post_cfg_function(args): print(f" beta: {state['beta_prod_t'][0]:.6f}") print(f" c_skip: {state['c_skip'][0]:.6f}") print(f" c_out: {state['c_out'][0]:.6f}") + + # Initialize noise for first step using previous frame if available + if state["last_uncond"] is not None and state["last_noise"] is not None: + # Scale noise based on current sigma + current_sigma = torch.tensor(sigmas[0], device=denoised.device, dtype=denoised.dtype) + scaled_noise = state["last_noise"] * current_sigma + + # Mix with previous frame prediction + alpha = 1.0 / (1.0 + current_sigma**2) + noisy_input = alpha * state["last_uncond"] + (1 - alpha) * scaled_noise + + # Update model input + if "input" in model_options: + model_options["input"] = noisy_input + print(f"- Initialized with previous frame, noise scale: {current_sigma:.6f}") - # Track this sigma state["seen_sigmas"].add(sigma) - - # Check if this is the last step state["is_last_step"] = False if state["current_sigmas"] is not None: - # It's the last step if we've seen all sigmas is_last_step = len(state["seen_sigmas"]) >= len(state["current_sigmas"]) - # Or if this is the smallest sigma in the sequence if not is_last_step and sigma == min(state["current_sigmas"]): is_last_step = True state["is_last_step"] = is_last_step @@ -139,68 +140,44 @@ def post_cfg_function(args): if state["last_uncond"] is None: if state["is_last_step"]: state["last_uncond"] = uncond_denoised.detach().clone() + # Store noise for next frame initialization + if "noise" in args: + state["last_noise"] = args["noise"].detach().clone() state["workflow_count"] += 1 state["current_sigmas"] = None if cfg_type == "initialize": state["initialized"] = True self.set_state(state) - print("- First workflow complete, stored last_uncond") + print("- First workflow complete, stored last_uncond and noise") return denoised - # Handle different CFG types - if cfg_type == "full": - result = denoised - elif cfg_type == "initialize" and not state["initialized"]: - result = denoised - if state["is_last_step"]: - state["initialized"] = True - self.set_state(state) - else: # self or initialized initialize - current_idx = len(state["seen_sigmas"]) - 1 - print(f"- Current step index: {current_idx}") - - # Use paper's exact formulation for noise prediction + current_idx = len(state["seen_sigmas"]) - 1 + print(f"- Current step index: {current_idx}") + + # Apply temporal consistency at first step and blend throughout + if current_idx == 0: + # Strong influence at first step noise_pred_uncond = state["last_uncond"] * state["delta"] - print(f"- Scaled noise prediction range: [{noise_pred_uncond.min():.3f}, {noise_pred_uncond.max():.3f}]") - - # Apply CFG with scaled prediction - result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) * state["residual_scale"] - print(f"- Result range after CFG: [{result.min():.3f}, {result.max():.3f}]") - - # Store last prediction if this is the last step - if state["is_last_step"]: - # Calculate F_theta using paper's formulation - F_theta = (uncond_denoised - state["beta_prod_t"][current_idx] * noise_pred_uncond) / state["alpha_prod_t"][current_idx] - print(f"- F_theta range: [{F_theta.min():.3f}, {F_theta.max():.3f}]") - - delta_x = state["c_out"][current_idx] * F_theta + state["c_skip"][current_idx] * uncond_denoised - print(f"- delta_x range: [{delta_x.min():.3f}, {delta_x.max():.3f}]") - - # Scale delta_x with next step's coefficients - if current_idx < len(state["current_sigmas"]) - 1: - next_alpha = state["alpha_prod_t"][current_idx + 1] - next_beta = state["beta_prod_t"][current_idx + 1] - else: - next_alpha = torch.ones_like(state["alpha_prod_t"][0]) - next_beta = torch.zeros_like(state["beta_prod_t"][0]) - print(f"- Next step coefficients - alpha: {next_alpha:.6f}, beta: {next_beta:.6f}") - - # Update stored prediction with properly scaled residual - if next_beta > 0: - final_update = (next_alpha * delta_x) / next_beta - # Add noise only when beta > 0 - noise = torch.randn_like(delta_x) * (1 - next_alpha**2).sqrt() - final_update = final_update + noise - print(f"- Added noise range: [{noise.min():.3f}, {noise.max():.3f}]") - else: - # For the last step, just use the current prediction - final_update = uncond_denoised - - print(f"- Final update range: [{final_update.min():.3f}, {final_update.max():.3f}]") - state["last_uncond"] = final_update - state["workflow_count"] += 1 - state["current_sigmas"] = None - self.set_state(state) + result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) + # Apply residual scale to entire result for stronger consistency + result = result * state["residual_scale"] + denoised * (1 - state["residual_scale"]) + else: + # Lighter influence in later steps + blend_scale = state["residual_scale"] * (1 - current_idx / len(state["current_sigmas"])) + result = denoised * (1 - blend_scale) + uncond_denoised * blend_scale + + print(f"- Result range after blending: [{result.min():.3f}, {result.max():.3f}]") + + # Store last prediction if this is the last step + if state["is_last_step"]: + state["last_uncond"] = uncond_denoised.detach().clone() + # Store noise for next frame initialization + if "noise" in args: + state["last_noise"] = args["noise"].detach().clone() + state["workflow_count"] += 1 + state["current_sigmas"] = None + self.set_state(state) + print(f"- Stored new last_uncond range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") return result @@ -371,11 +348,6 @@ class StreamCrossAttention(ControlNodeBase): - Pre-computes and caches prompt embeddings - Stores Key-Value pairs for reuse with static prompts - Only recomputes KV pairs when prompt changes - - Additional optimizations beyond paper: - - QK normalization for better numerical stability - - Rotary position embeddings (RoPE) for improved temporal consistency - - Configurable context window size for memory/quality tradeoff """ RETURN_TYPES = ("MODEL",) @@ -387,21 +359,6 @@ def INPUT_TYPES(cls): inputs = super().INPUT_TYPES() inputs["required"].update({ "model": ("MODEL",), - "qk_norm": ("BOOLEAN", { - "default": True, - "tooltip": "Additional optimization: Whether to apply layer normalization to query and key tensors" - }), - "use_rope": ("BOOLEAN", { - "default": True, - "tooltip": "Additional optimization: Whether to use rotary position embeddings for better temporal consistency" - }), - "context_size": ("INT", { - "default": 4, - "min": 1, - "max": 32, - "step": 1, - "tooltip": "Additional optimization: Maximum number of past frames to keep in context. Higher values use more memory but may improve temporal consistency." - }), "use_kv_cache": ("BOOLEAN", { "default": True, "tooltip": "Paper Section 3.5: Whether to cache key-value pairs for static prompts to avoid recomputation" @@ -414,19 +371,14 @@ def __init__(self): self.last_model_hash = None self.cross_attention_hook = None - def update(self, model, always_execute=True, qk_norm=True, use_rope=True, context_size=4, use_kv_cache=True): - print(f"[StreamCrossAttention] Initializing with qk_norm={qk_norm}, use_rope={use_rope}, context_size={context_size}, use_kv_cache={use_kv_cache}") + def update(self, model, always_execute=True, use_kv_cache=True): + print(f"[StreamCrossAttention] Initializing with use_kv_cache={use_kv_cache}") - # Get state with defaults state = self.get_state({ - "qk_norm": qk_norm, # Additional optimization - "use_rope": use_rope, # Additional optimization - "context_size": context_size, # Additional optimization - "use_kv_cache": use_kv_cache, # From paper Section 3.5 + "use_kv_cache": use_kv_cache, "workflow_count": 0, - "context_queue": [], # Additional: Store past context tensors for temporal consistency "kv_cache": {}, # From paper Section 3.5: Cache KV pairs for each prompt - "last_prompt_embeds": None, # From paper Section 3.5: For cache validation + "prompt_cache": {}, # Store prompt embeddings per module and dimension }) def cross_attention_forward(module, x, context=None, mask=None, value=None): @@ -435,82 +387,29 @@ def cross_attention_forward(module, x, context=None, mask=None, value=None): # Paper Section 3.5: KV Caching Logic cache_hit = False - if state["use_kv_cache"] and state["last_prompt_embeds"] is not None: - # Compare current context with cached prompt embeddings - if torch.allclose(context, state["last_prompt_embeds"], rtol=1e-5, atol=1e-5): - cache_hit = True - k, v = state["kv_cache"].get(module, (None, None)) - if k is not None and v is not None: - print("[StreamCrossAttention] Using cached KV pairs") + if state["use_kv_cache"] and context is not None: + # Create a unique key for this module and context shape + cache_key = (id(module), context.shape) + + if cache_key in state["prompt_cache"]: + cached_context = state["prompt_cache"][cache_key] + # Compare embeddings of same dimension + if torch.allclose(context, cached_context, rtol=1e-5, atol=1e-5): + cache_hit = True + k, v = state["kv_cache"].get(cache_key, (None, None)) + if k is not None and v is not None: + print("[StreamCrossAttention] Using cached KV pairs") if not cache_hit: - # Additional optimization: Temporal context management - if len(state["context_queue"]) >= state["context_size"]: - state["context_queue"].pop(0) - state["context_queue"].append(context.detach().clone()) - - # Additional optimization: Use past context for temporal consistency - full_context = torch.cat(state["context_queue"], dim=1) - - # Generate k/v for full context - k = module.to_k(full_context) - v = value if value is not None else module.to_v(full_context) + # Generate k/v for context + k = module.to_k(context) + v = value if value is not None else module.to_v(context) # Paper Section 3.5: Cache KV pairs for static prompts - if state["use_kv_cache"]: - state["last_prompt_embeds"] = context.detach().clone() - state["kv_cache"][module] = (k.detach().clone(), v.detach().clone()) - - # Additional optimization: QK normalization - if state["qk_norm"]: - q_norm = torch.nn.LayerNorm(q.shape[-1], device=q.device, dtype=q.dtype) - k_norm = torch.nn.LayerNorm(k.shape[-1], device=k.device, dtype=k.dtype) - q = q_norm(q) - k = k_norm(k) - - # Additional optimization: Rotary position embeddings - if state["use_rope"]: - # Calculate position embeddings - batch_size = q.shape[0] - seq_len = q.shape[1] - full_seq_len = k.shape[1] # Use full context length for k/v - dim = q.shape[2] - - # Create position indices for q and k separately - q_position = torch.arange(seq_len, device=q.device).unsqueeze(0).unsqueeze(-1) - k_position = torch.arange(full_seq_len, device=k.device).unsqueeze(0).unsqueeze(-1) - - q_position = q_position.repeat(batch_size, 1, dim//2) - k_position = k_position.repeat(batch_size, 1, dim//2) - - # Calculate frequencies - freq = 10000.0 ** (-torch.arange(0, dim//2, 2, device=q.device) / dim) - freq = freq.repeat((dim + 1) // 2)[:dim//2] - - # Calculate rotation angles - q_theta = q_position * freq - k_theta = k_position * freq - - # Apply rotations to q - q_cos = torch.cos(q_theta) - q_sin = torch.sin(q_theta) - q_reshaped = q.view(*q.shape[:-1], -1, 2) - q_out = torch.cat([ - q_reshaped[..., 0] * q_cos - q_reshaped[..., 1] * q_sin, - q_reshaped[..., 0] * q_sin + q_reshaped[..., 1] * q_cos - ], dim=-1) - - # Apply rotations to k - k_cos = torch.cos(k_theta) - k_sin = torch.sin(k_theta) - k_reshaped = k.view(*k.shape[:-1], -1, 2) - k_out = torch.cat([ - k_reshaped[..., 0] * k_cos - k_reshaped[..., 1] * k_sin, - k_reshaped[..., 0] * k_sin + k_reshaped[..., 1] * k_cos - ], dim=-1) - - q = q_out - k = k_out + if state["use_kv_cache"] and context is not None: + cache_key = (id(module), context.shape) + state["prompt_cache"][cache_key] = context.detach().clone() + state["kv_cache"][cache_key] = (k.detach().clone(), v.detach().clone()) # Standard attention computation with memory-efficient access pattern batch_size = q.shape[0] From 31e1f461e9b84b683ff3800734e14c168b088f90 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Wed, 12 Feb 2025 10:36:29 -0500 Subject: [PATCH 10/19] reorginization --- __init__.py | 4 +- stream_cfg.py | 207 +++++++++++++++ stream_conditioning.py | 143 +++++++++++ stream_diffusion_nodes.py | 525 ++++++++------------------------------ 4 files changed, 466 insertions(+), 413 deletions(-) create mode 100644 stream_cfg.py create mode 100644 stream_conditioning.py diff --git a/__init__.py b/__init__.py index 80bc9af..040a383 100644 --- a/__init__.py +++ b/__init__.py @@ -3,8 +3,10 @@ from .controls.utility_controls import FPSMonitor, SimilarityFilter, LazyCondition from .controls.motion_controls import MotionController, ROINode, IntegerMotionController from .misc_nodes import DTypeConverter, FastWebcamCapture, YOLOSimilarityCompare, TextRenderer, QuickShapeMask, MultilineText, LoadImageFromPath_ -from .stream_diffusion_nodes import StreamConditioning, StreamCFG, StreamCrossAttention +from .stream_diffusion_nodes import StreamCrossAttention from .stream_sampler import StreamBatchSampler, StreamScheduler, StreamFrameBuffer +from .stream_cfg import StreamCFG +from .stream_conditioning import StreamConditioning import re diff --git a/stream_cfg.py b/stream_cfg.py new file mode 100644 index 0000000..10d1984 --- /dev/null +++ b/stream_cfg.py @@ -0,0 +1,207 @@ +import torch +from .base.control_base import ControlNodeBase +import comfy.model_management +import comfy.samplers +import random +import math + +class StreamCFG(ControlNodeBase): + """Implements CFG approaches for temporal consistency between workflow runs""" + + RETURN_TYPES = ("MODEL",) + FUNCTION = "update" + CATEGORY = "real-time/sampling" + + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + del inputs["required"]["always_execute"] + inputs["required"].update({ + "model": ("MODEL",), + "cfg_type": (["self", "full", "initialize"], { + "default": "self", + "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" + }), + "residual_scale": ("FLOAT", { + "default": 0.7, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Scale factor for residual (higher = more temporal consistency)" + }), + "delta": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 5.0, + "step": 0.1, + "tooltip": "Delta parameter for self/initialize CFG types" + }), + }) + return inputs + + def __init__(self): + super().__init__() + self.last_model_hash = None + self.post_cfg_function = None + + def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.7, delta=1.0): + print(f"[StreamCFG] Initializing with cfg_type={cfg_type}, residual_scale={residual_scale}, delta={delta}") + + state = self.get_state({ + "last_uncond": None, + "initialized": False, + "cfg_type": cfg_type, + "residual_scale": residual_scale, + "delta": delta, + "workflow_count": 0, + "current_sigmas": None, + "seen_sigmas": set(), + "is_last_step": False, + "alpha_prod_t": None, + "beta_prod_t": None, + "c_skip": None, + "c_out": None, + "last_noise": None, # Store noise from previous frame + }) + + def post_cfg_function(args): + denoised = args["denoised"] + cond = args["cond"] + uncond = args["uncond"] + cond_denoised = args["cond_denoised"] + uncond_denoised = args["uncond_denoised"] + cond_scale = args["cond_scale"] + + print(f"\n[StreamCFG Debug] Step Info:") + print(f"- Workflow count: {state['workflow_count']}") + print(f"- CFG Type: {state['cfg_type']}") + print(f"- Tensor Stats:") + print(f" - denoised shape: {denoised.shape}, range: [{denoised.min():.3f}, {denoised.max():.3f}]") + print(f" - uncond_denoised shape: {uncond_denoised.shape}, range: [{uncond_denoised.min():.3f}, {uncond_denoised.max():.3f}]") + if state["last_uncond"] is not None: + print(f" - last_uncond shape: {state['last_uncond'].shape}, range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") + + sigma = args["sigma"] + if torch.is_tensor(sigma): + sigma = sigma[0].item() if len(sigma.shape) > 0 else sigma.item() + print(f"- Current sigma: {sigma:.6f}") + + model_options = args["model_options"] + sample_sigmas = model_options["transformer_options"].get("sample_sigmas", None) + + if sample_sigmas is not None and state["current_sigmas"] is None: + sigmas = [s.item() for s in sample_sigmas] + if sigmas[-1] == 0.0: + sigmas = sigmas[:-1] + state["current_sigmas"] = sigmas + state["seen_sigmas"] = set() + print(f"- New sigma sequence: {sigmas}") + + state["alpha_prod_t"] = torch.tensor([1.0 / (1.0 + s**2) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + state["beta_prod_t"] = torch.tensor([s / (1.0 + s**2) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + + state["c_skip"] = torch.tensor([1.0 / (s**2 + 1.0) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + state["c_out"] = torch.tensor([-s / torch.sqrt(torch.tensor(s**2 + 1.0)) for s in sigmas], + device=denoised.device, dtype=denoised.dtype) + + print(f"- Scaling factors for first step:") + print(f" alpha: {state['alpha_prod_t'][0]:.6f}") + print(f" beta: {state['beta_prod_t'][0]:.6f}") + print(f" c_skip: {state['c_skip'][0]:.6f}") + print(f" c_out: {state['c_out'][0]:.6f}") + + # Initialize noise for first step using previous frame if available + if state["last_uncond"] is not None and state["last_noise"] is not None: + # Scale noise based on current sigma + current_sigma = torch.tensor(sigmas[0], device=denoised.device, dtype=denoised.dtype) + scaled_noise = state["last_noise"] * current_sigma + + # Mix with previous frame prediction + alpha = 1.0 / (1.0 + current_sigma**2) + noisy_input = alpha * state["last_uncond"] + (1 - alpha) * scaled_noise + + # Update model input + if "input" in model_options: + model_options["input"] = noisy_input + print(f"- Initialized with previous frame, noise scale: {current_sigma:.6f}") + + state["seen_sigmas"].add(sigma) + state["is_last_step"] = False + if state["current_sigmas"] is not None: + is_last_step = len(state["seen_sigmas"]) >= len(state["current_sigmas"]) + if not is_last_step and sigma == min(state["current_sigmas"]): + is_last_step = True + state["is_last_step"] = is_last_step + print(f"- Is last step: {is_last_step}") + print(f"- Seen sigmas: {sorted(state['seen_sigmas'])}") + + # First workflow case + if state["last_uncond"] is None: + if state["is_last_step"]: + state["last_uncond"] = uncond_denoised.detach().clone() + # Store noise for next frame initialization + if "noise" in args: + state["last_noise"] = args["noise"].detach().clone() + state["workflow_count"] += 1 + state["current_sigmas"] = None + if cfg_type == "initialize": + state["initialized"] = True + self.set_state(state) + print("- First workflow complete, stored last_uncond and noise") + return denoised + + current_idx = len(state["seen_sigmas"]) - 1 + print(f"- Current step index: {current_idx}") + + # Apply temporal consistency at first step and blend throughout + if current_idx == 0: + # Strong influence at first step + noise_pred_uncond = state["last_uncond"] * state["delta"] + result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) + # Apply residual scale to entire result for stronger consistency + result = result * state["residual_scale"] + denoised * (1 - state["residual_scale"]) + else: + # Lighter influence in later steps + blend_scale = state["residual_scale"] * (1 - current_idx / len(state["current_sigmas"])) + result = denoised * (1 - blend_scale) + uncond_denoised * blend_scale + + print(f"- Result range after blending: [{result.min():.3f}, {result.max():.3f}]") + + # Store last prediction if this is the last step + if state["is_last_step"]: + state["last_uncond"] = uncond_denoised.detach().clone() + # Store noise for next frame initialization + if "noise" in args: + state["last_noise"] = args["noise"].detach().clone() + state["workflow_count"] += 1 + state["current_sigmas"] = None + self.set_state(state) + print(f"- Stored new last_uncond range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") + + return result + + # Store function reference to prevent garbage collection + self.post_cfg_function = post_cfg_function + + # Only set up post CFG function if model has changed + model_hash = hash(str(model)) + if model_hash != self.last_model_hash: + m = model.clone() + m.model_options = m.model_options.copy() + m.model_options["sampler_post_cfg_function"] = [self.post_cfg_function] + self.last_model_hash = model_hash + return (m,) + + # Make sure our function is still in the list + if not any(f is self.post_cfg_function for f in model.model_options.get("sampler_post_cfg_function", [])): + m = model.clone() + m.model_options = m.model_options.copy() + m.model_options["sampler_post_cfg_function"] = [self.post_cfg_function] + return (m,) + + return (model,) + + diff --git a/stream_conditioning.py b/stream_conditioning.py new file mode 100644 index 0000000..2164349 --- /dev/null +++ b/stream_conditioning.py @@ -0,0 +1,143 @@ +import torch +from .base.control_base import ControlNodeBase +import comfy.model_management +import comfy.samplers +import random +import math + +#NOTE: totally and utterly experimental. No theoretical backing whatsoever. +class StreamConditioning(ControlNodeBase): + """Applies Residual CFG to conditioning for improved temporal consistency with different CFG types""" + #NOTE: experimental + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "cfg_type": (["full", "self", "initialize"], { + "default": "full", + "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" + }), + "residual_scale": ("FLOAT", { + "default": 0.4, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Scale factor for residual conditioning (higher = more temporal consistency)" + }), + "delta": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 5.0, + "step": 0.1, + "tooltip": "Delta parameter for self/initialize CFG types" + }), + "context_size": ("INT", { + "default": 4, + "min": 1, + "max": 32, + "step": 1, + "tooltip": "Number of past conditionings to keep in context. Higher values = smoother transitions but more memory usage." + }), + "always_execute": ("BOOLEAN", { + "default": False, + }), + } + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "update" + CATEGORY = "real-time/control/utility" + + def __init__(self): + super().__init__() + + def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta=1.0, context_size=4, always_execute=False): + # Get state with defaults + state = self.get_state({ + "prev_positive": None, + "prev_negative": None, + "stock_noise": None, # For self/initialize CFG + "initialized": False, # For initialize CFG + "pos_context": [], # Store past positive conditionings + "neg_context": [] # Store past negative conditionings + }) + + # Extract conditioning tensors + current_pos_cond = positive[0][0] + current_neg_cond = negative[0][0] + + # Update context queues + if len(state["pos_context"]) >= context_size: + state["pos_context"].pop(0) + state["neg_context"].pop(0) + state["pos_context"].append(current_pos_cond.detach().clone()) + state["neg_context"].append(current_neg_cond.detach().clone()) + + # First frame case + if state["prev_positive"] is None: + state["prev_positive"] = current_pos_cond.detach().clone() + state["prev_negative"] = current_neg_cond.detach().clone() + if cfg_type == "initialize": + state["stock_noise"] = current_neg_cond.detach().clone() + elif cfg_type == "self": + state["stock_noise"] = current_neg_cond.detach().clone() * delta + self.set_state(state) + return (positive, negative) + + # Handle different CFG types + if cfg_type == "full": + # Use entire context for smoother transitions + pos_context = torch.stack(state["pos_context"], dim=0) + neg_context = torch.stack(state["neg_context"], dim=0) + + # Calculate weighted residuals across context + weights = torch.linspace(0.5, 1.0, len(state["pos_context"]), device=current_pos_cond.device) + pos_residual = (current_pos_cond - pos_context) * weights.view(-1, 1, 1) + neg_residual = (current_neg_cond - neg_context) * weights.view(-1, 1, 1) + + # Average residuals + pos_residual = pos_residual.mean(dim=0) + neg_residual = neg_residual.mean(dim=0) + + blended_pos = current_pos_cond + residual_scale * pos_residual + blended_neg = current_neg_cond + residual_scale * neg_residual + + # Update state + state["prev_positive"] = current_pos_cond.detach().clone() + state["prev_negative"] = current_neg_cond.detach().clone() + + # Reconstruct conditioning format + positive_out = [[blended_pos, positive[0][1]]] + negative_out = [[blended_neg, negative[0][1]]] + + else: # self or initialize + # Calculate residual for positive conditioning + pos_residual = current_pos_cond - state["prev_positive"] + blended_pos = current_pos_cond + residual_scale * pos_residual + + # Update stock noise based on current prediction + if cfg_type == "initialize" and not state["initialized"]: + # First prediction for initialize type + state["stock_noise"] = current_neg_cond.detach().clone() + state["initialized"] = True + else: + # Update stock noise with temporal consistency + stock_residual = current_neg_cond - state["stock_noise"] + state["stock_noise"] = current_neg_cond + residual_scale * stock_residual + + # Scale stock noise by delta + scaled_stock = state["stock_noise"] * delta + + # Update state + state["prev_positive"] = current_pos_cond.detach().clone() + state["prev_negative"] = scaled_stock.detach().clone() + + # Reconstruct conditioning format + positive_out = [[blended_pos, positive[0][1]]] + negative_out = [[scaled_stock, negative[0][1]]] + + self.set_state(state) + return (positive_out, negative_out) \ No newline at end of file diff --git a/stream_diffusion_nodes.py b/stream_diffusion_nodes.py index 7411471..1a49953 100644 --- a/stream_diffusion_nodes.py +++ b/stream_diffusion_nodes.py @@ -3,346 +3,10 @@ import comfy.model_management import comfy.samplers import random - -class StreamCFG(ControlNodeBase): - """Implements CFG approaches for temporal consistency between workflow runs""" - - RETURN_TYPES = ("MODEL",) - FUNCTION = "update" - CATEGORY = "real-time/sampling" - - @classmethod - def INPUT_TYPES(cls): - inputs = super().INPUT_TYPES() - inputs["required"].update({ - "model": ("MODEL",), - "cfg_type": (["self", "full", "initialize"], { - "default": "self", - "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" - }), - "residual_scale": ("FLOAT", { - "default": 0.7, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Scale factor for residual (higher = more temporal consistency)" - }), - "delta": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 5.0, - "step": 0.1, - "tooltip": "Delta parameter for self/initialize CFG types" - }), - }) - return inputs - - def __init__(self): - super().__init__() - self.last_model_hash = None - self.post_cfg_function = None - - def update(self, model, always_execute=True, cfg_type="self", residual_scale=0.7, delta=1.0): - print(f"[StreamCFG] Initializing with cfg_type={cfg_type}, residual_scale={residual_scale}, delta={delta}") - - state = self.get_state({ - "last_uncond": None, - "initialized": False, - "cfg_type": cfg_type, - "residual_scale": residual_scale, - "delta": delta, - "workflow_count": 0, - "current_sigmas": None, - "seen_sigmas": set(), - "is_last_step": False, - "alpha_prod_t": None, - "beta_prod_t": None, - "c_skip": None, - "c_out": None, - "last_noise": None, # Store noise from previous frame - }) - - def post_cfg_function(args): - denoised = args["denoised"] - cond = args["cond"] - uncond = args["uncond"] - cond_denoised = args["cond_denoised"] - uncond_denoised = args["uncond_denoised"] - cond_scale = args["cond_scale"] - - print(f"\n[StreamCFG Debug] Step Info:") - print(f"- Workflow count: {state['workflow_count']}") - print(f"- CFG Type: {state['cfg_type']}") - print(f"- Tensor Stats:") - print(f" - denoised shape: {denoised.shape}, range: [{denoised.min():.3f}, {denoised.max():.3f}]") - print(f" - uncond_denoised shape: {uncond_denoised.shape}, range: [{uncond_denoised.min():.3f}, {uncond_denoised.max():.3f}]") - if state["last_uncond"] is not None: - print(f" - last_uncond shape: {state['last_uncond'].shape}, range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") - - sigma = args["sigma"] - if torch.is_tensor(sigma): - sigma = sigma[0].item() if len(sigma.shape) > 0 else sigma.item() - print(f"- Current sigma: {sigma:.6f}") - - model_options = args["model_options"] - sample_sigmas = model_options["transformer_options"].get("sample_sigmas", None) - - if sample_sigmas is not None and state["current_sigmas"] is None: - sigmas = [s.item() for s in sample_sigmas] - if sigmas[-1] == 0.0: - sigmas = sigmas[:-1] - state["current_sigmas"] = sigmas - state["seen_sigmas"] = set() - print(f"- New sigma sequence: {sigmas}") - - state["alpha_prod_t"] = torch.tensor([1.0 / (1.0 + s**2) for s in sigmas], - device=denoised.device, dtype=denoised.dtype) - state["beta_prod_t"] = torch.tensor([s / (1.0 + s**2) for s in sigmas], - device=denoised.device, dtype=denoised.dtype) - - state["c_skip"] = torch.tensor([1.0 / (s**2 + 1.0) for s in sigmas], - device=denoised.device, dtype=denoised.dtype) - state["c_out"] = torch.tensor([-s / torch.sqrt(torch.tensor(s**2 + 1.0)) for s in sigmas], - device=denoised.device, dtype=denoised.dtype) - - print(f"- Scaling factors for first step:") - print(f" alpha: {state['alpha_prod_t'][0]:.6f}") - print(f" beta: {state['beta_prod_t'][0]:.6f}") - print(f" c_skip: {state['c_skip'][0]:.6f}") - print(f" c_out: {state['c_out'][0]:.6f}") - - # Initialize noise for first step using previous frame if available - if state["last_uncond"] is not None and state["last_noise"] is not None: - # Scale noise based on current sigma - current_sigma = torch.tensor(sigmas[0], device=denoised.device, dtype=denoised.dtype) - scaled_noise = state["last_noise"] * current_sigma - - # Mix with previous frame prediction - alpha = 1.0 / (1.0 + current_sigma**2) - noisy_input = alpha * state["last_uncond"] + (1 - alpha) * scaled_noise - - # Update model input - if "input" in model_options: - model_options["input"] = noisy_input - print(f"- Initialized with previous frame, noise scale: {current_sigma:.6f}") - - state["seen_sigmas"].add(sigma) - state["is_last_step"] = False - if state["current_sigmas"] is not None: - is_last_step = len(state["seen_sigmas"]) >= len(state["current_sigmas"]) - if not is_last_step and sigma == min(state["current_sigmas"]): - is_last_step = True - state["is_last_step"] = is_last_step - print(f"- Is last step: {is_last_step}") - print(f"- Seen sigmas: {sorted(state['seen_sigmas'])}") - - # First workflow case - if state["last_uncond"] is None: - if state["is_last_step"]: - state["last_uncond"] = uncond_denoised.detach().clone() - # Store noise for next frame initialization - if "noise" in args: - state["last_noise"] = args["noise"].detach().clone() - state["workflow_count"] += 1 - state["current_sigmas"] = None - if cfg_type == "initialize": - state["initialized"] = True - self.set_state(state) - print("- First workflow complete, stored last_uncond and noise") - return denoised - - current_idx = len(state["seen_sigmas"]) - 1 - print(f"- Current step index: {current_idx}") - - # Apply temporal consistency at first step and blend throughout - if current_idx == 0: - # Strong influence at first step - noise_pred_uncond = state["last_uncond"] * state["delta"] - result = noise_pred_uncond + cond_scale * (cond_denoised - noise_pred_uncond) - # Apply residual scale to entire result for stronger consistency - result = result * state["residual_scale"] + denoised * (1 - state["residual_scale"]) - else: - # Lighter influence in later steps - blend_scale = state["residual_scale"] * (1 - current_idx / len(state["current_sigmas"])) - result = denoised * (1 - blend_scale) + uncond_denoised * blend_scale - - print(f"- Result range after blending: [{result.min():.3f}, {result.max():.3f}]") - - # Store last prediction if this is the last step - if state["is_last_step"]: - state["last_uncond"] = uncond_denoised.detach().clone() - # Store noise for next frame initialization - if "noise" in args: - state["last_noise"] = args["noise"].detach().clone() - state["workflow_count"] += 1 - state["current_sigmas"] = None - self.set_state(state) - print(f"- Stored new last_uncond range: [{state['last_uncond'].min():.3f}, {state['last_uncond'].max():.3f}]") - - return result - - # Store function reference to prevent garbage collection - self.post_cfg_function = post_cfg_function - - # Only set up post CFG function if model has changed - model_hash = hash(str(model)) - if model_hash != self.last_model_hash: - m = model.clone() - m.model_options = m.model_options.copy() - m.model_options["sampler_post_cfg_function"] = [self.post_cfg_function] - self.last_model_hash = model_hash - return (m,) - - # Make sure our function is still in the list - if not any(f is self.post_cfg_function for f in model.model_options.get("sampler_post_cfg_function", [])): - m = model.clone() - m.model_options = m.model_options.copy() - m.model_options["sampler_post_cfg_function"] = [self.post_cfg_function] - return (m,) - - return (model,) - - -#NOTE: totally and utterly experimental. No theoretical backing whatsoever. -class StreamConditioning(ControlNodeBase): - """Applies Residual CFG to conditioning for improved temporal consistency with different CFG types""" - #NOTE: experimental - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "cfg_type": (["full", "self", "initialize"], { - "default": "full", - "tooltip": "Type of CFG to use: full (standard), self (memory efficient), or initialize (memory efficient with initialization)" - }), - "residual_scale": ("FLOAT", { - "default": 0.4, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Scale factor for residual conditioning (higher = more temporal consistency)" - }), - "delta": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 5.0, - "step": 0.1, - "tooltip": "Delta parameter for self/initialize CFG types" - }), - "context_size": ("INT", { - "default": 4, - "min": 1, - "max": 32, - "step": 1, - "tooltip": "Number of past conditionings to keep in context. Higher values = smoother transitions but more memory usage." - }), - "always_execute": ("BOOLEAN", { - "default": False, - }), - } - } - - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "update" - CATEGORY = "real-time/control/utility" - - def __init__(self): - super().__init__() - - def update(self, positive, negative, cfg_type="full", residual_scale=0.4, delta=1.0, context_size=4, always_execute=False): - # Get state with defaults - state = self.get_state({ - "prev_positive": None, - "prev_negative": None, - "stock_noise": None, # For self/initialize CFG - "initialized": False, # For initialize CFG - "pos_context": [], # Store past positive conditionings - "neg_context": [] # Store past negative conditionings - }) - - # Extract conditioning tensors - current_pos_cond = positive[0][0] - current_neg_cond = negative[0][0] - - # Update context queues - if len(state["pos_context"]) >= context_size: - state["pos_context"].pop(0) - state["neg_context"].pop(0) - state["pos_context"].append(current_pos_cond.detach().clone()) - state["neg_context"].append(current_neg_cond.detach().clone()) - - # First frame case - if state["prev_positive"] is None: - state["prev_positive"] = current_pos_cond.detach().clone() - state["prev_negative"] = current_neg_cond.detach().clone() - if cfg_type == "initialize": - state["stock_noise"] = current_neg_cond.detach().clone() - elif cfg_type == "self": - state["stock_noise"] = current_neg_cond.detach().clone() * delta - self.set_state(state) - return (positive, negative) - - # Handle different CFG types - if cfg_type == "full": - # Use entire context for smoother transitions - pos_context = torch.stack(state["pos_context"], dim=0) - neg_context = torch.stack(state["neg_context"], dim=0) - - # Calculate weighted residuals across context - weights = torch.linspace(0.5, 1.0, len(state["pos_context"]), device=current_pos_cond.device) - pos_residual = (current_pos_cond - pos_context) * weights.view(-1, 1, 1) - neg_residual = (current_neg_cond - neg_context) * weights.view(-1, 1, 1) - - # Average residuals - pos_residual = pos_residual.mean(dim=0) - neg_residual = neg_residual.mean(dim=0) - - blended_pos = current_pos_cond + residual_scale * pos_residual - blended_neg = current_neg_cond + residual_scale * neg_residual - - # Update state - state["prev_positive"] = current_pos_cond.detach().clone() - state["prev_negative"] = current_neg_cond.detach().clone() - - # Reconstruct conditioning format - positive_out = [[blended_pos, positive[0][1]]] - negative_out = [[blended_neg, negative[0][1]]] - - else: # self or initialize - # Calculate residual for positive conditioning - pos_residual = current_pos_cond - state["prev_positive"] - blended_pos = current_pos_cond + residual_scale * pos_residual - - # Update stock noise based on current prediction - if cfg_type == "initialize" and not state["initialized"]: - # First prediction for initialize type - state["stock_noise"] = current_neg_cond.detach().clone() - state["initialized"] = True - else: - # Update stock noise with temporal consistency - stock_residual = current_neg_cond - state["stock_noise"] - state["stock_noise"] = current_neg_cond + residual_scale * stock_residual - - # Scale stock noise by delta - scaled_stock = state["stock_noise"] * delta - - # Update state - state["prev_positive"] = current_pos_cond.detach().clone() - state["prev_negative"] = scaled_stock.detach().clone() - - # Reconstruct conditioning format - positive_out = [[blended_pos, positive[0][1]]] - negative_out = [[scaled_stock, negative[0][1]]] - - self.set_state(state) - return (positive_out, negative_out) - +import math class StreamCrossAttention(ControlNodeBase): - """Implements optimized cross attention with KV-cache for real-time generation + DESCRIPTION="""Implements optimized cross attention with KV-cache for real-time generation Paper reference: StreamDiffusion Section 3.5 "Pre-computation" - Pre-computes and caches prompt embeddings @@ -357,95 +21,137 @@ class StreamCrossAttention(ControlNodeBase): @classmethod def INPUT_TYPES(cls): inputs = super().INPUT_TYPES() + del inputs["required"]["always_execute"] inputs["required"].update({ "model": ("MODEL",), - "use_kv_cache": ("BOOLEAN", { - "default": True, - "tooltip": "Paper Section 3.5: Whether to cache key-value pairs for static prompts to avoid recomputation" + "prompt": ("STRING", { + "multiline": True, + "forceInput": True, + "default": "", + "tooltip": "Text prompt to use for caching. Only recomputes when this changes." + }), + "max_cache_size": ("INT", { + "default": 8, + "min": 1, + "max": 32, + "step": 1, + "tooltip": "Maximum number of cached entries per module" }), }) return inputs - + def __init__(self): super().__init__() - self.last_model_hash = None self.cross_attention_hook = None - def update(self, model, always_execute=True, use_kv_cache=True): - print(f"[StreamCrossAttention] Initializing with use_kv_cache={use_kv_cache}") + def update(self, model, prompt="", max_cache_size=8): + print(f"[StreamCrossAttention] Initializing with prompt='{prompt}', max_cache_size={max_cache_size}") + + # NOTE: Unlike the StreamDiffusion paper, we don't explicitly compare prompts to detect changes. + # Instead, we leverage ComfyUI's execution system: + # 1. Our node only executes when inputs (model, prompt, max_cache_size) change + # 2. When executed, we get the new prompt value and automatically recompute KV pairs + # 3. The cache key system using (module_id, prompt) ensures we use the right KVs + + # This is more efficient as we avoid explicit prompt comparison and let ComfyUI handle change detection. + # We do not expect the model or max cache size to change often. state = self.get_state({ - "use_kv_cache": use_kv_cache, + "max_cache_size": max_cache_size, "workflow_count": 0, "kv_cache": {}, # From paper Section 3.5: Cache KV pairs for each prompt - "prompt_cache": {}, # Store prompt embeddings per module and dimension + "cache_keys_by_module": {}, # Track cache keys per module for LRU eviction }) + def manage_cache_size(module_id): + """Maintain cache size limits using LRU eviction""" + if module_id in state["cache_keys_by_module"]: + module_keys = state["cache_keys_by_module"][module_id] + while len(module_keys) > state["max_cache_size"]: + # Remove oldest cache entry + old_key = module_keys.pop(0) + if old_key in state["kv_cache"]: + del state["kv_cache"][old_key] + + def get_cache_key(module, prompt): + """Generate cache key from module ID and prompt text""" + return (id(module), prompt) + def cross_attention_forward(module, x, context=None, mask=None, value=None): - q = module.to_q(x) + """Optimized cross attention following StreamDiffusion's approach""" + batch_size = x.shape[0] context = x if context is None else context - # Paper Section 3.5: KV Caching Logic - cache_hit = False - if state["use_kv_cache"] and context is not None: - # Create a unique key for this module and context shape - cache_key = (id(module), context.shape) - - if cache_key in state["prompt_cache"]: - cached_context = state["prompt_cache"][cache_key] - # Compare embeddings of same dimension - if torch.allclose(context, cached_context, rtol=1e-5, atol=1e-5): - cache_hit = True - k, v = state["kv_cache"].get(cache_key, (None, None)) - if k is not None and v is not None: - print("[StreamCrossAttention] Using cached KV pairs") + # Debug cache hit/miss + cache_key = get_cache_key(module, prompt) + cache_hit = cache_key in state["kv_cache"] + print(f"[StreamCrossAttn] Cache {'hit' if cache_hit else 'miss'} for module {id(module)}") + print(f"[StreamCrossAttn] Cache key: {cache_key}") - if not cache_hit: - # Generate k/v for context + # Check cache + if cache_hit: + k, v = state["kv_cache"][cache_key] + print(f"[StreamCrossAttn] Reusing cached KV pairs shape k:{k.shape} v:{v.shape}") + + # Update LRU tracking + module_keys = state["cache_keys_by_module"].get(id(module), []) + if cache_key in module_keys: + module_keys.remove(cache_key) + module_keys.append(cache_key) + state["cache_keys_by_module"][id(module)] = module_keys + else: + # Generate new KV pairs k = module.to_k(context) v = value if value is not None else module.to_v(context) + print(f"[StreamCrossAttn] Generated new KV pairs shape k:{k.shape} v:{v.shape}") + + # Cache without cloning - just use references + state["kv_cache"][cache_key] = (k, v) - # Paper Section 3.5: Cache KV pairs for static prompts - if state["use_kv_cache"] and context is not None: - cache_key = (id(module), context.shape) - state["prompt_cache"][cache_key] = context.detach().clone() - state["kv_cache"][cache_key] = (k.detach().clone(), v.detach().clone()) + # Update LRU tracking + module_keys = state["cache_keys_by_module"].get(id(module), []) + module_keys.append(cache_key) + state["cache_keys_by_module"][id(module)] = module_keys + + # Basic LRU cleanup + if len(module_keys) > state["max_cache_size"]: + old_key = module_keys.pop(0) + if old_key in state["kv_cache"]: + print(f"[StreamCrossAttn] Evicting old cache key {old_key}") + del state["kv_cache"][old_key] + + # Generate query + q = module.to_q(x) + print(f"[StreamCrossAttn] Generated query shape:{q.shape}") - # Standard attention computation with memory-efficient access pattern - batch_size = q.shape[0] - q_seq_len = q.shape[1] - k_seq_len = k.shape[1] + # Efficient single-pass reshape head_dim = q.shape[-1] // module.heads - - # Reshape for multi-head attention - q = q.view(batch_size, q_seq_len, module.heads, head_dim) - k = k.view(batch_size, k_seq_len, module.heads, head_dim) - v = v.view(batch_size, k_seq_len, module.heads, head_dim) - - # Transpose for attention computation - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Compute attention scores - scale = head_dim ** -0.5 + q = q.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) + k = k.view(-1, k.shape[1], module.heads, head_dim).transpose(1, 2) + v = v.view(-1, v.shape[1], module.heads, head_dim).transpose(1, 2) + + # Handle batch size expansion if needed + if k.shape[0] == 1 and batch_size > 1: + print(f"[StreamCrossAttn] Expanding cached KV pairs from batch 1 to {batch_size}") + k = k.expand(batch_size, -1, -1, -1) + v = v.expand(batch_size, -1, -1, -1) + + # Simple attention computation + scale = 1.0 / math.sqrt(head_dim) scores = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: scores = scores + mask - # Apply attention + # Compute attention and output attn = torch.softmax(scores, dim=-1) out = torch.matmul(attn, v) + print(f"[StreamCrossAttn] Attention output shape:{out.shape}") - # Reshape back - out = out.transpose(1, 2).contiguous() - out = out.view(batch_size, q_seq_len, -1) - - # Project back to original dimension - out = module.to_out[0](out) + # Final reshape + out = out.transpose(1, 2).reshape(batch_size, -1, module.heads * head_dim) - return out + return module.to_out[0](out) def hook_cross_attention(module, input, output): if isinstance(module, torch.nn.Module) and hasattr(module, "to_q"): @@ -455,26 +161,21 @@ def hook_cross_attention(module, input, output): # Replace with our optimized version module.forward = lambda *args, **kwargs: cross_attention_forward(module, *args, **kwargs) return output + + # Remove old hooks if they exist + if self.cross_attention_hook is not None: + self.cross_attention_hook.remove() - # Only set up hooks if model has changed - model_hash = hash(str(model)) - if model_hash != self.last_model_hash: - m = model.clone() - - # Remove old hooks if they exist - if self.cross_attention_hook is not None: - self.cross_attention_hook.remove() - - # Register hook for cross attention modules - def register_hooks(module): - if isinstance(module, torch.nn.Module) and hasattr(module, "to_q"): - self.cross_attention_hook = module.register_forward_hook(hook_cross_attention) - - m.model.apply(register_hooks) - self.last_model_hash = model_hash - return (m,) + # Clone model and apply hooks + m = model.clone() + + # Register hook for cross attention modules + def register_hooks(module): + if isinstance(module, torch.nn.Module) and hasattr(module, "to_q"): + self.cross_attention_hook = module.register_forward_hook(hook_cross_attention) - return (model,) + m.model.apply(register_hooks) + return (m,) From f8c05ac56c48719c1e1b28c3ad8ab013c2cbf089 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sat, 1 Mar 2025 23:12:54 -0500 Subject: [PATCH 11/19] txt2iigm --- stream_sampler.py | 64 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/stream_sampler.py b/stream_sampler.py index c59275e..837eb49 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -32,6 +32,7 @@ def __init__(self): self.frame_buffer = [] self.x_t_latent_buffer = None self.stock_noise = None + self.is_txt2img_mode = False def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): """Sample with staggered batch denoising steps""" @@ -45,6 +46,16 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N print(f"[StreamBatchSampler] Input sigmas: {sigmas}") print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, min: {noise.min():.3f}, max: {noise.max():.3f}") + # Detect if we're in text-to-image mode by checking if noise is all zeros + # This happens when empty latents are provided + self.is_txt2img_mode = torch.allclose(noise, torch.zeros_like(noise), atol=1e-6) + + if self.is_txt2img_mode: + print(f"[StreamBatchSampler] Detected text-to-image mode (empty latents)") + # For text-to-image, we'll use pure random noise + noise = torch.randn_like(noise) + print(f"[StreamBatchSampler] Generated random noise for txt2img: {noise.shape}") + # Verify batch size matches number of timesteps if batch_size != num_sigmas: raise ValueError(f"Batch size ({batch_size}) must match number of timesteps ({num_sigmas})") @@ -57,20 +68,25 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N print(f"[StreamBatchSampler] Beta values: {beta_prod_t.view(-1)}") # Initialize stock noise if needed - if self.stock_noise is None: + if self.stock_noise is None or self.is_txt2img_mode: self.stock_noise = torch.randn_like(noise[0]) # Random noise instead of zeros print(f"[StreamBatchSampler] Initialized random stock noise with shape: {self.stock_noise.shape}") # Scale noise for each frame based on its sigma scaled_noise = [] for i in range(batch_size): - frame_noise = noise[i] + self.stock_noise * sigmas[i] # Add scaled noise to input + if self.is_txt2img_mode: + # For txt2img, use pure noise scaled by sigma + frame_noise = self.stock_noise * sigmas[i] + else: + # For img2img, add scaled noise to input + frame_noise = noise[i] + self.stock_noise * sigmas[i] scaled_noise.append(frame_noise) x = torch.stack(scaled_noise, dim=0) print(f"[StreamBatchSampler] Scaled noise shape: {x.shape}, min: {x.min():.3f}, max: {x.max():.3f}") # Initialize frame buffer if needed - if self.x_t_latent_buffer is None and num_sigmas > 1: + if (self.x_t_latent_buffer is None or self.is_txt2img_mode) and num_sigmas > 1: self.x_t_latent_buffer = x[0].clone() # Initialize with noised first frame print(f"[StreamBatchSampler] Initialized buffer with shape: {self.x_t_latent_buffer.shape}") @@ -197,6 +213,7 @@ def __init__(self): super().__init__() self.frame_buffer = [] # List to store incoming frames self.buffer_size = None + self.is_txt2img_mode = False def update(self, latent, buffer_size=4, always_execute=True): """Add new frame to buffer and return batch when ready""" @@ -204,14 +221,35 @@ def update(self, latent, buffer_size=4, always_execute=True): # Extract latent tensor from input and remove batch dimension if present x = latent["samples"] - if x.dim() == 4: # [B,C,H,W] + + # Check if this is an empty latent (for txt2img) + is_empty_latent = x.numel() == 0 or (x.dim() > 0 and x.shape[0] == 0) + + if is_empty_latent: + self.is_txt2img_mode = True + print(f"[StreamFrameBuffer] Detected empty latent for text-to-image mode") + # Create empty latents with correct dimensions for txt2img + # Get dimensions from latent dict + height = latent.get("height", 512) + width = latent.get("width", 512) + + # Calculate latent dimensions (typically 1/8 of image dimensions for SD) + latent_height = height // 8 + latent_width = width // 8 + + # Create zero tensor with correct shape + x = torch.zeros((4, latent_height, latent_width), + device=comfy.model_management.get_torch_device()) + print(f"[StreamFrameBuffer] Created empty latent with shape: {x.shape}") + elif x.dim() == 4: # [B,C,H,W] + self.is_txt2img_mode = False x = x.squeeze(0) # Remove batch dimension -> [C,H,W] # Add new frame to buffer - if len(self.frame_buffer) == 0: - # First frame - initialize buffer with copies + if len(self.frame_buffer) == 0 or self.is_txt2img_mode: + # First frame or txt2img mode - initialize buffer with copies self.frame_buffer = [x.clone() for _ in range(self.buffer_size)] - print(f"[StreamFrameBuffer] Initialized buffer with {self.buffer_size} copies of first frame") + print(f"[StreamFrameBuffer] Initialized buffer with {self.buffer_size} copies of frame") else: # Shift frames forward and add new frame self.frame_buffer.pop(0) # Remove oldest frame @@ -222,5 +260,13 @@ def update(self, latent, buffer_size=4, always_execute=True): batch = torch.stack(self.frame_buffer, dim=0) # [B,C,H,W] print(f"[StreamFrameBuffer] Created batch with shape: {batch.shape}") - # Return as latent dict - return ({"samples": batch},) + # Return as latent dict with preserved dimensions + result = {"samples": batch} + + # Preserve height and width if present in input + if "height" in latent: + result["height"] = latent["height"] + if "width" in latent: + result["width"] = latent["width"] + + return (result,) From 72768100c7c56a5673017ab0a444234f1c9609f7 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Thu, 6 Mar 2025 19:10:56 -0500 Subject: [PATCH 12/19] optimizations --- stream_sampler.py | 92 +++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/stream_sampler.py b/stream_sampler.py index 837eb49..937c7a6 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -37,24 +37,18 @@ def __init__(self): def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): """Sample with staggered batch denoising steps""" extra_args = {} if extra_args is None else extra_args - print(f"[StreamBatchSampler] Starting sampling with {len(sigmas)-1} steps") # Get number of frames in batch and available sigmas batch_size = noise.shape[0] num_sigmas = len(sigmas) - 1 # Subtract 1 because last sigma is the target (0.0) - print(f"[StreamBatchSampler] Input sigmas: {sigmas}") - print(f"[StreamBatchSampler] Input noise shape: {noise.shape}, min: {noise.min():.3f}, max: {noise.max():.3f}") - # Detect if we're in text-to-image mode by checking if noise is all zeros # This happens when empty latents are provided self.is_txt2img_mode = torch.allclose(noise, torch.zeros_like(noise), atol=1e-6) if self.is_txt2img_mode: - print(f"[StreamBatchSampler] Detected text-to-image mode (empty latents)") # For text-to-image, we'll use pure random noise noise = torch.randn_like(noise) - print(f"[StreamBatchSampler] Generated random noise for txt2img: {noise.shape}") # Verify batch size matches number of timesteps if batch_size != num_sigmas: @@ -64,52 +58,41 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N alpha_prod_t = (sigmas[:-1] / sigmas[0]).view(-1, 1, 1, 1) # [B,1,1,1] beta_prod_t = (1 - alpha_prod_t) - print(f"[StreamBatchSampler] Alpha values: {alpha_prod_t.view(-1)}") - print(f"[StreamBatchSampler] Beta values: {beta_prod_t.view(-1)}") # Initialize stock noise if needed - if self.stock_noise is None or self.is_txt2img_mode: + if self.stock_noise is None or self.is_txt2img_mode: # Kept original condition for functional equivalence self.stock_noise = torch.randn_like(noise[0]) # Random noise instead of zeros - print(f"[StreamBatchSampler] Initialized random stock noise with shape: {self.stock_noise.shape}") - - # Scale noise for each frame based on its sigma - scaled_noise = [] - for i in range(batch_size): - if self.is_txt2img_mode: - # For txt2img, use pure noise scaled by sigma - frame_noise = self.stock_noise * sigmas[i] - else: - # For img2img, add scaled noise to input - frame_noise = noise[i] + self.stock_noise * sigmas[i] - scaled_noise.append(frame_noise) - x = torch.stack(scaled_noise, dim=0) - print(f"[StreamBatchSampler] Scaled noise shape: {x.shape}, min: {x.min():.3f}, max: {x.max():.3f}") + + # Optimization: Vectorize noise scaling instead of looping + sigmas_view = sigmas[:-1].view(-1, 1, 1, 1) # Reshape for broadcasting + if self.is_txt2img_mode: + x = self.stock_noise.unsqueeze(0) * sigmas_view # Broadcast stock_noise across batch + else: + x = noise + self.stock_noise.unsqueeze(0) * sigmas_view # Add scaled noise to input # Initialize frame buffer if needed - if (self.x_t_latent_buffer is None or self.is_txt2img_mode) and num_sigmas > 1: - self.x_t_latent_buffer = x[0].clone() # Initialize with noised first frame - print(f"[StreamBatchSampler] Initialized buffer with shape: {self.x_t_latent_buffer.shape}") + if (self.x_t_latent_buffer is None or self.is_txt2img_mode) and num_sigmas > 1: # Kept original condition + # Optimization: Pre-allocate and copy instead of clone + self.x_t_latent_buffer = torch.empty_like(x[0]) # Pre-allocate memory + self.x_t_latent_buffer.copy_(x[0]) # In-place copy # Use buffer for first frame to maintain temporal consistency if num_sigmas > 1: - x = torch.cat([self.x_t_latent_buffer.unsqueeze(0), x[1:]], dim=0) - print(f"[StreamBatchSampler] Combined with buffer, shape: {x.shape}") + # Optimization: Update in-place instead of concatenating + x[0] = self.x_t_latent_buffer # Replace first frame with buffer # Run model on entire batch at once with torch.no_grad(): # Process all frames in parallel sigma_batch = sigmas[:-1] - print(f"[StreamBatchSampler] Using sigmas for denoising: {sigma_batch}") denoised_batch = model(x, sigma_batch, **extra_args) - print(f"[StreamBatchSampler] Denoised batch shape: {denoised_batch.shape}") - print(f"[StreamBatchSampler] Denoised stats - min: {denoised_batch.min():.3f}, max: {denoised_batch.max():.3f}") # Update buffer with intermediate results if num_sigmas > 1: # Store result from first frame as buffer for next iteration - self.x_t_latent_buffer = denoised_batch[0].clone() - print(f"[StreamBatchSampler] Updated buffer with shape: {self.x_t_latent_buffer.shape}") + # Optimization: Use in-place copy instead of clone + self.x_t_latent_buffer.copy_(denoised_batch[0]) # Update buffer in-place # Return result from last frame x_0_pred_out = denoised_batch[-1].unsqueeze(0) @@ -211,10 +194,13 @@ def INPUT_TYPES(cls): def __init__(self): super().__init__() - self.frame_buffer = [] # List to store incoming frames + # Optimization: Replace list with a pre-allocated tensor ring buffer + self.frame_buffer = None # Tensor of shape [buffer_size, C, H, W] self.buffer_size = None + self.buffer_pos = 0 # Current position in ring buffer + self.is_initialized = False # Track buffer initialization self.is_txt2img_mode = False - + def update(self, latent, buffer_size=4, always_execute=True): """Add new frame to buffer and return batch when ready""" self.buffer_size = buffer_size @@ -245,19 +231,29 @@ def update(self, latent, buffer_size=4, always_execute=True): self.is_txt2img_mode = False x = x.squeeze(0) # Remove batch dimension -> [C,H,W] - # Add new frame to buffer - if len(self.frame_buffer) == 0 or self.is_txt2img_mode: - # First frame or txt2img mode - initialize buffer with copies - self.frame_buffer = [x.clone() for _ in range(self.buffer_size)] - print(f"[StreamFrameBuffer] Initialized buffer with {self.buffer_size} copies of frame") + # Optimization: Initialize or resize frame_buffer as a tensor + if not self.is_initialized or self.frame_buffer.shape[0] != self.buffer_size or \ + self.frame_buffer.shape[1:] != x.shape: + # Pre-allocate buffer with correct shape + self.frame_buffer = torch.zeros( + (self.buffer_size, *x.shape), + device=x.device, + dtype=x.dtype + ) + if self.is_txt2img_mode or not self.is_initialized: + # Optimization: Use broadcasting to fill buffer with copies + self.frame_buffer[:] = x.unsqueeze(0) # Broadcast x to [buffer_size, C, H, W] + print(f"[StreamFrameBuffer] Initialized buffer with {self.buffer_size} copies of frame") + self.is_initialized = True + self.buffer_pos = 0 else: - # Shift frames forward and add new frame - self.frame_buffer.pop(0) # Remove oldest frame - self.frame_buffer.append(x.clone()) # Add new frame - print(f"[StreamFrameBuffer] Added new frame to buffer") - - # Stack frames into batch - batch = torch.stack(self.frame_buffer, dim=0) # [B,C,H,W] + # Add new frame to buffer using ring buffer logic + self.frame_buffer[self.buffer_pos] = x # In-place update + print(f"[StreamFrameBuffer] Added new frame to buffer at position {self.buffer_pos}") + self.buffer_pos = (self.buffer_pos + 1) % self.buffer_size # Circular increment + + # Optimization: frame_buffer is already a tensor batch, no need to stack + batch = self.frame_buffer print(f"[StreamFrameBuffer] Created batch with shape: {batch.shape}") # Return as latent dict with preserved dimensions @@ -269,4 +265,4 @@ def update(self, latent, buffer_size=4, always_execute=True): if "width" in latent: result["width"] = latent["width"] - return (result,) + return (result,) \ No newline at end of file From b24c28011b30157625feb8bb6fd5166cb7ba9fe4 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Mar 2025 11:28:54 -0400 Subject: [PATCH 13/19] add profiling, remove controlnode base --- stream_sampler.py | 246 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 188 insertions(+), 58 deletions(-) diff --git a/stream_sampler.py b/stream_sampler.py index 937c7a6..aee363e 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -1,39 +1,168 @@ import torch -from .base.control_base import ControlNodeBase import comfy.model_management import comfy.samplers import random +import time +import os +import json +from datetime import datetime +from functools import wraps +import cProfile +import pstats +import io +# Simple profiling setup - SINGLE FILE +PROFILE_FILE = "stream_sampler_profile.json" +PROFILE_DATA = [] -class StreamBatchSampler(ControlNodeBase): - """Implements batched denoising for faster inference by processing multiple frames in parallel at different denoising steps""" +def profile_time(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Start timing + start_time = time.time() + + # Check memory before + if torch.cuda.is_available(): + torch.cuda.synchronize() + mem_before = torch.cuda.memory_allocated() / (1024 * 1024) + + # Execute the function + result = func(*args, **kwargs) + + # End timing + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.time() + + # Check memory after + if torch.cuda.is_available(): + mem_after = torch.cuda.memory_allocated() / (1024 * 1024) + mem_diff = mem_after - mem_before + else: + mem_diff = 0 + + exec_time = end_time - start_time + + # Get shape info if available + shape_info = "unknown" + if len(args) > 1 and hasattr(args[1], 'shape'): # If this is sample(), args[1] is noise + shape_info = str(args[1].shape) + + # Log results + entry = { + 'function': func.__name__, + 'execution_time': exec_time, + 'memory_diff_mb': mem_diff, + 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'shape': shape_info + } + + PROFILE_DATA.append(entry) + + # Print to console + print(f"PROFILE: {func.__name__} - Time: {exec_time:.4f}s, Memory: {mem_diff:.2f}MB") + + # Save to file after each run + save_profile_data() + + return result + return wrapper + + +def profile_cprofile(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Create profiler + pr = cProfile.Profile() + pr.enable() + + # Execute the function + result = func(*args, **kwargs) + + # Disable profiler + pr.disable() + + # Get stats + s = io.StringIO() + ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') + ps.print_stats(20) + + # Add to the most recent profile entry + if PROFILE_DATA: + PROFILE_DATA[-1]['cprofile_data'] = s.getvalue() + + return result + return wrapper + + +def save_profile_data(): + """Save all profiling data to a single JSON file""" + if not PROFILE_DATA: + return + + # Convert to serializable format + json_data = [] + for entry in PROFILE_DATA: + serializable_entry = { + 'function': entry['function'], + 'execution_time': float(entry['execution_time']), + 'memory_diff_mb': float(entry['memory_diff_mb']), + 'timestamp': entry['timestamp'], + 'shape': entry.get('shape', 'unknown') + } + + if 'cprofile_data' in entry: + # Only store recent cprofile data to keep file size manageable + if len(json_data) < 10 or len(json_data) % 10 == 0: + serializable_entry['cprofile_data'] = entry['cprofile_data'] + + json_data.append(serializable_entry) + + # Write to a single file + with open(PROFILE_FILE, 'w') as f: + json.dump(json_data, f, indent=2) + + # Calculate stats + sample_times = [x['execution_time'] for x in PROFILE_DATA if x['function'] == 'sample'] + if sample_times: + avg_time = sum(sample_times) / len(sample_times) + min_time = min(sample_times) + max_time = max(sample_times) + print(f"PROFILE SUMMARY: {len(sample_times)} runs, Avg: {avg_time:.4f}s, Min: {min_time:.4f}s, Max: {max_time:.4f}s") + + print(f"Profiling data saved to {PROFILE_FILE}") + + +class StreamBatchSampler: + RETURN_TYPES = ("SAMPLER",) FUNCTION = "update" - CATEGORY = "real-time/sampling" - + CATEGORY = "StreamPack/sampling" + DESCRIPTION = "Implements batched denoising for faster inference by processing multiple frames in parallel at different denoising steps. Also adds temportal consistency to the denoising process." @classmethod def INPUT_TYPES(cls): - inputs = super().INPUT_TYPES() - inputs["required"].update({ - "num_steps": ("INT", { - "default": 4, - "min": 1, - "max": 10, - "step": 1, - "tooltip": "Number of denoising steps. Should match the frame buffer size." - }), - }) - return inputs + return { + "required": { + "num_steps": ("INT", { + "default": 4, + "min": 1, + "max": 10, + "step": 1, + "tooltip": "Number of denoising steps. Should match the frame buffer size." + }), + }, + } def __init__(self): - super().__init__() self.num_steps = None self.frame_buffer = [] self.x_t_latent_buffer = None self.stock_noise = None self.is_txt2img_mode = False + @profile_time + @profile_cprofile def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): """Sample with staggered batch denoising steps""" extra_args = {} if extra_args is None else extra_args @@ -106,40 +235,43 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N return x_0_pred_out - def update(self, num_steps=4, always_execute=True): + @profile_time + def update(self, num_steps=4): """Create sampler with specified settings""" self.num_steps = num_steps sampler = comfy.samplers.KSAMPLER(self.sample) return (sampler,) +# Print setup info when module is imported +print(f"StreamBatchSampler profiling enabled. Results will be saved to {PROFILE_FILE}") -class StreamScheduler(ControlNodeBase): - """Implements StreamDiffusion's efficient timestep selection""" + +class StreamScheduler: RETURN_TYPES = ("SIGMAS",) FUNCTION = "update" - CATEGORY = "real-time/sampling" - + CATEGORY = "StreamPack/sampling" + DESCRIPTION = "Implements StreamDiffusion's efficient timestep selection. Use in conjunction with StreamBatchSampler." @classmethod def INPUT_TYPES(cls): - inputs = super().INPUT_TYPES() - inputs["required"].update({ - "model": ("MODEL",), - "t_index_list": ("STRING", { - "default": "32,45", - "tooltip": "Comma-separated list of timesteps to actually use for denoising. Examples: '32,45' for img2img or '0,16,32,45' for txt2img" - }), - "num_inference_steps": ("INT", { - "default": 50, - "min": 1, - "max": 1000, - "step": 1, - "tooltip": "Total number of timesteps in schedule. StreamDiffusion uses 50 by default. Only timesteps specified in t_index_list are actually used." - }), - }) - return inputs + return { + "required": { + "model": ("MODEL",), + "t_index_list": ("STRING", { + "default": "32,45", + "tooltip": "Comma-separated list of timesteps to actually use for denoising. Examples: '32,45' for img2img or '0,16,32,45' for txt2img" + }), + "num_inference_steps": ("INT", { + "default": 50, + "min": 1, + "max": 1000, + "step": 1, + "tooltip": "Total number of timesteps in schedule. StreamDiffusion uses 50 by default. Only timesteps specified in t_index_list are actually used." + }), + }, + } - def update(self, model, t_index_list="32,45", num_inference_steps=50, always_execute=True): + def update(self, model, t_index_list="32,45", num_inference_steps=50): # Get model's sampling parameters model_sampling = model.get_model_object("model_sampling") @@ -170,38 +302,36 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50, always_exe return (selected_sigmas,) -class StreamFrameBuffer(ControlNodeBase): - """Accumulates frames to enable staggered batch denoising like StreamDiffusion""" +class StreamFrameBuffer: + RETURN_TYPES = ("LATENT",) FUNCTION = "update" - CATEGORY = "real-time/sampling" - + CATEGORY = "StreamPack/sampling" + DESCRIPTION = "Accumulates frames to enable staggered batch denoising like StreamDiffusion. Use in conjunction with StreamBatchSampler" @classmethod def INPUT_TYPES(cls): - inputs = super().INPUT_TYPES() - inputs["required"].update({ - "latent": ("LATENT",), - "buffer_size": ("INT", { - "default": 4, - "min": 1, - "max": 10, - "step": 1, - "tooltip": "Number of frames to buffer before starting batch processing. Should match number of denoising steps." - }), - }) - return inputs + return { + "required": { + "latent": ("LATENT",), + "buffer_size": ("INT", { + "default": 4, + "min": 1, + "max": 10, + "step": 1, + "tooltip": "Number of frames to buffer before starting batch processing. Should match number of denoising steps." + }), + }, + } def __init__(self): - super().__init__() - # Optimization: Replace list with a pre-allocated tensor ring buffer self.frame_buffer = None # Tensor of shape [buffer_size, C, H, W] self.buffer_size = None self.buffer_pos = 0 # Current position in ring buffer self.is_initialized = False # Track buffer initialization self.is_txt2img_mode = False - def update(self, latent, buffer_size=4, always_execute=True): + def update(self, latent, buffer_size=4): """Add new frame to buffer and return batch when ready""" self.buffer_size = buffer_size From 22a852b7f0478c848753bcb67452a4e5af78cadc Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Mar 2025 11:46:04 -0400 Subject: [PATCH 14/19] mem alloc optim --- stream_sampler.py | 106 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 31 deletions(-) diff --git a/stream_sampler.py b/stream_sampler.py index aee363e..f61328c 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -160,24 +160,46 @@ def __init__(self): self.x_t_latent_buffer = None self.stock_noise = None self.is_txt2img_mode = False + + # Initialize all optimization buffers as None + self.zeros_reference = None + self.random_noise_buffer = None + self.sigmas_view_buffer = None + self.expanded_stock_noise = None + self.working_buffer = None + self.output_buffer = None @profile_time @profile_cprofile def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): - """Sample with staggered batch denoising steps""" + """Sample with staggered batch denoising steps - Optimized version""" extra_args = {} if extra_args is None else extra_args # Get number of frames in batch and available sigmas batch_size = noise.shape[0] num_sigmas = len(sigmas) - 1 # Subtract 1 because last sigma is the target (0.0) - # Detect if we're in text-to-image mode by checking if noise is all zeros - # This happens when empty latents are provided - self.is_txt2img_mode = torch.allclose(noise, torch.zeros_like(noise), atol=1e-6) + # Optimization 1: Reuse zeros buffer for txt2img detection + if self.zeros_reference is None: + # We only need a small reference tensor for comparison, not a full tensor + self.zeros_reference = torch.zeros(1, device=noise.device, dtype=noise.dtype) + + # Check if noise tensor is all zeros - functionally identical but more efficient + self.is_txt2img_mode = torch.abs(noise).sum() < 1e-5 + # Noise handling with memory optimization if self.is_txt2img_mode: - # For text-to-image, we'll use pure random noise - noise = torch.randn_like(noise) + # Optimization 2: If txt2img mode, reuse the noise tensor directly + # instead of allocating new memory + if self.random_noise_buffer is None or self.random_noise_buffer.shape != noise.shape: + self.random_noise_buffer = torch.empty_like(noise) + + # Generate random noise in-place + self.random_noise_buffer.normal_() + x = self.random_noise_buffer # Use pre-allocated buffer + else: + # If not txt2img, we'll still need to add noise later + x = noise # No need to copy, will add noise later # Verify batch size matches number of timesteps if batch_size != num_sigmas: @@ -187,29 +209,46 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N alpha_prod_t = (sigmas[:-1] / sigmas[0]).view(-1, 1, 1, 1) # [B,1,1,1] beta_prod_t = (1 - alpha_prod_t) + # Optimization 3: Initialize stock noise with reuse + if self.stock_noise is None or self.stock_noise.shape != noise[0].shape: + self.stock_noise = torch.empty_like(noise[0]) + self.stock_noise.normal_() # Generate random noise in-place + + # Optimization 4: Pre-allocate and reuse view buffer for sigmas + if self.sigmas_view_buffer is None or self.sigmas_view_buffer.shape[0] != len(sigmas)-1: + self.sigmas_view_buffer = torch.empty((len(sigmas)-1, 1, 1, 1), + device=sigmas.device, + dtype=sigmas.dtype) + # In-place copy of sigmas view + self.sigmas_view_buffer.copy_(sigmas[:-1].view(-1, 1, 1, 1)) + + # Optimization 5: Eliminate unsqueeze allocation by pre-expanding stock noise + if self.expanded_stock_noise is None or self.expanded_stock_noise.shape[0] != batch_size: + self.expanded_stock_noise = self.stock_noise.expand(batch_size, *self.stock_noise.shape) + + # Apply noise with pre-allocated buffers - no new memory allocation + if not self.is_txt2img_mode: # Already handled txt2img case above + # If we need a working buffer separate from noise input: + if id(x) == id(noise): # They're the same object, need a separate buffer + if self.working_buffer is None or self.working_buffer.shape != noise.shape: + self.working_buffer = torch.empty_like(noise) + x = self.working_buffer + # Add noise to input + torch.add(noise, self.expanded_stock_noise * self.sigmas_view_buffer, out=x) + + # Initialize and manage latent buffer with memory optimization + if (self.x_t_latent_buffer is None or self.is_txt2img_mode) and num_sigmas > 1: + # Optimization 6: Pre-allocate or resize as needed + if self.x_t_latent_buffer is None or self.x_t_latent_buffer.shape != x[0].shape: + self.x_t_latent_buffer = torch.empty_like(x[0]) + # In-place copy instead of clone + self.x_t_latent_buffer.copy_(x[0]) - # Initialize stock noise if needed - if self.stock_noise is None or self.is_txt2img_mode: # Kept original condition for functional equivalence - self.stock_noise = torch.randn_like(noise[0]) # Random noise instead of zeros - - # Optimization: Vectorize noise scaling instead of looping - sigmas_view = sigmas[:-1].view(-1, 1, 1, 1) # Reshape for broadcasting - if self.is_txt2img_mode: - x = self.stock_noise.unsqueeze(0) * sigmas_view # Broadcast stock_noise across batch - else: - x = noise + self.stock_noise.unsqueeze(0) * sigmas_view # Add scaled noise to input - - # Initialize frame buffer if needed - if (self.x_t_latent_buffer is None or self.is_txt2img_mode) and num_sigmas > 1: # Kept original condition - # Optimization: Pre-allocate and copy instead of clone - self.x_t_latent_buffer = torch.empty_like(x[0]) # Pre-allocate memory - self.x_t_latent_buffer.copy_(x[0]) # In-place copy - # Use buffer for first frame to maintain temporal consistency if num_sigmas > 1: - # Optimization: Update in-place instead of concatenating - x[0] = self.x_t_latent_buffer # Replace first frame with buffer - + # In-place update - no new allocation + x[0].copy_(self.x_t_latent_buffer) + # Run model on entire batch at once with torch.no_grad(): # Process all frames in parallel @@ -220,15 +259,20 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N # Update buffer with intermediate results if num_sigmas > 1: # Store result from first frame as buffer for next iteration - # Optimization: Use in-place copy instead of clone - self.x_t_latent_buffer.copy_(denoised_batch[0]) # Update buffer in-place + self.x_t_latent_buffer.copy_(denoised_batch[0]) # In-place update - # Return result from last frame - x_0_pred_out = denoised_batch[-1].unsqueeze(0) + # Optimization 7: Pre-allocate output buffer + if self.output_buffer is None or self.output_buffer.shape != (1, *denoised_batch[-1].shape): + self.output_buffer = torch.empty(1, *denoised_batch[-1].shape, + device=denoised_batch.device, + dtype=denoised_batch.dtype) + # Copy the result directly to pre-allocated buffer + self.output_buffer[0].copy_(denoised_batch[-1]) + x_0_pred_out = self.output_buffer else: x_0_pred_out = denoised_batch self.x_t_latent_buffer = None - + # Call callback if provided if callback is not None: callback({'x': x_0_pred_out, 'i': 0, 'sigma': sigmas[0], 'sigma_hat': sigmas[0], 'denoised': denoised_batch[-1:]}) From 7ea7e51883cfac6f0a72d8ec6ba9218bef66e0b4 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:38:49 -0400 Subject: [PATCH 15/19] sans profiling --- stream_sampler.py | 147 ++-------------------------------------------- 1 file changed, 6 insertions(+), 141 deletions(-) diff --git a/stream_sampler.py b/stream_sampler.py index f61328c..e9cb1a2 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -4,134 +4,6 @@ import random import time import os -import json -from datetime import datetime -from functools import wraps -import cProfile -import pstats -import io - -# Simple profiling setup - SINGLE FILE -PROFILE_FILE = "stream_sampler_profile.json" -PROFILE_DATA = [] - -def profile_time(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Start timing - start_time = time.time() - - # Check memory before - if torch.cuda.is_available(): - torch.cuda.synchronize() - mem_before = torch.cuda.memory_allocated() / (1024 * 1024) - - # Execute the function - result = func(*args, **kwargs) - - # End timing - if torch.cuda.is_available(): - torch.cuda.synchronize() - end_time = time.time() - - # Check memory after - if torch.cuda.is_available(): - mem_after = torch.cuda.memory_allocated() / (1024 * 1024) - mem_diff = mem_after - mem_before - else: - mem_diff = 0 - - exec_time = end_time - start_time - - # Get shape info if available - shape_info = "unknown" - if len(args) > 1 and hasattr(args[1], 'shape'): # If this is sample(), args[1] is noise - shape_info = str(args[1].shape) - - # Log results - entry = { - 'function': func.__name__, - 'execution_time': exec_time, - 'memory_diff_mb': mem_diff, - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'shape': shape_info - } - - PROFILE_DATA.append(entry) - - # Print to console - print(f"PROFILE: {func.__name__} - Time: {exec_time:.4f}s, Memory: {mem_diff:.2f}MB") - - # Save to file after each run - save_profile_data() - - return result - return wrapper - - -def profile_cprofile(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Create profiler - pr = cProfile.Profile() - pr.enable() - - # Execute the function - result = func(*args, **kwargs) - - # Disable profiler - pr.disable() - - # Get stats - s = io.StringIO() - ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') - ps.print_stats(20) - - # Add to the most recent profile entry - if PROFILE_DATA: - PROFILE_DATA[-1]['cprofile_data'] = s.getvalue() - - return result - return wrapper - - -def save_profile_data(): - """Save all profiling data to a single JSON file""" - if not PROFILE_DATA: - return - - # Convert to serializable format - json_data = [] - for entry in PROFILE_DATA: - serializable_entry = { - 'function': entry['function'], - 'execution_time': float(entry['execution_time']), - 'memory_diff_mb': float(entry['memory_diff_mb']), - 'timestamp': entry['timestamp'], - 'shape': entry.get('shape', 'unknown') - } - - if 'cprofile_data' in entry: - # Only store recent cprofile data to keep file size manageable - if len(json_data) < 10 or len(json_data) % 10 == 0: - serializable_entry['cprofile_data'] = entry['cprofile_data'] - - json_data.append(serializable_entry) - - # Write to a single file - with open(PROFILE_FILE, 'w') as f: - json.dump(json_data, f, indent=2) - - # Calculate stats - sample_times = [x['execution_time'] for x in PROFILE_DATA if x['function'] == 'sample'] - if sample_times: - avg_time = sum(sample_times) / len(sample_times) - min_time = min(sample_times) - max_time = max(sample_times) - print(f"PROFILE SUMMARY: {len(sample_times)} runs, Avg: {avg_time:.4f}s, Min: {min_time:.4f}s, Max: {max_time:.4f}s") - - print(f"Profiling data saved to {PROFILE_FILE}") - class StreamBatchSampler: @@ -169,8 +41,6 @@ def __init__(self): self.working_buffer = None self.output_buffer = None - @profile_time - @profile_cprofile def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=None): """Sample with staggered batch denoising steps - Optimized version""" extra_args = {} if extra_args is None else extra_args @@ -279,16 +149,12 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N return x_0_pred_out - @profile_time def update(self, num_steps=4): """Create sampler with specified settings""" self.num_steps = num_steps sampler = comfy.samplers.KSAMPLER(self.sample) return (sampler,) -# Print setup info when module is imported -print(f"StreamBatchSampler profiling enabled. Results will be saved to {PROFILE_FILE}") - class StreamScheduler: @@ -323,7 +189,7 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50): try: t_index_list = [int(t.strip()) for t in t_index_list.split(",")] except ValueError as e: - print(f"Error parsing timesteps: {e}. Using default [32,45]") + t_index_list = [32, 45] # Create full schedule using normal scheduler @@ -334,7 +200,7 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50): selected_sigmas = [] for t in sorted(t_index_list, reverse=True): # Sort in reverse to go from high noise to low if t < 0 or t >= num_inference_steps: - print(f"Warning: timestep {t} out of range [0,{num_inference_steps}), skipping") + continue selected_sigmas.append(float(full_sigmas[t])) @@ -387,7 +253,7 @@ def update(self, latent, buffer_size=4): if is_empty_latent: self.is_txt2img_mode = True - print(f"[StreamFrameBuffer] Detected empty latent for text-to-image mode") + # Create empty latents with correct dimensions for txt2img # Get dimensions from latent dict height = latent.get("height", 512) @@ -400,7 +266,7 @@ def update(self, latent, buffer_size=4): # Create zero tensor with correct shape x = torch.zeros((4, latent_height, latent_width), device=comfy.model_management.get_torch_device()) - print(f"[StreamFrameBuffer] Created empty latent with shape: {x.shape}") + elif x.dim() == 4: # [B,C,H,W] self.is_txt2img_mode = False x = x.squeeze(0) # Remove batch dimension -> [C,H,W] @@ -417,18 +283,17 @@ def update(self, latent, buffer_size=4): if self.is_txt2img_mode or not self.is_initialized: # Optimization: Use broadcasting to fill buffer with copies self.frame_buffer[:] = x.unsqueeze(0) # Broadcast x to [buffer_size, C, H, W] - print(f"[StreamFrameBuffer] Initialized buffer with {self.buffer_size} copies of frame") + self.is_initialized = True self.buffer_pos = 0 else: # Add new frame to buffer using ring buffer logic self.frame_buffer[self.buffer_pos] = x # In-place update - print(f"[StreamFrameBuffer] Added new frame to buffer at position {self.buffer_pos}") self.buffer_pos = (self.buffer_pos + 1) % self.buffer_size # Circular increment # Optimization: frame_buffer is already a tensor batch, no need to stack batch = self.frame_buffer - print(f"[StreamFrameBuffer] Created batch with shape: {batch.shape}") + # Return as latent dict with preserved dimensions result = {"samples": batch} From d9e05164faf88e26e3ef54c0ab388da850d17842 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 11 Mar 2025 18:59:09 -0400 Subject: [PATCH 16/19] add particle nodes --- __init__.py | 5 +- particle_nodes.py | 147 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 particle_nodes.py diff --git a/__init__.py b/__init__.py index 040a383..9e79adc 100644 --- a/__init__.py +++ b/__init__.py @@ -7,6 +7,7 @@ from .stream_sampler import StreamBatchSampler, StreamScheduler, StreamFrameBuffer from .stream_cfg import StreamCFG from .stream_conditioning import StreamConditioning +from .particle_nodes import DepthMapWarpNode import re @@ -16,7 +17,7 @@ "IntControl": IntControl, "StringControl": StringControl, "FloatSequence": FloatSequence, - "IntSequence": IntSequence, + "IntSequence": IntSequence, "StringSequence": StringSequence, "FPSMonitor": FPSMonitor, "SimilarityFilter": SimilarityFilter, @@ -31,7 +32,7 @@ "IntegerMotionController": IntegerMotionController, "YOLOSimilarityCompare": YOLOSimilarityCompare, "TextRenderer": TextRenderer, - + "DepthMapWarpNode": DepthMapWarpNode, "ROINode": ROINode, diff --git a/particle_nodes.py b/particle_nodes.py new file mode 100644 index 0000000..35074bf --- /dev/null +++ b/particle_nodes.py @@ -0,0 +1,147 @@ +import torch +import numpy as np +from .base.control_base import ControlNodeBase + +class DepthMapWarpNode(ControlNodeBase): + """ + A node that warps any depth map with a radial stretch and optional pulsing effect, designed for gamepad control. + """ + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "depth_map": ("IMAGE", { + "tooltip": "Input depth map to warp (BHWC format, any content)" + }), + "center_x": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "display": "slider", + "tooltip": "X coordinate of the warp center (0-1), e.g., left joystick X" + }), + "center_y": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "display": "slider", + "tooltip": "Y coordinate of the warp center (0-1), e.g., left joystick Y" + }), + "stretch_strength": ("FLOAT", { + "default": 0.0, + "min": -2.0, + "max": 2.0, + "step": 0.01, + "tooltip": "Strength of the stretch (positive = outward, negative = inward), e.g., right trigger - left trigger" + }), + "falloff": ("FLOAT", { + "default": 2.0, + "min": 0.1, + "max": 5.0, + "step": 0.1, + "tooltip": "Controls how quickly the stretch diminishes, e.g., right joystick Y" + }), + "pulse_frequency": ("FLOAT", { + "default": 0.0, + "min": 0.0, + "max": 5.0, + "step": 0.1, + "tooltip": "Frequency of pulsing effect (0 = off), e.g., right joystick X" + }), + }, + "optional": { + "mask": ("MASK", { + "tooltip": "Optional mask to limit where the warp is applied (0-1 values)" + }), + "mode": (["radial", "stretch", "bend", "wave"], {"default": "radial", "tooltip": "Warp mode"}), + + } + } + return inputs + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "update" + CATEGORY = "image/transforms" + + def __init__(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.frame_count = 0 # For pulsing effect + + def update(self, depth_map, center_x, center_y, stretch_strength, falloff, pulse_frequency, mode, mask=None): + self.frame_count += 1 + depth_map = depth_map.to(self.device) + batch_size, height, width, channels = depth_map.shape + + # Create coordinate grid + y = torch.linspace(0, 1, height, device=self.device) + x = torch.linspace(0, 1, width, device=self.device) + y_grid, x_grid = torch.meshgrid(y, x, indexing='ij') + dx = x_grid - center_x + dy = y_grid - center_y + distance = torch.sqrt(dx**2 + dy**2) + max_distance = torch.tensor(1.414, device=self.device) + distance = distance / max_distance + + # Apply pulsing effect if enabled + effective_strength = stretch_strength + if pulse_frequency > 0: + pulse = torch.sin(torch.tensor(self.frame_count * pulse_frequency * 0.1, device=self.device)) + effective_strength = stretch_strength * (1 + 0.5 * pulse) + + # Compute warp based on mode + if mode == "radial": # Original radial stretch + displacement = effective_strength * torch.exp(-falloff * distance) + warp_x = x_grid + dx * displacement + warp_y = y_grid + dy * displacement + + elif mode == "stretch": # Vertical or horizontal stretch + if effective_strength > 0: # Positive = vertical stretch + displacement = effective_strength * torch.exp(-falloff * torch.abs(dy)) # Stretch along y-axis + warp_x = x_grid + warp_y = y_grid + dy * displacement + else: # Negative = horizontal stretch + displacement = -effective_strength * torch.exp(-falloff * torch.abs(dx)) # Stretch along x-axis + warp_x = x_grid + dx * displacement + warp_y = y_grid + + elif mode == "bend": # Bend left or right from vertical axis + bend_factor = effective_strength * torch.exp(-falloff * torch.abs(dx)) # Bend strength decreases with x-distance + warp_x = x_grid + bend_factor * (y_grid - center_y)**2 # Quadratic bend along y + warp_y = y_grid + + elif mode == "wave": # Horizontal traveling wave + wave = effective_strength * torch.sin(10 * (x_grid - center_x) + self.frame_count * 0.2) * torch.exp(-falloff * distance) + warp_x = x_grid + wave + warp_y = y_grid + + # Clamp coordinates to [0, 1] + warp_x = torch.clamp(warp_x, 0, 1) + warp_y = torch.clamp(warp_y, 0, 1) + + # Convert to grid sample coordinates (-1 to 1) + grid = torch.stack((warp_x * 2 - 1, warp_y * 2 - 1), dim=-1) + grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + + # Sample the warped depth map + warped_depth = torch.nn.functional.grid_sample( + depth_map.permute(0, 3, 1, 2), # BCHW + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=True + ) + warped_depth = warped_depth.permute(0, 2, 3, 1) # BHWC + + # Apply mask if provided + if mask is not None: + mask = mask.to(self.device) + if mask.dim() == 3: # BHW + mask = mask.unsqueeze(-1) + elif mask.dim() == 2: # HW + mask = mask.unsqueeze(0).unsqueeze(-1) + warped_depth = depth_map * (1 - mask) + warped_depth * mask + + return (warped_depth.clamp(0, 1),) \ No newline at end of file From fb8c84f84480d885799899d554cb046ae8a506bd Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Fri, 14 Mar 2025 00:38:15 +0000 Subject: [PATCH 17/19] update --- .ipynb_checkpoints/__init__-checkpoint.py | 69 ++++ .../particle_nodes-checkpoint.py | 350 ++++++++++++++++++ __init__.py | 2 + dimensions.py | 24 ++ stream_sampler.py | 120 +++--- 5 files changed, 504 insertions(+), 61 deletions(-) create mode 100644 .ipynb_checkpoints/__init__-checkpoint.py create mode 100644 .ipynb_checkpoints/particle_nodes-checkpoint.py create mode 100644 dimensions.py diff --git a/.ipynb_checkpoints/__init__-checkpoint.py b/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000..fafac66 --- /dev/null +++ b/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,69 @@ +from .controls.value_controls import FloatControl, IntControl, StringControl +from .controls.sequence_controls import FloatSequence, IntSequence, StringSequence +from .controls.utility_controls import FPSMonitor, SimilarityFilter, LazyCondition +from .controls.motion_controls import MotionController, ROINode, IntegerMotionController +from .misc_nodes import DTypeConverter, FastWebcamCapture, YOLOSimilarityCompare, TextRenderer, QuickShapeMask, MultilineText, LoadImageFromPath_ +from .stream_diffusion_nodes import StreamCrossAttention +from .stream_sampler import StreamBatchSampler, StreamScheduler, StreamFrameBuffer +from .stream_cfg import StreamCFG +from .stream_conditioning import StreamConditioning +import re +from .particle_nodes import TemporalParticleDepthNode + + +NODE_CLASS_MAPPINGS = { + "FloatControl": FloatControl, + "IntControl": IntControl, + "StringControl": StringControl, + "FloatSequence": FloatSequence, + "IntSequence": IntSequence, + "StringSequence": StringSequence, + "FPSMonitor": FPSMonitor, + "SimilarityFilter": SimilarityFilter, + "StreamCFG": StreamCFG, + "StreamConditioning": StreamConditioning, + "StreamBatchSampler": StreamBatchSampler, + "StreamScheduler": StreamScheduler, + "StreamFrameBuffer": StreamFrameBuffer, + "StreamCrossAttention": StreamCrossAttention, + "LazyCondition": LazyCondition, + "MotionController": MotionController, + "IntegerMotionController": IntegerMotionController, + "YOLOSimilarityCompare": YOLOSimilarityCompare, + "TextRenderer": TextRenderer, + + + "ROINode": ROINode, + + #"IntervalControl": IntervalCo ntrol, + #"DeltaControl": DeltaControl, + "QuickShapeMask": QuickShapeMask, + "DTypeConverter": DTypeConverter, + "FastWebcamCapture": FastWebcamCapture, + "MultilineText": MultilineText, + "LoadImageFromPath_": LoadImageFromPath_, + "TemporalParticleDepthNode": TemporalParticleDepthNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = {} + +suffix = " 🕒🅡🅣🅝" + +for node_name in NODE_CLASS_MAPPINGS.keys(): + # Convert camelCase or snake_case to Title Case + if node_name not in NODE_DISPLAY_NAME_MAPPINGS: + display_name = ' '.join(word.capitalize() for word in re.findall(r'[A-Z]?[a-z]+|[A-Z]+(?=[A-Z][a-z]|\d|\W|$)|\d+', node_name)) + else: + display_name = NODE_DISPLAY_NAME_MAPPINGS[node_name] + + # Add the suffix if it's not already present + if not display_name.endswith(suffix): + display_name += suffix + + # Assign the final display name to the mappings + NODE_DISPLAY_NAME_MAPPINGS[node_name] = display_name + + +WEB_DIRECTORY = "./web/js" + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file diff --git a/.ipynb_checkpoints/particle_nodes-checkpoint.py b/.ipynb_checkpoints/particle_nodes-checkpoint.py new file mode 100644 index 0000000..fc2a55c --- /dev/null +++ b/.ipynb_checkpoints/particle_nodes-checkpoint.py @@ -0,0 +1,350 @@ +import torch +import numpy as np +from .base.control_base import ControlNodeBase + +class TemporalParticleDepthNode(ControlNodeBase): + """ + A node that generates temporally consistent particle depth maps with a controllable origin point. + """ + @classmethod + def INPUT_TYPES(cls): + inputs = super().INPUT_TYPES() + inputs["required"].update({ + "origin_x": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "display": "slider", + "tooltip": "X coordinate of the origin point (0-1)" + }), + "origin_y": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "display": "slider", + "tooltip": "Y coordinate of the origin point (0-1)" + }), + "width": ("INT", { + "default": 512, + "min": 64, + "max": 2048, + "step": 8, + "tooltip": "Width of the output depth map" + }), + "height": ("INT", { + "default": 512, + "min": 64, + "max": 2048, + "step": 8, + "tooltip": "Height of the output depth map" + }), + "num_particles": ("INT", { + "default": 100, + "min": 10, + "max": 1000, + "step": 10, + "tooltip": "Number of particles to simulate" + }), + "particle_size": ("FLOAT", { + "default": 0.03, + "min": 0.001, + "max": 0.2, + "step": 0.001, + "tooltip": "Size of each particle (relative to image size)" + }), + "speed": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 3.0, + "step": 0.1, + "display": "slider", + "tooltip": "Overall movement speed of particles" + }), + "inertia_factor": ("FLOAT", { + "default": 0.1, + "min": 0.01, + "max": 0.5, + "step": 0.01, + "display": "slider", + "tooltip": "Controls how quickly particles respond to changes (lower = more inertia)" + }), + "distance_response": ("FLOAT", { + "default": 0.15, + "min": -0.3, + "max": 0.3, + "step": 0.01, + "display": "slider", + "tooltip": "How distance affects response speed (positive = closer particles respond faster, negative = distant particles respond faster, zero = uniform response)" + }), + }) + inputs["optional"] = { + "hand_data": ("HAND_DATA", { + "tooltip": "Optional hand tracking data to use for origin point" + }), + "hand_keypoint": (["palm", "thumb_tip", "index_tip", "middle_tip", "ring_tip", "pinky_tip"], { + "default": "index_tip", + "tooltip": "Which hand keypoint to use as origin (if hand_data is provided)" + }), + } + return inputs + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "update" + CATEGORY = "image/generators" + + def __init__(self): + super().__init__() + self.generator = None + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def update(self, origin_x, origin_y, width, height, num_particles, particle_size, speed, inertia_factor, distance_response, always_execute=True, hand_data=None, hand_keypoint="index_tip"): + # Initialize generator if needed or if parameters changed + if self.generator is None: + self.generator = TemporalParticleDepthGenerator( + batch_size=1, + height=height, + width=width, + num_particles=num_particles, + initial_origin_x=origin_x, + initial_origin_y=origin_y, + particle_size=particle_size, + speed=speed, + inertia_factor=inertia_factor, + distance_response=distance_response, + device=self.device + ) + elif (self.generator.shape[1] != height or + self.generator.shape[2] != width or + self.generator.num_particles != num_particles): + # Reinitialize if dimensions or particle count changed + self.generator = TemporalParticleDepthGenerator( + batch_size=1, + height=height, + width=width, + num_particles=num_particles, + initial_origin_x=origin_x, + initial_origin_y=origin_y, + particle_size=particle_size, + speed=speed, + inertia_factor=inertia_factor, + distance_response=distance_response, + device=self.device + ) + + # Use hand data if provided + if hand_data is not None and len(hand_data) > 0: + # Extract the first hand's data + hand_info = hand_data[0] + + # Check if hands are present + if hand_info["hands_present"]: + # Determine which hand to use (prefer right hand if available) + hand_landmarks = None + if hand_info["right_hand"] is not None: + hand_landmarks = hand_info["right_hand"] + elif hand_info["left_hand"] is not None: + hand_landmarks = hand_info["left_hand"] + + if hand_landmarks is not None: + # Map keypoint to index + keypoint_map = { + "palm": 0, # Center of palm + "thumb_tip": 4, # Thumb tip + "index_tip": 8, # Index finger tip + "middle_tip": 12, # Middle finger tip + "ring_tip": 16, # Ring finger tip + "pinky_tip": 20 # Pinky tip + } + + if hand_keypoint in keypoint_map: + keypoint_idx = keypoint_map[hand_keypoint] + # Get normalized coordinates (already in 0-1 range) + if keypoint_idx < len(hand_landmarks): + # X and Y coordinates are at indices 0 and 1 + origin_x = hand_landmarks[keypoint_idx][0] + origin_y = hand_landmarks[keypoint_idx][1] + + # Generate depth map with current origin + depth_map = self.generator.update( + new_origin_x=origin_x, + new_origin_y=origin_y, + particle_size=particle_size, + speed=speed, + inertia_factor=inertia_factor, + distance_response=distance_response + ) + + # ComfyUI already uses BHWC format, so no permute needed + return (depth_map,) + + +import torch + +class TemporalParticleDepthGenerator: + def __init__(self, + batch_size: int, + height: int, + width: int, + num_particles: int = 100, + initial_origin_x: float = 0.5, + initial_origin_y: float = 0.5, + particle_size: float = 0.03, + speed: float = 1.0, + inertia_factor: float = 0.1, + distance_response: float = 0.15, + device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): + """ + Initialize a temporally consistent particle depth map generator with inertial origin movement. + """ + self.device = torch.device(device) + self.shape = (batch_size, height, width) + self.particle_size = particle_size + self.speed = speed + self.base_inertia_factor = inertia_factor + self.distance_response = distance_response + + # Coordinate grid (static) + y = torch.linspace(0, 1, height, device=device) + x = torch.linspace(0, 1, width, device=device) + self.y_grid, self.x_grid = torch.meshgrid(y, x, indexing='ij') + + # Persistent particle properties + self.num_particles = num_particles + self.angles = torch.rand(num_particles, device=device) * 2 * torch.pi + self.base_speeds = torch.rand(num_particles, device=device) * 0.02 + 0.005 + self.speeds = self.base_speeds * self.speed + self.distances = torch.zeros(num_particles, device=device) + + # Origin (current and target) + self.origin_x = torch.tensor(initial_origin_x, device=device) + self.origin_y = torch.tensor(initial_origin_y, device=device) + self.target_origin_x = self.origin_x.clone() + self.target_origin_y = self.origin_y.clone() + + # Particle-specific inertia factors (will be computed based on distance) + self.inertia_factors = torch.ones(num_particles, device=device) * inertia_factor + + # For respawning particles + self.active = torch.ones(num_particles, dtype=torch.bool, device=device) + + # Frame counters + self.frame_count = 0 + self.last_major_refresh = 0 + + def update(self, + new_origin_x: float = None, + new_origin_y: float = None, + particle_size: float = None, + speed: float = None, + inertia_factor: float = None, + distance_response: float = None) -> torch.Tensor: + """ + Update particle positions and generate next depth map frame with inertial origin movement. + + Args: + new_origin_x (float, optional): Target origin X (0-1) + new_origin_y (float, optional): Target origin Y (0-1) + particle_size (float, optional): Update particle size + speed (float, optional): Update overall movement speed + inertia_factor (float, optional): Update inertia factor + distance_response (float, optional): Update distance response factor + + Returns: + torch.Tensor: Depth map (B, H, W, 3) + """ + self.frame_count += 1 + + # Update target origin if provided + if new_origin_x is not None: + self.target_origin_x = torch.tensor(new_origin_x, device=self.device) + if new_origin_y is not None: + self.target_origin_y = torch.tensor(new_origin_y, device=self.device) + if particle_size is not None: + self.particle_size = particle_size + if speed is not None: + self.speed = speed + # Update actual speeds based on base speeds and speed multiplier + self.speeds = self.base_speeds * self.speed + if inertia_factor is not None: + self.base_inertia_factor = inertia_factor + if distance_response is not None: + self.distance_response = distance_response + + # Periodic randomization + if self.frame_count - self.last_major_refresh >= 200: + refresh_mask = torch.rand(self.num_particles, device=self.device) < 0.3 + if refresh_mask.any(): + num_refresh = refresh_mask.sum() + self.distances[refresh_mask] = torch.rand(num_refresh, device=self.device) * 0.3 + self.angles[refresh_mask] = torch.rand(num_refresh, device=self.device) * 2 * torch.pi + self.base_speeds[refresh_mask] = torch.rand(num_refresh, device=self.device) * 0.02 + 0.005 + self.speeds[refresh_mask] = self.base_speeds[refresh_mask] * self.speed + + speed_adjust = (torch.rand(self.num_particles, device=self.device) * 0.01) - 0.005 + self.base_speeds = torch.clamp(self.base_speeds + speed_adjust, 0.002, 0.03) + self.speeds = self.base_speeds * self.speed + self.last_major_refresh = self.frame_count + + # Update distances + self.distances = self.distances + self.speeds + + # Respawn off-screen particles + off_screen = self.distances > 1.414 + if off_screen.any(): + num_respawn = off_screen.sum() + self.distances[off_screen] = torch.rand(num_respawn, device=self.device) * 0.1 + self.angles[off_screen] = torch.rand(num_respawn, device=self.device) * 2 * torch.pi + self.speeds[off_screen] = torch.rand(num_respawn, device=self.device) * 0.02 + 0.005 + + # Calculate current particle positions (before origin update) + particle_x = self.origin_x + torch.cos(self.angles) * self.distances + particle_y = self.origin_y + torch.sin(self.angles) * self.distances + + # Compute distance-based inertia factors + particle_distances = torch.sqrt((particle_x - self.origin_x)**2 + (particle_y - self.origin_y)**2) + + # Calculate inertia based on distance and user parameters + # When distance_response is positive, closer particles respond faster + # When negative, distant particles respond faster + # When zero, all particles respond uniformly + base_response = self.base_inertia_factor * 2 + distance_effect = particle_distances * self.distance_response * self.base_inertia_factor + + if self.distance_response >= 0: + # Positive: closer particles respond faster (subtract distance effect) + self.inertia_factors = torch.clamp( + base_response - distance_effect, + 0.02 * self.base_inertia_factor, + 0.2 * self.base_inertia_factor + ) + else: + # Negative: distant particles respond faster (add distance effect) + self.inertia_factors = torch.clamp( + base_response + distance_effect, + 0.02 * self.base_inertia_factor, + 0.2 * self.base_inertia_factor + ) + + # Interpolate origin toward target with inertia + self.origin_x = self.origin_x + self.inertia_factors * (self.target_origin_x - self.origin_x) + self.origin_y = self.origin_y + self.inertia_factors * (self.target_origin_y - self.origin_y) + + # Recalculate particle positions with updated origin + particle_x = self.origin_x + torch.cos(self.angles) * self.distances + particle_y = self.origin_y + torch.sin(self.angles) * self.distances + + # Generate depth map + depth = torch.zeros(self.shape, device=self.device) + for i in range(self.num_particles): + if not self.active[i]: + continue + + dist = torch.sqrt((self.x_grid - particle_x[i])**2 + + (self.y_grid - particle_y[i])**2) + particle = torch.exp(-(dist**2) / (2 * self.particle_size**2)) + depth[0] = torch.maximum(depth[0], particle * self.distances[i]) + + depth = depth / (depth.max() + 1e-6) + return depth.unsqueeze(-1).expand(-1, -1, -1, 3).clamp(0, 1) \ No newline at end of file diff --git a/__init__.py b/__init__.py index 9e79adc..eb112a5 100644 --- a/__init__.py +++ b/__init__.py @@ -8,6 +8,7 @@ from .stream_cfg import StreamCFG from .stream_conditioning import StreamConditioning from .particle_nodes import DepthMapWarpNode +from .dimensions import ImageDimensions import re @@ -43,6 +44,7 @@ "FastWebcamCapture": FastWebcamCapture, "MultilineText": MultilineText, "LoadImageFromPath_": LoadImageFromPath_, + "ImageDimensions": ImageDimensions, } NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/dimensions.py b/dimensions.py new file mode 100644 index 0000000..db98845 --- /dev/null +++ b/dimensions.py @@ -0,0 +1,24 @@ +class ImageDimensions: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",)}} + + RETURN_TYPES = ("INT", "INT", "INT") + RETURN_NAMES = ("width", "height", "count") + FUNCTION = "get_dimensions" + CATEGORY = "image/info" + + def get_dimensions(self, image): + count = image.shape[0] # Batch size + height = image.shape[1] + width = image.shape[2] + print(f"Image dimensions: {width}x{height} (batch size: {count})") + return (width, height, count) + +NODE_CLASS_MAPPINGS = { + "ImageDimensions": ImageDimensions +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageDimensions": "Get Image Dimensions" +} \ No newline at end of file diff --git a/stream_sampler.py b/stream_sampler.py index e9cb1a2..6cb5f2e 100644 --- a/stream_sampler.py +++ b/stream_sampler.py @@ -1,17 +1,9 @@ import torch import comfy.model_management import comfy.samplers -import random -import time -import os class StreamBatchSampler: - - RETURN_TYPES = ("SAMPLER",) - FUNCTION = "update" - CATEGORY = "StreamPack/sampling" - DESCRIPTION = "Implements batched denoising for faster inference by processing multiple frames in parallel at different denoising steps. Also adds temportal consistency to the denoising process." @classmethod def INPUT_TYPES(cls): return { @@ -26,6 +18,11 @@ def INPUT_TYPES(cls): }, } + RETURN_TYPES = ("SAMPLER",) + FUNCTION = "update" + CATEGORY = "StreamPack/sampling" + DESCRIPTION = "Implements batched denoising for faster inference by processing multiple frames in parallel at different denoising steps. Also adds temportal consistency to the denoising process." + def __init__(self): self.num_steps = None self.frame_buffer = [] @@ -33,7 +30,7 @@ def __init__(self): self.stock_noise = None self.is_txt2img_mode = False - # Initialize all optimization buffers as None + # Initialize all buffers self.zeros_reference = None self.random_noise_buffer = None self.sigmas_view_buffer = None @@ -49,18 +46,16 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N batch_size = noise.shape[0] num_sigmas = len(sigmas) - 1 # Subtract 1 because last sigma is the target (0.0) - # Optimization 1: Reuse zeros buffer for txt2img detection + + # Reuse zeros buffer for txt2img detection if self.zeros_reference is None: - # We only need a small reference tensor for comparison, not a full tensor self.zeros_reference = torch.zeros(1, device=noise.device, dtype=noise.dtype) - # Check if noise tensor is all zeros - functionally identical but more efficient + # Check if noise tensor is all zeros self.is_txt2img_mode = torch.abs(noise).sum() < 1e-5 - # Noise handling with memory optimization if self.is_txt2img_mode: - # Optimization 2: If txt2img mode, reuse the noise tensor directly - # instead of allocating new memory + # If txt2img mode, reuse the noise tensor directly if self.random_noise_buffer is None or self.random_noise_buffer.shape != noise.shape: self.random_noise_buffer = torch.empty_like(noise) @@ -69,7 +64,7 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N x = self.random_noise_buffer # Use pre-allocated buffer else: # If not txt2img, we'll still need to add noise later - x = noise # No need to copy, will add noise later + x = noise # Verify batch size matches number of timesteps if batch_size != num_sigmas: @@ -79,12 +74,12 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N alpha_prod_t = (sigmas[:-1] / sigmas[0]).view(-1, 1, 1, 1) # [B,1,1,1] beta_prod_t = (1 - alpha_prod_t) - # Optimization 3: Initialize stock noise with reuse + # Initialize stock noise with reuse if self.stock_noise is None or self.stock_noise.shape != noise[0].shape: self.stock_noise = torch.empty_like(noise[0]) self.stock_noise.normal_() # Generate random noise in-place - # Optimization 4: Pre-allocate and reuse view buffer for sigmas + # Pre-allocate and reuse view buffer for sigmas if self.sigmas_view_buffer is None or self.sigmas_view_buffer.shape[0] != len(sigmas)-1: self.sigmas_view_buffer = torch.empty((len(sigmas)-1, 1, 1, 1), device=sigmas.device, @@ -92,7 +87,7 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N # In-place copy of sigmas view self.sigmas_view_buffer.copy_(sigmas[:-1].view(-1, 1, 1, 1)) - # Optimization 5: Eliminate unsqueeze allocation by pre-expanding stock noise + # Eliminate unsqueeze allocation by pre-expanding stock noise if self.expanded_stock_noise is None or self.expanded_stock_noise.shape[0] != batch_size: self.expanded_stock_noise = self.stock_noise.expand(batch_size, *self.stock_noise.shape) @@ -108,18 +103,15 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N # Initialize and manage latent buffer with memory optimization if (self.x_t_latent_buffer is None or self.is_txt2img_mode) and num_sigmas > 1: - # Optimization 6: Pre-allocate or resize as needed + # Pre-allocate or resize as needed if self.x_t_latent_buffer is None or self.x_t_latent_buffer.shape != x[0].shape: self.x_t_latent_buffer = torch.empty_like(x[0]) - # In-place copy instead of clone self.x_t_latent_buffer.copy_(x[0]) # Use buffer for first frame to maintain temporal consistency if num_sigmas > 1: - # In-place update - no new allocation x[0].copy_(self.x_t_latent_buffer) - # Run model on entire batch at once with torch.no_grad(): # Process all frames in parallel sigma_batch = sigmas[:-1] @@ -131,7 +123,7 @@ def sample(self, model, noise, sigmas, extra_args=None, callback=None, disable=N # Store result from first frame as buffer for next iteration self.x_t_latent_buffer.copy_(denoised_batch[0]) # In-place update - # Optimization 7: Pre-allocate output buffer + # Pre-allocate output buffer if self.output_buffer is None or self.output_buffer.shape != (1, *denoised_batch[-1].shape): self.output_buffer = torch.empty(1, *denoised_batch[-1].shape, device=denoised_batch.device, @@ -158,17 +150,13 @@ def update(self, num_steps=4): class StreamScheduler: - RETURN_TYPES = ("SIGMAS",) - FUNCTION = "update" - CATEGORY = "StreamPack/sampling" - DESCRIPTION = "Implements StreamDiffusion's efficient timestep selection. Use in conjunction with StreamBatchSampler." @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), "t_index_list": ("STRING", { - "default": "32,45", + "default": "0,16,32,49", "tooltip": "Comma-separated list of timesteps to actually use for denoising. Examples: '32,45' for img2img or '0,16,32,45' for txt2img" }), "num_inference_steps": ("INT", { @@ -181,6 +169,11 @@ def INPUT_TYPES(cls): }, } + RETURN_TYPES = ("SIGMAS",) + FUNCTION = "update" + CATEGORY = "StreamPack/sampling" + DESCRIPTION = "Implements StreamDiffusion's efficient timestep selection. Use in conjunction with StreamBatchSampler." + def update(self, model, t_index_list="32,45", num_inference_steps=50): # Get model's sampling parameters model_sampling = model.get_model_object("model_sampling") @@ -214,11 +207,6 @@ def update(self, model, t_index_list="32,45", num_inference_steps=50): class StreamFrameBuffer: - - RETURN_TYPES = ("LATENT",) - FUNCTION = "update" - CATEGORY = "StreamPack/sampling" - DESCRIPTION = "Accumulates frames to enable staggered batch denoising like StreamDiffusion. Use in conjunction with StreamBatchSampler" @classmethod def INPUT_TYPES(cls): return { @@ -234,12 +222,20 @@ def INPUT_TYPES(cls): }, } + RETURN_TYPES = ("LATENT",) + FUNCTION = "update" + CATEGORY = "StreamPack/sampling" + DESCRIPTION = "Accumulates frames to enable staggered batch denoising like StreamDiffusion. Use in conjunction with StreamBatchSampler" + + def __init__(self): self.frame_buffer = None # Tensor of shape [buffer_size, C, H, W] self.buffer_size = None - self.buffer_pos = 0 # Current position in ring buffer + self.buffer_pos = 0 # Current position self.is_initialized = False # Track buffer initialization self.is_txt2img_mode = False + self.last_valid_frame = None # Store the last valid frame as fallback + self.expected_shape = None # Store expected shape for validation def update(self, latent, buffer_size=4): """Add new frame to buffer and return batch when ready""" @@ -248,46 +244,48 @@ def update(self, latent, buffer_size=4): # Extract latent tensor from input and remove batch dimension if present x = latent["samples"] - # Check if this is an empty latent (for txt2img) - is_empty_latent = x.numel() == 0 or (x.dim() > 0 and x.shape[0] == 0) + # Check if this is a txt2img (zeros tensor) or img2img mode + # In ComfyUI, EmptyLatentImage returns a zeros tensor with shape [batch_size, 4, h//8, w//8] + # We consider it txt2img mode if the tensor contains all zeros + is_txt2img_mode = torch.sum(torch.abs(x)) < 1e-6 + self.is_txt2img_mode = is_txt2img_mode - if is_empty_latent: - self.is_txt2img_mode = True - - # Create empty latents with correct dimensions for txt2img - # Get dimensions from latent dict - height = latent.get("height", 512) - width = latent.get("width", 512) - - # Calculate latent dimensions (typically 1/8 of image dimensions for SD) - latent_height = height // 8 - latent_width = width // 8 - - # Create zero tensor with correct shape - x = torch.zeros((4, latent_height, latent_width), - device=comfy.model_management.get_torch_device()) - - elif x.dim() == 4: # [B,C,H,W] - self.is_txt2img_mode = False + # If it's a batch with size 1, remove the batch dimension for our buffer + if x.dim() == 4 and x.shape[0] == 1: # [1,C,H,W] x = x.squeeze(0) # Remove batch dimension -> [C,H,W] - # Optimization: Initialize or resize frame_buffer as a tensor - if not self.is_initialized or self.frame_buffer.shape[0] != self.buffer_size or \ - self.frame_buffer.shape[1:] != x.shape: + # Initialize buffer on first run or when buffer size changes + if not self.is_initialized or self.frame_buffer is None or self.frame_buffer.shape[0] != self.buffer_size: + # First initialization - set expected shape and store first frame + if not self.is_initialized: + self.expected_shape = x.shape + self.last_valid_frame = x.clone() + # Pre-allocate buffer with correct shape self.frame_buffer = torch.zeros( - (self.buffer_size, *x.shape), + (self.buffer_size, *self.expected_shape), device=x.device, dtype=x.dtype ) + if self.is_txt2img_mode or not self.is_initialized: + # Use the right-sized frame for initialization + init_frame = x if x.shape == self.expected_shape else self.last_valid_frame # Optimization: Use broadcasting to fill buffer with copies - self.frame_buffer[:] = x.unsqueeze(0) # Broadcast x to [buffer_size, C, H, W] + self.frame_buffer[:] = init_frame.unsqueeze(0) # Broadcast to [buffer_size, C, H, W] self.is_initialized = True self.buffer_pos = 0 else: - # Add new frame to buffer using ring buffer logic + # Check if incoming frame matches expected dimensions + if x.shape != self.expected_shape: + # Size mismatch - use last valid frame instead + x = self.last_valid_frame + else: + # Valid frame - update our reference + self.last_valid_frame = x.clone() + + # Add frame to buffer using ring buffer logic self.frame_buffer[self.buffer_pos] = x # In-place update self.buffer_pos = (self.buffer_pos + 1) % self.buffer_size # Circular increment From 1ed06cdfd16dec73e37277080fa4c291df4b2c98 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Thu, 13 Mar 2025 20:39:59 -0400 Subject: [PATCH 18/19] streamwarp demo --- examples/313WARP512_3step.json | 308 +++++++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 examples/313WARP512_3step.json diff --git a/examples/313WARP512_3step.json b/examples/313WARP512_3step.json new file mode 100644 index 0000000..2a3bd64 --- /dev/null +++ b/examples/313WARP512_3step.json @@ -0,0 +1,308 @@ +{ + "3": { + "inputs": { + "unet_name": "static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine", + "model_type": "SD15" + }, + "class_type": "TensorRTLoader", + "_meta": { + "title": "TensorRT Loader" + } + }, + "5": { + "inputs": { + "text": "an abstract masterpiece trippy flowing orbs", + "clip": [ + "89", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "89", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "11": { + "inputs": { + "vae_name": "taesd" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "13": { + "inputs": { + "backend": "inductor", + "fullgraph": true, + "mode": "reduce-overhead", + "compile_encoder": true, + "compile_decoder": true, + "vae": [ + "11", + 0 + ] + }, + "class_type": "TorchCompileLoadVAE", + "_meta": { + "title": "TorchCompileLoadVAE" + } + }, + "20": { + "inputs": { + "samples": [ + "30", + 0 + ], + "vae": [ + "13", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "30": { + "inputs": { + "add_noise": true, + "noise_seed": 276614755339600, + "cfg": 1.1, + "model": [ + "3", + 0 + ], + "positive": [ + "71", + 0 + ], + "negative": [ + "71", + 1 + ], + "sampler": [ + "77", + 0 + ], + "sigmas": [ + "31", + 0 + ], + "latent_image": [ + "78", + 0 + ] + }, + "class_type": "SamplerCustom", + "_meta": { + "title": "SamplerCustom" + } + }, + "31": { + "inputs": { + "t_index_list": "0,16,32", + "num_inference_steps": 50, + "model": [ + "3", + 0 + ] + }, + "class_type": "StreamScheduler", + "_meta": { + "title": "Stream Scheduler 🕒🅡🅣🅝" + } + }, + "53": { + "inputs": { + "clip_name": "sd15/dreamshaper/model.fp16.safetensors", + "type": "stable_diffusion", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "60": { + "inputs": { + "width": [ + "107", + 0 + ], + "height": [ + "107", + 1 + ], + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "71": { + "inputs": { + "strength": 1, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "5", + 0 + ], + "negative": [ + "6", + 0 + ], + "control_net": [ + "85", + 0 + ], + "image": [ + "105", + 0 + ] + }, + "class_type": "ControlNetApplyAdvanced", + "_meta": { + "title": "Apply ControlNet" + } + }, + "72": { + "inputs": { + "control_net_name": "control_v11f1p_sd15_depth_fp16.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "77": { + "inputs": { + "num_steps": 3 + }, + "class_type": "StreamBatchSampler", + "_meta": { + "title": "Stream Batch Sampler 🕒🅡🅣🅝" + } + }, + "78": { + "inputs": { + "buffer_size": 3, + "latent": [ + "60", + 0 + ] + }, + "class_type": "StreamFrameBuffer", + "_meta": { + "title": "Stream Frame Buffer 🕒🅡🅣🅝" + } + }, + "83": { + "inputs": { + "image": "harold.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "85": { + "inputs": { + "backend": "inductor", + "fullgraph": false, + "mode": "reduce-overhead", + "controlnet": [ + "72", + 0 + ] + }, + "class_type": "TorchCompileLoadControlNet", + "_meta": { + "title": "TorchCompileLoadControlNet" + } + }, + "89": { + "inputs": { + "stop_at_clip_layer": -2, + "clip": [ + "53", + 0 + ] + }, + "class_type": "CLIPSetLastLayer", + "_meta": { + "title": "CLIP Set Last Layer" + } + }, + "101": { + "inputs": { + "engine": "depth_anything_vitl14-fp16.engine", + "images": [ + "83", + 0 + ] + }, + "class_type": "DepthAnythingTensorrt", + "_meta": { + "title": "Depth Anything Tensorrt" + } + }, + "105": { + "inputs": { + "center_x": 0.5, + "center_y": 0.5, + "stretch_strength": 0, + "falloff": 2, + "pulse_frequency": 0, + "mode": "radial", + "depth_map": [ + "101", + 0 + ] + }, + "class_type": "DepthMapWarpNode", + "_meta": { + "title": "Depth Map Warp Node 🕒🅡🅣🅝" + } + }, + "107": { + "inputs": { + "image": [ + "83", + 0 + ] + }, + "class_type": "ImageDimensions", + "_meta": { + "title": "Image Dimensions 🕒🅡🅣🅝" + } + }, + "110": { + "inputs": { + "images": [ + "20", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file From 90ab9e2505a8be984ca334c135f4d8b78c6f55f2 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Fri, 21 Mar 2025 15:59:42 -0400 Subject: [PATCH 19/19] examples --- examples/313WARP384_highes.json | 377 +++++++++++++++++++++++ examples/313WARP384_highes_upscale.json | 391 ++++++++++++++++++++++++ 2 files changed, 768 insertions(+) create mode 100644 examples/313WARP384_highes.json create mode 100644 examples/313WARP384_highes_upscale.json diff --git a/examples/313WARP384_highes.json b/examples/313WARP384_highes.json new file mode 100644 index 0000000..1a5719e --- /dev/null +++ b/examples/313WARP384_highes.json @@ -0,0 +1,377 @@ +{ + "3": { + "inputs": { + "unet_name": "dynamic-dreamshaper8_SD15_$dyn-b-1-4-2-h-384-640-512-w-384-640-512_00001_.engine", + "model_type": "SD15" + }, + "class_type": "TensorRTLoader", + "_meta": { + "title": "TensorRT Loader" + } + }, + "5": { + "inputs": { + "text": "an abstract masterpiece trippy flowing orbs", + "clip": [ + "89", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "89", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "11": { + "inputs": { + "vae_name": "taesd" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "13": { + "inputs": { + "backend": "inductor", + "fullgraph": true, + "mode": "reduce-overhead", + "compile_encoder": true, + "compile_decoder": true, + "vae": [ + "11", + 0 + ] + }, + "class_type": "TorchCompileLoadVAE", + "_meta": { + "title": "TorchCompileLoadVAE" + } + }, + "30": { + "inputs": { + "add_noise": true, + "noise_seed": 156051838586740, + "cfg": 1, + "model": [ + "3", + 0 + ], + "positive": [ + "71", + 0 + ], + "negative": [ + "71", + 1 + ], + "sampler": [ + "77", + 0 + ], + "sigmas": [ + "31", + 0 + ], + "latent_image": [ + "78", + 0 + ] + }, + "class_type": "SamplerCustom", + "_meta": { + "title": "SamplerCustom" + } + }, + "31": { + "inputs": { + "t_index_list": "0,16,32", + "num_inference_steps": 50, + "model": [ + "3", + 0 + ] + }, + "class_type": "StreamScheduler", + "_meta": { + "title": "Stream Scheduler 🕒🅡🅣🅝" + } + }, + "53": { + "inputs": { + "clip_name": "sd15/dreamshaper/model.fp16.safetensors", + "type": "stable_diffusion", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "71": { + "inputs": { + "strength": 1, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "5", + 0 + ], + "negative": [ + "6", + 0 + ], + "control_net": [ + "85", + 0 + ], + "image": [ + "105", + 0 + ] + }, + "class_type": "ControlNetApplyAdvanced", + "_meta": { + "title": "Apply ControlNet" + } + }, + "72": { + "inputs": { + "control_net_name": "control_v11f1p_sd15_depth_fp16.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "77": { + "inputs": { + "num_steps": 3 + }, + "class_type": "StreamBatchSampler", + "_meta": { + "title": "Stream Batch Sampler 🕒🅡🅣🅝" + } + }, + "78": { + "inputs": { + "buffer_size": 3, + "latent": [ + "119", + 0 + ] + }, + "class_type": "StreamFrameBuffer", + "_meta": { + "title": "Stream Frame Buffer 🕒🅡🅣🅝" + } + }, + "83": { + "inputs": { + "image": "harold.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "85": { + "inputs": { + "backend": "inductor", + "fullgraph": false, + "mode": "reduce-overhead", + "controlnet": [ + "72", + 0 + ] + }, + "class_type": "TorchCompileLoadControlNet", + "_meta": { + "title": "TorchCompileLoadControlNet" + } + }, + "89": { + "inputs": { + "stop_at_clip_layer": -2, + "clip": [ + "53", + 0 + ] + }, + "class_type": "CLIPSetLastLayer", + "_meta": { + "title": "CLIP Set Last Layer" + } + }, + "101": { + "inputs": { + "engine": "depth_anything_vitl14-fp16.engine", + "images": [ + "83", + 0 + ] + }, + "class_type": "DepthAnythingTensorrt", + "_meta": { + "title": "Depth Anything Tensorrt" + } + }, + "105": { + "inputs": { + "center_x": 0.5, + "center_y": 0.5, + "stretch_strength": 0, + "falloff": 2, + "pulse_frequency": 0, + "mode": "radial", + "depth_map": [ + "101", + 0 + ] + }, + "class_type": "DepthMapWarpNode", + "_meta": { + "title": "Depth Map Warp Node 🕒🅡🅣🅝" + } + }, + "112": { + "inputs": { + "samples": [ + "113", + 0 + ], + "vae": [ + "13", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "113": { + "inputs": { + "add_noise": true, + "noise_seed": 283295832341401, + "cfg": 1, + "model": [ + "3", + 0 + ], + "positive": [ + "71", + 0 + ], + "negative": [ + "71", + 1 + ], + "sampler": [ + "117", + 0 + ], + "sigmas": [ + "116", + 0 + ], + "latent_image": [ + "114", + 0 + ] + }, + "class_type": "SamplerCustom", + "_meta": { + "title": "SamplerCustom" + } + }, + "114": { + "inputs": { + "buffer_size": 2, + "latent": [ + "118", + 0 + ] + }, + "class_type": "StreamFrameBuffer", + "_meta": { + "title": "Stream Frame Buffer 🕒🅡🅣🅝" + } + }, + "115": { + "inputs": { + "images": [ + "112", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "116": { + "inputs": { + "t_index_list": "2,32", + "num_inference_steps": 50, + "model": [ + "3", + 0 + ] + }, + "class_type": "StreamScheduler", + "_meta": { + "title": "Stream Scheduler 🕒🅡🅣🅝" + } + }, + "117": { + "inputs": { + "num_steps": 2 + }, + "class_type": "StreamBatchSampler", + "_meta": { + "title": "Stream Batch Sampler 🕒🅡🅣🅝" + } + }, + "118": { + "inputs": { + "upscale_method": "nearest-exact", + "width": 640, + "height": 640, + "crop": "disabled", + "samples": [ + "30", + 0 + ] + }, + "class_type": "LatentUpscale", + "_meta": { + "title": "Upscale Latent" + } + }, + "119": { + "inputs": { + "width": 384, + "height": 384, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + } +} \ No newline at end of file diff --git a/examples/313WARP384_highes_upscale.json b/examples/313WARP384_highes_upscale.json new file mode 100644 index 0000000..24bed2a --- /dev/null +++ b/examples/313WARP384_highes_upscale.json @@ -0,0 +1,391 @@ +{ + "3": { + "inputs": { + "unet_name": "dynamic-dreamshaper8_SD15_$dyn-b-1-4-2-h-384-640-512-w-384-640-512_00001_.engine", + "model_type": "SD15" + }, + "class_type": "TensorRTLoader", + "_meta": { + "title": "TensorRT Loader" + } + }, + "5": { + "inputs": { + "text": "an abstract masterpiece trippy flowing orbs", + "clip": [ + "89", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "89", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "11": { + "inputs": { + "vae_name": "taesd" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "13": { + "inputs": { + "backend": "inductor", + "fullgraph": true, + "mode": "reduce-overhead", + "compile_encoder": true, + "compile_decoder": true, + "vae": [ + "11", + 0 + ] + }, + "class_type": "TorchCompileLoadVAE", + "_meta": { + "title": "TorchCompileLoadVAE" + } + }, + "30": { + "inputs": { + "add_noise": true, + "noise_seed": 156051838586740, + "cfg": 1, + "model": [ + "3", + 0 + ], + "positive": [ + "71", + 0 + ], + "negative": [ + "71", + 1 + ], + "sampler": [ + "77", + 0 + ], + "sigmas": [ + "31", + 0 + ], + "latent_image": [ + "78", + 0 + ] + }, + "class_type": "SamplerCustom", + "_meta": { + "title": "SamplerCustom" + } + }, + "31": { + "inputs": { + "t_index_list": "0,16,32", + "num_inference_steps": 50, + "model": [ + "3", + 0 + ] + }, + "class_type": "StreamScheduler", + "_meta": { + "title": "Stream Scheduler 🕒🅡🅣🅝" + } + }, + "53": { + "inputs": { + "clip_name": "sd15/dreamshaper/model.fp16.safetensors", + "type": "stable_diffusion", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "71": { + "inputs": { + "strength": 1, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "5", + 0 + ], + "negative": [ + "6", + 0 + ], + "control_net": [ + "85", + 0 + ], + "image": [ + "105", + 0 + ] + }, + "class_type": "ControlNetApplyAdvanced", + "_meta": { + "title": "Apply ControlNet" + } + }, + "72": { + "inputs": { + "control_net_name": "control_v11f1p_sd15_depth_fp16.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "77": { + "inputs": { + "num_steps": 3 + }, + "class_type": "StreamBatchSampler", + "_meta": { + "title": "Stream Batch Sampler 🕒🅡🅣🅝" + } + }, + "78": { + "inputs": { + "buffer_size": 3, + "latent": [ + "119", + 0 + ] + }, + "class_type": "StreamFrameBuffer", + "_meta": { + "title": "Stream Frame Buffer 🕒🅡🅣🅝" + } + }, + "83": { + "inputs": { + "image": "harold.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "85": { + "inputs": { + "backend": "inductor", + "fullgraph": false, + "mode": "reduce-overhead", + "controlnet": [ + "72", + 0 + ] + }, + "class_type": "TorchCompileLoadControlNet", + "_meta": { + "title": "TorchCompileLoadControlNet" + } + }, + "89": { + "inputs": { + "stop_at_clip_layer": -2, + "clip": [ + "53", + 0 + ] + }, + "class_type": "CLIPSetLastLayer", + "_meta": { + "title": "CLIP Set Last Layer" + } + }, + "101": { + "inputs": { + "engine": "depth_anything_vitl14-fp16.engine", + "images": [ + "83", + 0 + ] + }, + "class_type": "DepthAnythingTensorrt", + "_meta": { + "title": "Depth Anything Tensorrt" + } + }, + "105": { + "inputs": { + "center_x": 0.5, + "center_y": 0.5, + "stretch_strength": 0, + "falloff": 2, + "pulse_frequency": 0, + "mode": "radial", + "depth_map": [ + "101", + 0 + ] + }, + "class_type": "DepthMapWarpNode", + "_meta": { + "title": "Depth Map Warp Node 🕒🅡🅣🅝" + } + }, + "112": { + "inputs": { + "samples": [ + "113", + 0 + ], + "vae": [ + "13", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "113": { + "inputs": { + "add_noise": true, + "noise_seed": 283295832341401, + "cfg": 1, + "model": [ + "3", + 0 + ], + "positive": [ + "71", + 0 + ], + "negative": [ + "71", + 1 + ], + "sampler": [ + "117", + 0 + ], + "sigmas": [ + "116", + 0 + ], + "latent_image": [ + "114", + 0 + ] + }, + "class_type": "SamplerCustom", + "_meta": { + "title": "SamplerCustom" + } + }, + "114": { + "inputs": { + "buffer_size": 2, + "latent": [ + "118", + 0 + ] + }, + "class_type": "StreamFrameBuffer", + "_meta": { + "title": "Stream Frame Buffer 🕒🅡🅣🅝" + } + }, + "115": { + "inputs": { + "images": [ + "125", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "116": { + "inputs": { + "t_index_list": "2,32", + "num_inference_steps": 50, + "model": [ + "3", + 0 + ] + }, + "class_type": "StreamScheduler", + "_meta": { + "title": "Stream Scheduler 🕒🅡🅣🅝" + } + }, + "117": { + "inputs": { + "num_steps": 2 + }, + "class_type": "StreamBatchSampler", + "_meta": { + "title": "Stream Batch Sampler 🕒🅡🅣🅝" + } + }, + "118": { + "inputs": { + "upscale_method": "nearest-exact", + "width": 640, + "height": 640, + "crop": "disabled", + "samples": [ + "30", + 0 + ] + }, + "class_type": "LatentUpscale", + "_meta": { + "title": "Upscale Latent" + } + }, + "119": { + "inputs": { + "width": 384, + "height": 384, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "125": { + "inputs": { + "upscale_method": "lanczos", + "scale_by": 1.5, + "image": [ + "112", + 0 + ] + }, + "class_type": "ImageScaleBy", + "_meta": { + "title": "Upscale Image By" + } + } +} \ No newline at end of file