27
27
from invokeai .backend .stable_diffusion .diffusers_pipeline import image_resized_to_grid_as_tensor
28
28
from invokeai .backend .stable_diffusion .vae_tiling import patch_vae_tiling_params
29
29
from invokeai .backend .util .devices import TorchDevice
30
+ from invokeai .backend .util .vae_working_memory import estimate_vae_working_memory_sd15_sdxl
30
31
31
32
32
33
@invocation (
@@ -52,47 +53,23 @@ class ImageToLatentsInvocation(BaseInvocation):
52
53
tile_size : int = InputField (default = 0 , multiple_of = 8 , description = FieldDescriptions .vae_tile_size )
53
54
fp32 : bool = InputField (default = False , description = FieldDescriptions .fp32 )
54
55
55
- def _estimate_working_memory (
56
- self , image_tensor : torch .Tensor , use_tiling : bool , vae : AutoencoderKL | AutoencoderTiny
57
- ) -> int :
58
- """Estimate the working memory required by the invocation in bytes."""
59
- # Encode operations use approximately 50% of the memory required for decode operations
60
- element_size = 4 if self .fp32 else 2
61
- scaling_constant = 1100 # 50% of decode scaling constant (2200)
62
-
63
- if use_tiling :
64
- tile_size = self .tile_size
65
- if tile_size == 0 :
66
- tile_size = vae .tile_sample_min_size
67
- assert isinstance (tile_size , int )
68
- h = tile_size
69
- w = tile_size
70
- working_memory = h * w * element_size * scaling_constant
71
-
72
- # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
73
- # and number of tiles. We could make this more precise in the future, but this should be good enough for
74
- # most use cases.
75
- working_memory = working_memory * 1.25
76
- else :
77
- h = image_tensor .shape [- 2 ]
78
- w = image_tensor .shape [- 1 ]
79
- working_memory = h * w * element_size * scaling_constant
80
-
81
- if self .fp32 :
82
- # If we are running in FP32, then we should account for the likely increase in model size (~250MB).
83
- working_memory += 250 * 2 ** 20
84
-
85
- return int (working_memory )
86
-
87
- @staticmethod
56
+ @classmethod
88
57
def vae_encode (
58
+ cls ,
89
59
vae_info : LoadedModel ,
90
60
upcast : bool ,
91
61
tiled : bool ,
92
62
image_tensor : torch .Tensor ,
93
63
tile_size : int = 0 ,
94
- estimated_working_memory : int = 0 ,
95
64
) -> torch .Tensor :
65
+ assert isinstance (vae_info .model , (AutoencoderKL , AutoencoderTiny ))
66
+ estimated_working_memory = estimate_vae_working_memory_sd15_sdxl (
67
+ operation = "encode" ,
68
+ image_tensor = image_tensor ,
69
+ vae = vae_info .model ,
70
+ tile_size = tile_size if tiled else None ,
71
+ fp32 = upcast ,
72
+ )
96
73
with vae_info .model_on_device (working_mem_bytes = estimated_working_memory ) as (_ , vae ):
97
74
assert isinstance (vae , (AutoencoderKL , AutoencoderTiny ))
98
75
orig_dtype = vae .dtype
@@ -156,17 +133,13 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
156
133
if image_tensor .dim () == 3 :
157
134
image_tensor = einops .rearrange (image_tensor , "c h w -> 1 c h w" )
158
135
159
- use_tiling = self .tiled or context .config .get ().force_tiled_decode
160
- estimated_working_memory = self ._estimate_working_memory (image_tensor , use_tiling , vae_info .model )
161
-
162
136
context .util .signal_progress ("Running VAE encoder" )
163
137
latents = self .vae_encode (
164
138
vae_info = vae_info ,
165
139
upcast = self .fp32 ,
166
- tiled = self .tiled ,
140
+ tiled = self .tiled or context . config . get (). force_tiled_decode ,
167
141
image_tensor = image_tensor ,
168
142
tile_size = self .tile_size ,
169
- estimated_working_memory = estimated_working_memory ,
170
143
)
171
144
172
145
latents = latents .to ("cpu" )
0 commit comments