Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 65 additions & 44 deletions frontend/src/components/SettingsPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import {
} from "../data/pipelines";
import { PARAMETER_METADATA } from "../data/parameterMetadata";
import { DenoisingStepsSlider } from "./DenoisingStepsSlider";
import {
getResolutionScaleFactor,
adjustResolutionForPipeline,
} from "../lib/utils";
import { useLocalSliderValue } from "../hooks/useLocalSliderValue";
import type {
PipelineId,
Expand Down Expand Up @@ -140,6 +144,15 @@ export function SettingsPanel({
const [widthError, setWidthError] = useState<string | null>(null);
const [seedError, setSeedError] = useState<string | null>(null);

// Check if resolution needs adjustment
const scaleFactor = getResolutionScaleFactor(pipelineId);
const resolutionWarning =
scaleFactor &&
(resolution.height % scaleFactor !== 0 ||
resolution.width % scaleFactor !== 0)
? `Resolution will be adjusted to ${adjustResolutionForPipeline(pipelineId, resolution).resolution.width}×${adjustResolutionForPipeline(pipelineId, resolution).resolution.height} when starting the stream (must be divisible by ${scaleFactor})`
: null;

const handlePipelineIdChange = (value: string) => {
if (value in PIPELINES) {
onPipelineIdChange?.(value as PipelineId);
Expand Down Expand Up @@ -500,55 +513,63 @@ export function SettingsPanel({
<p className="text-xs text-red-500 ml-16">{widthError}</p>
)}
</div>
{resolutionWarning && (
<div className="flex items-start gap-1">
<Info className="h-3.5 w-3.5 mt-0.5 shrink-0 text-amber-600 dark:text-amber-500" />
<p className="text-xs text-amber-600 dark:text-amber-500">
{resolutionWarning}
</p>
</div>
)}
</div>

<div className="space-y-1">
<div className="flex items-center gap-2">
<LabelWithTooltip
label={PARAMETER_METADATA.seed.label}
tooltip={PARAMETER_METADATA.seed.tooltip}
className="text-sm text-foreground w-14"
<div className="space-y-1">
<div className="flex items-center gap-2">
<LabelWithTooltip
label={PARAMETER_METADATA.seed.label}
tooltip={PARAMETER_METADATA.seed.tooltip}
className="text-sm text-foreground w-14"
/>
<div
className={`flex-1 flex items-center border rounded-full overflow-hidden h-8 ${seedError ? "border-red-500" : ""}`}
>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 shrink-0 rounded-none hover:bg-accent"
onClick={decrementSeed}
disabled={isStreaming}
>
<Minus className="h-3.5 w-3.5" />
</Button>
<Input
type="number"
value={seed}
onChange={e => {
const value = parseInt(e.target.value);
if (!isNaN(value)) {
handleSeedChange(value);
}
}}
disabled={isStreaming}
className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
min={0}
max={2147483647}
/>
<div
className={`flex-1 flex items-center border rounded-full overflow-hidden h-8 ${seedError ? "border-red-500" : ""}`}
<Button
variant="ghost"
size="icon"
className="h-8 w-8 shrink-0 rounded-none hover:bg-accent"
onClick={incrementSeed}
disabled={isStreaming}
>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 shrink-0 rounded-none hover:bg-accent"
onClick={decrementSeed}
disabled={isStreaming}
>
<Minus className="h-3.5 w-3.5" />
</Button>
<Input
type="number"
value={seed}
onChange={e => {
const value = parseInt(e.target.value);
if (!isNaN(value)) {
handleSeedChange(value);
}
}}
disabled={isStreaming}
className="text-center border-0 focus-visible:ring-0 focus-visible:ring-offset-0 h-8 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
min={0}
max={2147483647}
/>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 shrink-0 rounded-none hover:bg-accent"
onClick={incrementSeed}
disabled={isStreaming}
>
<Plus className="h-3.5 w-3.5" />
</Button>
</div>
<Plus className="h-3.5 w-3.5" />
</Button>
</div>
{seedError && (
<p className="text-xs text-red-500 ml-16">{seedError}</p>
)}
</div>
{seedError && (
<p className="text-xs text-red-500 ml-16">{seedError}</p>
)}
</div>
</div>
</div>
Expand Down
50 changes: 50 additions & 0 deletions frontend/src/lib/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,56 @@
import { type ClassValue, clsx } from "clsx";
import { twMerge } from "tailwind-merge";
import type { PipelineId } from "../types";

export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
}

/**
* Gets the scale factor that resolution must be divisible by for a given pipeline.
* Returns null if the pipeline doesn't require resolution adjustment.
*/
export function getResolutionScaleFactor(
pipelineId: PipelineId
): number | null {
if (
pipelineId === "longlive" ||
pipelineId === "streamdiffusionv2" ||
pipelineId === "krea-realtime-video" ||
pipelineId === "reward-forcing"
) {
// VAE downsample (8) * patch embedding downsample (2) = 16
return 16;
}
return null;
}

/**
* Adjusts resolution to be divisible by the required scale factor for the pipeline.
* Returns the adjusted resolution and whether it was changed.
*/
export function adjustResolutionForPipeline(
pipelineId: PipelineId,
resolution: { height: number; width: number }
): {
resolution: { height: number; width: number };
wasAdjusted: boolean;
} {
const scaleFactor = getResolutionScaleFactor(pipelineId);
if (!scaleFactor) {
return { resolution, wasAdjusted: false };
}

const adjustedHeight =
Math.round(resolution.height / scaleFactor) * scaleFactor;
const adjustedWidth =
Math.round(resolution.width / scaleFactor) * scaleFactor;

const wasAdjusted =
adjustedHeight !== resolution.height || adjustedWidth !== resolution.width;

return {
resolution: { height: adjustedHeight, width: adjustedWidth },
wasAdjusted,
};
}
15 changes: 14 additions & 1 deletion frontend/src/pages/StreamPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
getDefaultPromptForMode,
pipelineSupportsVACE,
} from "../data/pipelines";
import { adjustResolutionForPipeline } from "../lib/utils";
import type {
InputMode,
PipelineId,
Expand Down Expand Up @@ -688,7 +689,19 @@ export function StreamPage() {
let loadParams = null;

// Use settings.resolution if available, otherwise fall back to videoResolution
const resolution = settings.resolution || videoResolution;
let resolution = settings.resolution || videoResolution;

// Adjust resolution to be divisible by required scale factor for the pipeline
if (resolution) {
const { resolution: adjustedResolution, wasAdjusted } =
adjustResolutionForPipeline(pipelineIdToUse, resolution);

if (wasAdjusted) {
// Update settings with adjusted resolution
updateSettings({ resolution: adjustedResolution });
resolution = adjustedResolution;
}
}

// Compute VACE enabled state once - enabled by default for text mode on VACE-supporting pipelines
const vaceEnabled =
Expand Down
10 changes: 9 additions & 1 deletion src/scope/core/pipelines/krea_realtime_video/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..interface import Pipeline, Requirements
from ..process import postprocess_chunk
from ..schema import KreaRealtimeVideoConfig
from ..utils import Quantization, load_model_config
from ..utils import Quantization, load_model_config, validate_resolution
from ..wan2_1.components import WanDiffusionWrapper, WanTextEncoderWrapper
from ..wan2_1.lora.mixin import LoRAEnabledPipeline
from ..wan2_1.vae import WanVAEWrapper
Expand Down Expand Up @@ -49,6 +49,14 @@ def __init__(
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
):
# Validate resolution requirements
# VAE downsample (8) * patch embedding downsample (2) = 16
validate_resolution(
height=config.height,
width=config.width,
scale_factor=16,
)

model_dir = getattr(config, "model_dir", None)
generator_path = getattr(config, "generator_path", None)
text_encoder_path = getattr(config, "text_encoder_path", None)
Expand Down
10 changes: 9 additions & 1 deletion src/scope/core/pipelines/longlive/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..interface import Pipeline, Requirements
from ..process import postprocess_chunk
from ..schema import LongLiveConfig
from ..utils import Quantization, load_model_config
from ..utils import Quantization, load_model_config, validate_resolution
from ..wan2_1.components import WanDiffusionWrapper, WanTextEncoderWrapper
from ..wan2_1.lora.mixin import LoRAEnabledPipeline
from ..wan2_1.lora.strategies.module_targeted_lora import ModuleTargetedLoRAStrategy
Expand Down Expand Up @@ -45,6 +45,14 @@ def __init__(
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
):
# Validate resolution requirements
# VAE downsample (8) * patch embedding downsample (2) = 16
validate_resolution(
height=config.height,
width=config.width,
scale_factor=16,
)

model_dir = getattr(config, "model_dir", None)
generator_path = getattr(config, "generator_path", None)
lora_path = getattr(config, "lora_path", None)
Expand Down
9 changes: 8 additions & 1 deletion src/scope/core/pipelines/reward_forcing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..interface import Pipeline, Requirements
from ..process import postprocess_chunk
from ..schema import RewardForcingConfig
from ..utils import Quantization, load_model_config
from ..utils import Quantization, load_model_config, validate_resolution
from ..wan2_1.components import WanDiffusionWrapper, WanTextEncoderWrapper
from ..wan2_1.lora.mixin import LoRAEnabledPipeline
from ..wan2_1.vace.mixin import VACEEnabledPipeline
Expand Down Expand Up @@ -50,6 +50,13 @@ def __init__(
tokenizer_path = getattr(config, "tokenizer_path", None)

model_config = load_model_config(config, __file__)

validate_resolution(
height=config.height,
width=config.width,
scale_factor=16,
)

base_model_name = getattr(model_config, "base_model_name", "Wan2.1-T2V-1.3B")
base_model_kwargs = getattr(model_config, "base_model_kwargs", {})

Expand Down
10 changes: 9 additions & 1 deletion src/scope/core/pipelines/streamdiffusionv2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..interface import Pipeline, Requirements
from ..process import postprocess_chunk
from ..schema import StreamDiffusionV2Config
from ..utils import Quantization, load_model_config
from ..utils import Quantization, load_model_config, validate_resolution
from ..wan2_1.components import WanDiffusionWrapper, WanTextEncoderWrapper
from ..wan2_1.lora.mixin import LoRAEnabledPipeline
from ..wan2_1.vace import VACEEnabledPipeline
Expand Down Expand Up @@ -44,6 +44,14 @@ def __init__(
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
):
# Validate resolution requirements
# VAE downsample (8) * patch embedding downsample (2) = 16
validate_resolution(
height=config.height,
width=config.width,
scale_factor=16,
)

model_dir = getattr(config, "model_dir", None)
generator_path = getattr(config, "generator_path", None)
text_encoder_path = getattr(config, "text_encoder_path", None)
Expand Down
27 changes: 27 additions & 0 deletions src/scope/core/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,30 @@ def load_model_config(config, pipeline_file_path: str | Path) -> OmegaConf:
model_yaml_path = Path(pipeline_file_path).parent / "model.yaml"
model_config = OmegaConf.load(model_yaml_path)
return model_config


def validate_resolution(
height: int,
width: int,
scale_factor: int,
) -> None:
"""
Validate that resolution dimensions are divisible by the required scale factor.
Args:
height: Height of the resolution
width: Width of the resolution
scale_factor: The factor that both dimensions must be divisible by
Raises:
ValueError: If height or width is not divisible by scale_factor
"""
if height % scale_factor != 0 or width % scale_factor != 0:
adjusted_width = (width // scale_factor) * scale_factor
adjusted_height = (height // scale_factor) * scale_factor
raise ValueError(
f"Invalid resolution {width}×{height}. "
f"Both width and height must be divisible by {scale_factor} "
f"(VAE downsample factor 8 × patch embedding downsample factor 2 = {scale_factor}). "
f"Please adjust to a valid resolution, e.g., {adjusted_width}×{adjusted_height}."
)
Loading