Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ __pycache__

output.mp4

src/scope/core/pipelines/**/*.mp4

notes/

# Cursor IDE files
Expand Down
108 changes: 107 additions & 1 deletion src/scope/core/pipelines/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from diffusers.modular_pipelines import PipelineState

from .interface import Pipeline
from .interface import Pipeline, Requirements
from .schema import BasePipelineConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,3 +112,109 @@ def apply_mode_defaults_to_state(
state.set("noise_scale", config.noise_scale)
if "noise_controller" not in kwargs and config.noise_controller is not None:
state.set("noise_controller", config.noise_controller)


# -----------------------------------------------------------------------------
# Multi-mode pipeline helpers
# -----------------------------------------------------------------------------


def calculate_video_input_size(components_config: dict) -> int:
"""Calculate video input size from pipeline component config.

Video input size = num_frame_per_block * vae_temporal_downsample_factor

Args:
components_config: Dictionary with pipeline config values (typically
from components.config)

Returns:
Number of video frames required for video mode input
"""
num_frame_per_block = components_config.get("num_frame_per_block", 3)
vae_temporal_downsample_factor = components_config.get(
"vae_temporal_downsample_factor", 4
)
return num_frame_per_block * vae_temporal_downsample_factor


def prepare_for_mode(
pipeline_class: type["Pipeline"],
components_config: dict,
kwargs: dict,
video_input_size: int | None = None,
) -> "Requirements | None":
"""Determine input requirements based on current mode.

This is the shared implementation for multi-mode pipeline prepare() methods.
Returns video requirements when video mode is active, None for text mode.

Args:
pipeline_class: The pipeline class (for accessing config defaults)
components_config: Dictionary with pipeline config (for calculating
video_input_size if not provided)
kwargs: Call kwargs that may contain 'video' key
video_input_size: Override for video input size. If None, calculated
from components_config.

Returns:
Requirements with input_size for video mode, None for text mode
"""
from .interface import Requirements

# Calculate video input size if not provided
if video_input_size is None:
video_input_size = calculate_video_input_size(components_config)

# If video is explicitly provided, use video mode
if kwargs.get("video") is not None:
return Requirements(input_size=video_input_size)

# Fall back to schema's default mode
config = get_pipeline_config(pipeline_class)
if config.input_size is not None:
return Requirements(input_size=config.input_size)

return None


def handle_mode_transition(
state: "PipelineState",
vae: Any,
first_call: bool,
last_mode: str | None,
kwargs: dict,
) -> tuple[bool, str]:
"""Handle mode changes and cache management for multi-mode pipelines.

Detects mode transitions and manages cache initialization accordingly.
On first call or mode change, sets init_cache=True and clears VAE cache.

Args:
state: PipelineState to update with init_cache
vae: VAE component with clear_cache() method
first_call: Whether this is the first call to the pipeline
last_mode: Previous mode (None if first call)
kwargs: Call kwargs for resolving current mode

Returns:
Tuple of (new_first_call, current_mode) to update pipeline state
"""
current_mode = resolve_input_mode(kwargs)
mode_changed = last_mode is not None and last_mode != current_mode

if first_call or mode_changed:
state.set("init_cache", True)
if mode_changed:
logger.info(
"handle_mode_transition: Mode changed from %s to %s, resetting cache",
last_mode,
current_mode,
)
vae.clear_cache()
first_call = False
else:
# This will be overridden if init_cache is passed in kwargs
state.set("init_cache", False)

return first_call, current_mode
4 changes: 2 additions & 2 deletions src/scope/core/pipelines/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class Pipeline(ABC):
- API introspection and automatic UI generation

See schema.py for the BasePipelineConfig model and pipeline-specific configs.
For multi-mode pipeline support (text/video), use multi_mode.MultiModePipeline
as the base class which provides declarative mode configuration.
For multi-mode pipeline support (text/video), pipelines use helper functions
from defaults.py (resolve_input_mode, apply_mode_defaults_to_state, etc.).
"""

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def get_context_frames(components, state: BlockState) -> torch.Tensor:
vae_device = next(components.vae.parameters()).device
decoded_first_frame = state.decoded_frame_buffer[:, :1].to(vae_device)
reencoded_latent = components.vae.encode_to_latent(
rearrange(decoded_first_frame, "B T C H W -> B C T H W")
rearrange(decoded_first_frame, "B T C H W -> B C T H W"), use_cache=False
)
return torch.cat(
[reencoded_latent, state.context_frame_buffer.to(generator_device)],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .vae import WanVAEWrapper

__all__ = ["WanVAEWrapper"]
238 changes: 0 additions & 238 deletions src/scope/core/pipelines/krea_realtime_video/components/vae.py

This file was deleted.

1 change: 1 addition & 0 deletions src/scope/core/pipelines/krea_realtime_video/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ num_frame_per_block: 3
kv_cache_num_frames: 3
local_attn_size: 6
vae_spatial_downsample_factor: 8
vae_temporal_downsample_factor: 4
patch_embedding_spatial_downsample_factor: 2
max_rope_freq_table_seq_len: 1024
Loading
Loading