Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions invokeai/app/invocations/cogview4_image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4

# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
# should refactor to avoid this duplication.
Expand All @@ -36,18 +37,12 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
working_memory = h * w * element_size * scaling_constant
return int(working_memory)

@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = estimate_vae_working_memory_cogview4(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoencoderKL)

Expand All @@ -74,10 +69,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKL)

estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
latents = self.vae_encode(
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
Expand Down
15 changes: 4 additions & 11 deletions invokeai/app/invocations/cogview4_latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from PIL import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
Expand All @@ -20,6 +19,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4

# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
# should refactor to avoid this duplication.
Expand All @@ -39,22 +39,15 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
estimated_working_memory = estimate_vae_working_memory_cogview4(
operation="decode", image_tensor=latents, vae=vae_info.model
)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
Expand Down
16 changes: 5 additions & 11 deletions invokeai/app/invocations/flux_vae_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from PIL import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
Expand All @@ -18,6 +17,7 @@
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux


@invocation(
Expand All @@ -39,17 +39,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)

def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)

def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
assert isinstance(vae_info.model, AutoEncoder)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="decode", image_tensor=latents, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
Expand Down
22 changes: 7 additions & 15 deletions invokeai/app/invocations/flux_vae_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux


@invocation(
Expand All @@ -35,22 +36,16 @@ class FluxVaeEncodeInvocation(BaseInvocation):
input=Input.Connection,
)

def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
working_memory = h * w * element_size * scaling_constant
return int(working_memory)

@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Expose seed parameter at the invocation level.
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
assert isinstance(vae_info.model, AutoEncoder)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
Expand All @@ -70,10 +65,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

context.util.signal_progress("Running VAE")
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
latents = self.vae_encode(
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
Expand Down
51 changes: 12 additions & 39 deletions invokeai/app/invocations/image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl


@invocation(
Expand All @@ -52,47 +53,23 @@ class ImageToLatentsInvocation(BaseInvocation):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)

def _estimate_working_memory(
self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
element_size = 4 if self.fp32 else 2
scaling_constant = 1100 # 50% of decode scaling constant (2200)

if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
h = tile_size
w = tile_size
working_memory = h * w * element_size * scaling_constant

# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
working_memory = h * w * element_size * scaling_constant

if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20

return int(working_memory)

@staticmethod
@classmethod
def vae_encode(
cls,
vae_info: LoadedModel,
upcast: bool,
tiled: bool,
image_tensor: torch.Tensor,
tile_size: int = 0,
estimated_working_memory: int = 0,
) -> torch.Tensor:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
tile_size=tile_size if tiled else None,
fp32=upcast,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
Expand Down Expand Up @@ -156,17 +133,13 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

use_tiling = self.tiled or context.config.get().force_tiled_decode
estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model)

context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info,
upcast=self.fp32,
tiled=self.tiled,
tiled=self.tiled or context.config.get().force_tiled_decode,
image_tensor=image_tensor,
tile_size=self.tile_size,
estimated_working_memory=estimated_working_memory,
)

latents = latents.to("cpu")
Expand Down
43 changes: 8 additions & 35 deletions invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl


@invocation(
Expand All @@ -53,39 +54,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)

def _estimate_working_memory(
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 2200 # Determined experimentally.

if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
out_h = tile_size
out_w = tile_size
working_memory = out_h * out_w * element_size * scaling_constant

# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
working_memory = out_h * out_w * element_size * scaling_constant

if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20

return int(working_memory)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
Expand All @@ -94,8 +62,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))

estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
tile_size=self.tile_size if use_tiling else None,
fp32=self.fp32,
)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
Expand Down
22 changes: 7 additions & 15 deletions invokeai/app/invocations/sd3_image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3


@invocation(
Expand All @@ -32,18 +33,12 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to encode")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
working_memory = h * w * element_size * scaling_constant
return int(working_memory)

@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = estimate_vae_working_memory_sd3(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoencoderKL)

Expand All @@ -70,10 +65,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKL)

estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
latents = self.vae_encode(
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
Expand Down
Loading