Skip to content

Commit 80b7c94

Browse files
Changes to the previous radiance commit. (#9851)
1 parent c1297f4 commit 80b7c94

File tree

5 files changed

+35
-66
lines changed

5 files changed

+35
-66
lines changed

comfy/ldm/chroma_radiance/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ def _forward(
306306

307307
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
308308

309-
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
310-
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
309+
h_len = (img.shape[-2] // self.patch_size)
310+
w_len = (img.shape[-1] // self.patch_size)
311+
311312
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
312313
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
313314
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
@@ -325,4 +326,4 @@ def _forward(
325326
transformer_options,
326327
attn_mask=kwargs.get("attention_mask", None),
327328
)
328-
return self.forward_nerf(img, img_out, params)
329+
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]

comfy/pixel_space_convert.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
4+
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
5+
# to LATENT B, C, H, W and values on the scale of -1..1.
6+
class PixelspaceConversionVAE(torch.nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
10+
11+
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
12+
return pixels
13+
14+
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
15+
return samples
16+

comfy/sd.py

Lines changed: 10 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import comfy.ldm.hunyuan3d.vae
1919
import comfy.ldm.ace.vae.music_dcae_pipeline
2020
import comfy.ldm.hunyuan_video.vae
21+
import comfy.pixel_space_convert
2122
import yaml
2223
import math
2324
import os
@@ -516,6 +517,15 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
516517
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
517518
self.disable_offload = True
518519
self.extra_1d_channel = 16
520+
elif "pixel_space_vae" in sd:
521+
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
522+
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
523+
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
524+
self.downscale_ratio = 1
525+
self.upscale_ratio = 1
526+
self.latent_channels = 3
527+
self.latent_dim = 2
528+
self.output_channels = 3
519529
else:
520530
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
521531
self.first_stage_model = None
@@ -785,65 +795,6 @@ def temporal_compression_decode(self):
785795
except:
786796
return None
787797

788-
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
789-
# to LATENT B, C, H, W and values on the scale of -1..1.
790-
class PixelspaceConversionVAE:
791-
def __init__(self, size_increment: int=16):
792-
self.intermediate_device = comfy.model_management.intermediate_device()
793-
self.size_increment = size_increment
794-
795-
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
796-
if self.size_increment == 1:
797-
return pixels
798-
dims = pixels.shape[1:-1]
799-
for d in range(len(dims)):
800-
d_adj = (dims[d] // self.size_increment) * self.size_increment
801-
if d_adj == d:
802-
continue
803-
d_offset = (dims[d] % self.size_increment) // 2
804-
pixels = pixels.narrow(d + 1, d_offset, d_adj)
805-
return pixels
806-
807-
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
808-
if pixels.ndim == 3:
809-
pixels = pixels.unsqueeze(0)
810-
elif pixels.ndim != 4:
811-
raise ValueError("Unexpected input image shape")
812-
# Ensure the image has spatial dimensions that are multiples of 16.
813-
pixels = self.vae_encode_crop_pixels(pixels)
814-
h, w, c = pixels.shape[1:]
815-
if h < self.size_increment or w < self.size_increment:
816-
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
817-
pixels= pixels[..., :3]
818-
if c == 1:
819-
pixels = pixels.expand(-1, -1, -1, 3)
820-
elif c != 3:
821-
raise ValueError("Unexpected number of channels in input image")
822-
# Rescale to -1..1 and move the channel dimension to position 1.
823-
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
824-
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
825-
latent -= 0.5
826-
latent *= 2
827-
return latent.clamp_(-1, 1)
828-
829-
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
830-
# Rescale to 0..1 and move the channel dimension to the end.
831-
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
832-
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
833-
img += 1.0
834-
img *= 0.5
835-
return img.clamp_(0, 1)
836-
837-
encode_tiled = encode
838-
decode_tiled = decode
839-
840-
@classmethod
841-
def spacial_compression_decode(cls) -> int:
842-
# This just exists so the tiled VAE nodes don't crash.
843-
return 1
844-
845-
spacial_compression_encode = spacial_compression_decode
846-
temporal_compression_decode = spacial_compression_decode
847798

848799
class StyleModel:
849800
def __init__(self, model, device="cpu"):

comfy/supported_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma):
12131213
latent_format = comfy.latent_formats.ChromaRadiance
12141214

12151215
# Pixel-space model, no spatial compression for model input.
1216-
memory_usage_factor = 0.0325
1216+
memory_usage_factor = 0.038
12171217

12181218
def get_model(self, state_dict, prefix="", device=None):
12191219
return model_base.ChromaRadiance(self, device=device)

nodes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def vae_list():
730730
vaes.append("taesd3")
731731
if f1_taesd_dec and f1_taesd_enc:
732732
vaes.append("taef1")
733-
vaes.append("chroma_radiance")
733+
vaes.append("pixel_space")
734734
return vaes
735735

736736
@staticmethod
@@ -773,8 +773,9 @@ def INPUT_TYPES(s):
773773

774774
#TODO: scale factor?
775775
def load_vae(self, vae_name):
776-
if vae_name == "chroma_radiance":
777-
return (comfy.sd.PixelspaceConversionVAE(),)
776+
if vae_name == "pixel_space":
777+
sd = {}
778+
sd["pixel_space_vae"] = torch.tensor(1.0)
778779
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
779780
sd = self.load_taesd(vae_name)
780781
else:

0 commit comments

Comments
 (0)