Skip to content

Commit 9c3d4b3

Browse files
refactor: estimate working vae memory during encode/decode
- Move the estimation logic to utility functions - Estimate memory _within_ the encode and decode methods, ensuring we _always_ estimate working memory when running a VAE
1 parent 1f8a60d commit 9c3d4b3

9 files changed

+171
-152
lines changed

invokeai/app/invocations/cogview4_image_to_latents.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from invokeai.backend.model_manager.load.load_base import LoadedModel
1818
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
1919
from invokeai.backend.util.devices import TorchDevice
20+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
2021

2122
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
2223
# should refactor to avoid this duplication.
@@ -36,18 +37,12 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
3637
image: ImageField = InputField(description="The image to encode.")
3738
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
3839

39-
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
40-
"""Estimate the working memory required by the invocation in bytes."""
41-
# Encode operations use approximately 50% of the memory required for decode operations
42-
h = image_tensor.shape[-2]
43-
w = image_tensor.shape[-1]
44-
element_size = next(vae.parameters()).element_size()
45-
scaling_constant = 1100 # 50% of decode scaling constant (2200)
46-
working_memory = h * w * element_size * scaling_constant
47-
return int(working_memory)
48-
4940
@staticmethod
50-
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
41+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
42+
assert isinstance(vae_info.model, AutoencoderKL)
43+
estimated_working_memory = estimate_vae_working_memory_cogview4(
44+
operation="encode", image_tensor=image_tensor, vae=vae_info.model
45+
)
5146
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
5247
assert isinstance(vae, AutoencoderKL)
5348

@@ -74,10 +69,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
7469
vae_info = context.models.load(self.vae.vae)
7570
assert isinstance(vae_info.model, AutoencoderKL)
7671

77-
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
78-
latents = self.vae_encode(
79-
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
80-
)
72+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
8173

8274
latents = latents.to("cpu")
8375
name = context.tensors.save(tensor=latents)

invokeai/app/invocations/cogview4_latents_to_image.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from PIL import Image
77

88
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
9-
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
109
from invokeai.app.invocations.fields import (
1110
FieldDescriptions,
1211
Input,
@@ -20,6 +19,7 @@
2019
from invokeai.app.services.shared.invocation_context import InvocationContext
2120
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
2221
from invokeai.backend.util.devices import TorchDevice
22+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
2323

2424
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
2525
# should refactor to avoid this duplication.
@@ -39,22 +39,15 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
3939
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
4040
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
4141

42-
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
43-
"""Estimate the working memory required by the invocation in bytes."""
44-
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
45-
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
46-
element_size = next(vae.parameters()).element_size()
47-
scaling_constant = 2200 # Determined experimentally.
48-
working_memory = out_h * out_w * element_size * scaling_constant
49-
return int(working_memory)
50-
5142
@torch.no_grad()
5243
def invoke(self, context: InvocationContext) -> ImageOutput:
5344
latents = context.tensors.load(self.latents.latents_name)
5445

5546
vae_info = context.models.load(self.vae.vae)
5647
assert isinstance(vae_info.model, (AutoencoderKL))
57-
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
48+
estimated_working_memory = estimate_vae_working_memory_cogview4(
49+
operation="decode", image_tensor=latents, vae=vae_info.model
50+
)
5851
with (
5952
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
6053
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

invokeai/app/invocations/flux_vae_decode.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from PIL import Image
44

55
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
6-
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
76
from invokeai.app.invocations.fields import (
87
FieldDescriptions,
98
Input,
@@ -18,6 +17,7 @@
1817
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
1918
from invokeai.backend.model_manager.load.load_base import LoadedModel
2019
from invokeai.backend.util.devices import TorchDevice
20+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
2121

2222

2323
@invocation(
@@ -39,17 +39,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
3939
input=Input.Connection,
4040
)
4141

42-
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
43-
"""Estimate the working memory required by the invocation in bytes."""
44-
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
45-
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
46-
element_size = next(vae.parameters()).element_size()
47-
scaling_constant = 2200 # Determined experimentally.
48-
working_memory = out_h * out_w * element_size * scaling_constant
49-
return int(working_memory)
50-
5142
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
52-
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
43+
assert isinstance(vae_info.model, AutoEncoder)
44+
estimated_working_memory = estimate_vae_working_memory_flux(
45+
operation="decode", image_tensor=latents, vae=vae_info.model
46+
)
5347
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
5448
assert isinstance(vae, AutoEncoder)
5549
vae_dtype = next(iter(vae.parameters())).dtype

invokeai/app/invocations/flux_vae_encode.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from invokeai.backend.model_manager import LoadedModel
1616
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
1717
from invokeai.backend.util.devices import TorchDevice
18+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
1819

1920

2021
@invocation(
@@ -35,22 +36,16 @@ class FluxVaeEncodeInvocation(BaseInvocation):
3536
input=Input.Connection,
3637
)
3738

38-
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int:
39-
"""Estimate the working memory required by the invocation in bytes."""
40-
# Encode operations use approximately 50% of the memory required for decode operations
41-
h = image_tensor.shape[-2]
42-
w = image_tensor.shape[-1]
43-
element_size = next(vae.parameters()).element_size()
44-
scaling_constant = 1100 # 50% of decode scaling constant (2200)
45-
working_memory = h * w * element_size * scaling_constant
46-
return int(working_memory)
47-
4839
@staticmethod
49-
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
40+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
5041
# TODO(ryand): Expose seed parameter at the invocation level.
5142
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
5243
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
5344
# should be used for VAE encode sampling.
45+
assert isinstance(vae_info.model, AutoEncoder)
46+
estimated_working_memory = estimate_vae_working_memory_flux(
47+
operation="encode", image_tensor=image_tensor, vae=vae_info.model
48+
)
5449
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
5550
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
5651
assert isinstance(vae, AutoEncoder)
@@ -70,10 +65,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
7065
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
7166

7267
context.util.signal_progress("Running VAE")
73-
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
74-
latents = self.vae_encode(
75-
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
76-
)
68+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
7769

7870
latents = latents.to("cpu")
7971
name = context.tensors.save(tensor=latents)

invokeai/app/invocations/image_to_latents.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
2828
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
2929
from invokeai.backend.util.devices import TorchDevice
30+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
3031

3132

3233
@invocation(
@@ -52,47 +53,23 @@ class ImageToLatentsInvocation(BaseInvocation):
5253
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
5354
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
5455

55-
def _estimate_working_memory(
56-
self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
57-
) -> int:
58-
"""Estimate the working memory required by the invocation in bytes."""
59-
# Encode operations use approximately 50% of the memory required for decode operations
60-
element_size = 4 if self.fp32 else 2
61-
scaling_constant = 1100 # 50% of decode scaling constant (2200)
62-
63-
if use_tiling:
64-
tile_size = self.tile_size
65-
if tile_size == 0:
66-
tile_size = vae.tile_sample_min_size
67-
assert isinstance(tile_size, int)
68-
h = tile_size
69-
w = tile_size
70-
working_memory = h * w * element_size * scaling_constant
71-
72-
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
73-
# and number of tiles. We could make this more precise in the future, but this should be good enough for
74-
# most use cases.
75-
working_memory = working_memory * 1.25
76-
else:
77-
h = image_tensor.shape[-2]
78-
w = image_tensor.shape[-1]
79-
working_memory = h * w * element_size * scaling_constant
80-
81-
if self.fp32:
82-
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
83-
working_memory += 250 * 2**20
84-
85-
return int(working_memory)
86-
87-
@staticmethod
56+
@classmethod
8857
def vae_encode(
58+
cls,
8959
vae_info: LoadedModel,
9060
upcast: bool,
9161
tiled: bool,
9262
image_tensor: torch.Tensor,
9363
tile_size: int = 0,
94-
estimated_working_memory: int = 0,
9564
) -> torch.Tensor:
65+
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
66+
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
67+
operation="encode",
68+
image_tensor=image_tensor,
69+
vae=vae_info.model,
70+
tile_size=tile_size if tiled else None,
71+
fp32=upcast,
72+
)
9673
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
9774
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
9875
orig_dtype = vae.dtype
@@ -156,17 +133,13 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
156133
if image_tensor.dim() == 3:
157134
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
158135

159-
use_tiling = self.tiled or context.config.get().force_tiled_decode
160-
estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model)
161-
162136
context.util.signal_progress("Running VAE encoder")
163137
latents = self.vae_encode(
164138
vae_info=vae_info,
165139
upcast=self.fp32,
166-
tiled=self.tiled,
140+
tiled=self.tiled or context.config.get().force_tiled_decode,
167141
image_tensor=image_tensor,
168142
tile_size=self.tile_size,
169-
estimated_working_memory=estimated_working_memory,
170143
)
171144

172145
latents = latents.to("cpu")

invokeai/app/invocations/latents_to_image.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
2828
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
2929
from invokeai.backend.util.devices import TorchDevice
30+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
3031

3132

3233
@invocation(
@@ -53,39 +54,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5354
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
5455
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
5556

56-
def _estimate_working_memory(
57-
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
58-
) -> int:
59-
"""Estimate the working memory required by the invocation in bytes."""
60-
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
61-
# element size (precision). This estimate is accurate for both SD1 and SDXL.
62-
element_size = 4 if self.fp32 else 2
63-
scaling_constant = 2200 # Determined experimentally.
64-
65-
if use_tiling:
66-
tile_size = self.tile_size
67-
if tile_size == 0:
68-
tile_size = vae.tile_sample_min_size
69-
assert isinstance(tile_size, int)
70-
out_h = tile_size
71-
out_w = tile_size
72-
working_memory = out_h * out_w * element_size * scaling_constant
73-
74-
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
75-
# and number of tiles. We could make this more precise in the future, but this should be good enough for
76-
# most use cases.
77-
working_memory = working_memory * 1.25
78-
else:
79-
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
80-
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
81-
working_memory = out_h * out_w * element_size * scaling_constant
82-
83-
if self.fp32:
84-
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
85-
working_memory += 250 * 2**20
86-
87-
return int(working_memory)
88-
8957
@torch.no_grad()
9058
def invoke(self, context: InvocationContext) -> ImageOutput:
9159
latents = context.tensors.load(self.latents.latents_name)
@@ -94,8 +62,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
9462

9563
vae_info = context.models.load(self.vae.vae)
9664
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
97-
98-
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
65+
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
66+
operation="decode",
67+
image_tensor=latents,
68+
vae=vae_info.model,
69+
tile_size=self.tile_size if use_tiling else None,
70+
fp32=self.fp32,
71+
)
9972
with (
10073
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
10174
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

invokeai/app/invocations/sd3_image_to_latents.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from invokeai.backend.model_manager.load.load_base import LoadedModel
1818
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
1919
from invokeai.backend.util.devices import TorchDevice
20+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
2021

2122

2223
@invocation(
@@ -32,18 +33,12 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
3233
image: ImageField = InputField(description="The image to encode")
3334
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
3435

35-
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
36-
"""Estimate the working memory required by the invocation in bytes."""
37-
# Encode operations use approximately 50% of the memory required for decode operations
38-
h = image_tensor.shape[-2]
39-
w = image_tensor.shape[-1]
40-
element_size = next(vae.parameters()).element_size()
41-
scaling_constant = 1100 # 50% of decode scaling constant (2200)
42-
working_memory = h * w * element_size * scaling_constant
43-
return int(working_memory)
44-
4536
@staticmethod
46-
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
37+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
38+
assert isinstance(vae_info.model, AutoencoderKL)
39+
estimated_working_memory = estimate_vae_working_memory_sd3(
40+
operation="encode", image_tensor=image_tensor, vae=vae_info.model
41+
)
4742
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
4843
assert isinstance(vae, AutoencoderKL)
4944

@@ -70,10 +65,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
7065
vae_info = context.models.load(self.vae.vae)
7166
assert isinstance(vae_info.model, AutoencoderKL)
7267

73-
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
74-
latents = self.vae_encode(
75-
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
76-
)
68+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
7769

7870
latents = latents.to("cpu")
7971
name = context.tensors.save(tensor=latents)

0 commit comments

Comments
 (0)