Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class LatentFormat:
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8

def process_in(self, latent):
return latent * self.scale_factor
Expand Down Expand Up @@ -181,6 +182,7 @@ def process_out(self, latent):

class Flux2(LatentFormat):
latent_channels = 128
spacial_downscale_ratio = 16

def __init__(self):
self.latent_rgb_factors =[
Expand Down Expand Up @@ -749,6 +751,7 @@ class ACEAudio(LatentFormat):

class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1

def __init__(self):
self.latent_rgb_factors = [
Expand Down
12 changes: 9 additions & 3 deletions comfy/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):

return noises

def fix_empty_latent_channels(model, latent_image):
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if torch.count_nonzero(latent_image) == 0:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")

if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image
Expand Down
6 changes: 4 additions & 2 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler,
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
latent["samples"] = latent_image

if not add_noise:
Expand All @@ -760,6 +760,7 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler,
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)

out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
Expand Down Expand Up @@ -939,7 +940,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
latent["samples"] = latent_image

noise_mask = None
Expand All @@ -954,6 +955,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
samples = samples.to(comfy.model_management.intermediate_device())

out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def define_schema(cls):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})

generate = execute # TODO: remove

Expand Down
5 changes: 3 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def INPUT_TYPES(s):

def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
return ({"samples": latent, "downscale_ratio_spacial": 8}, )


class LatentFromBatch:
Expand Down Expand Up @@ -1538,7 +1538,7 @@ def set_mask(self, samples, mask):

def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))

if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
Expand All @@ -1556,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
return (out, )

Expand Down