|
18 | 18 | import comfy.ldm.hunyuan3d.vae |
19 | 19 | import comfy.ldm.ace.vae.music_dcae_pipeline |
20 | 20 | import comfy.ldm.hunyuan_video.vae |
| 21 | +import comfy.pixel_space_convert |
21 | 22 | import yaml |
22 | 23 | import math |
23 | 24 | import os |
@@ -516,6 +517,15 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): |
516 | 517 | self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] |
517 | 518 | self.disable_offload = True |
518 | 519 | self.extra_1d_channel = 16 |
| 520 | + elif "pixel_space_vae" in sd: |
| 521 | + self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE() |
| 522 | + self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) |
| 523 | + self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) |
| 524 | + self.downscale_ratio = 1 |
| 525 | + self.upscale_ratio = 1 |
| 526 | + self.latent_channels = 3 |
| 527 | + self.latent_dim = 2 |
| 528 | + self.output_channels = 3 |
519 | 529 | else: |
520 | 530 | logging.warning("WARNING: No VAE weights detected, VAE not initalized.") |
521 | 531 | self.first_stage_model = None |
@@ -785,65 +795,6 @@ def temporal_compression_decode(self): |
785 | 795 | except: |
786 | 796 | return None |
787 | 797 |
|
788 | | -# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 |
789 | | -# to LATENT B, C, H, W and values on the scale of -1..1. |
790 | | -class PixelspaceConversionVAE: |
791 | | - def __init__(self, size_increment: int=16): |
792 | | - self.intermediate_device = comfy.model_management.intermediate_device() |
793 | | - self.size_increment = size_increment |
794 | | - |
795 | | - def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: |
796 | | - if self.size_increment == 1: |
797 | | - return pixels |
798 | | - dims = pixels.shape[1:-1] |
799 | | - for d in range(len(dims)): |
800 | | - d_adj = (dims[d] // self.size_increment) * self.size_increment |
801 | | - if d_adj == d: |
802 | | - continue |
803 | | - d_offset = (dims[d] % self.size_increment) // 2 |
804 | | - pixels = pixels.narrow(d + 1, d_offset, d_adj) |
805 | | - return pixels |
806 | | - |
807 | | - def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: |
808 | | - if pixels.ndim == 3: |
809 | | - pixels = pixels.unsqueeze(0) |
810 | | - elif pixels.ndim != 4: |
811 | | - raise ValueError("Unexpected input image shape") |
812 | | - # Ensure the image has spatial dimensions that are multiples of 16. |
813 | | - pixels = self.vae_encode_crop_pixels(pixels) |
814 | | - h, w, c = pixels.shape[1:] |
815 | | - if h < self.size_increment or w < self.size_increment: |
816 | | - raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") |
817 | | - pixels= pixels[..., :3] |
818 | | - if c == 1: |
819 | | - pixels = pixels.expand(-1, -1, -1, 3) |
820 | | - elif c != 3: |
821 | | - raise ValueError("Unexpected number of channels in input image") |
822 | | - # Rescale to -1..1 and move the channel dimension to position 1. |
823 | | - latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) |
824 | | - latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() |
825 | | - latent -= 0.5 |
826 | | - latent *= 2 |
827 | | - return latent.clamp_(-1, 1) |
828 | | - |
829 | | - def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: |
830 | | - # Rescale to 0..1 and move the channel dimension to the end. |
831 | | - img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) |
832 | | - img = img.clamp_(-1, 1).movedim(1, -1).contiguous() |
833 | | - img += 1.0 |
834 | | - img *= 0.5 |
835 | | - return img.clamp_(0, 1) |
836 | | - |
837 | | - encode_tiled = encode |
838 | | - decode_tiled = decode |
839 | | - |
840 | | - @classmethod |
841 | | - def spacial_compression_decode(cls) -> int: |
842 | | - # This just exists so the tiled VAE nodes don't crash. |
843 | | - return 1 |
844 | | - |
845 | | - spacial_compression_encode = spacial_compression_decode |
846 | | - temporal_compression_decode = spacial_compression_decode |
847 | 798 |
|
848 | 799 | class StyleModel: |
849 | 800 | def __init__(self, model, device="cpu"): |
|
0 commit comments