Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions comfy/ldm/chroma_radiance/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ def _forward(

params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))

h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
h_len = (img.shape[-2] // self.patch_size)
w_len = (img.shape[-1] // self.patch_size)

img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
Expand All @@ -325,4 +326,4 @@ def _forward(
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
16 changes: 16 additions & 0 deletions comfy/pixel_space_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE(torch.nn.Module):
def __init__(self):
super().__init__()
self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))

def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
return pixels

def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
return samples

69 changes: 10 additions & 59 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.pixel_space_convert
import yaml
import math
import os
Expand Down Expand Up @@ -516,6 +517,15 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
elif "pixel_space_vae" in sd:
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.downscale_ratio = 1
self.upscale_ratio = 1
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
Expand Down Expand Up @@ -785,65 +795,6 @@ def temporal_compression_decode(self):
except:
return None

# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE:
def __init__(self, size_increment: int=16):
self.intermediate_device = comfy.model_management.intermediate_device()
self.size_increment = size_increment

def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
if self.size_increment == 1:
return pixels
dims = pixels.shape[1:-1]
for d in range(len(dims)):
d_adj = (dims[d] // self.size_increment) * self.size_increment
if d_adj == d:
continue
d_offset = (dims[d] % self.size_increment) // 2
pixels = pixels.narrow(d + 1, d_offset, d_adj)
return pixels

def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
# Ensure the image has spatial dimensions that are multiples of 16.
pixels = self.vae_encode_crop_pixels(pixels)
h, w, c = pixels.shape[1:]
if h < self.size_increment or w < self.size_increment:
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
pixels= pixels[..., :3]
if c == 1:
pixels = pixels.expand(-1, -1, -1, 3)
elif c != 3:
raise ValueError("Unexpected number of channels in input image")
# Rescale to -1..1 and move the channel dimension to position 1.
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
latent -= 0.5
latent *= 2
return latent.clamp_(-1, 1)

def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
# Rescale to 0..1 and move the channel dimension to the end.
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
img += 1.0
img *= 0.5
return img.clamp_(0, 1)

encode_tiled = encode
decode_tiled = decode

@classmethod
def spacial_compression_decode(cls) -> int:
# This just exists so the tiled VAE nodes don't crash.
return 1

spacial_compression_encode = spacial_compression_decode
temporal_compression_decode = spacial_compression_decode

class StyleModel:
def __init__(self, model, device="cpu"):
Expand Down
2 changes: 1 addition & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance

# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.0325
memory_usage_factor = 0.038

def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
Expand Down
7 changes: 4 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def vae_list():
vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
vaes.append("chroma_radiance")
vaes.append("pixel_space")
return vaes

@staticmethod
Expand Down Expand Up @@ -773,8 +773,9 @@ def INPUT_TYPES(s):

#TODO: scale factor?
def load_vae(self, vae_name):
if vae_name == "chroma_radiance":
return (comfy.sd.PixelspaceConversionVAE(),)
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
Expand Down
Loading