Skip to content

Commit 5d36c1c

Browse files
committed
We should not trust the value of since the model could be partially-loaded.
1 parent 6b18f27 commit 5d36c1c

File tree

9 files changed

+36
-24
lines changed

9 files changed

+36
-24
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def prep_control_data(
411411
context: InvocationContext,
412412
control_input: ControlField | list[ControlField] | None,
413413
latents_shape: List[int],
414+
device: torch.device,
414415
exit_stack: ExitStack,
415416
do_classifier_free_guidance: bool = True,
416417
) -> list[ControlNetData] | None:
@@ -452,7 +453,7 @@ def prep_control_data(
452453
height=control_height_resize,
453454
# batch_size=batch_size * num_images_per_prompt,
454455
# num_images_per_prompt=num_images_per_prompt,
455-
device=control_model.device,
456+
device=device,
456457
dtype=control_model.dtype,
457458
control_mode=control_info.control_mode,
458459
resize_mode=control_info.resize_mode,
@@ -605,6 +606,7 @@ def run_t2i_adapters(
605606
context: InvocationContext,
606607
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
607608
latents_shape: list[int],
609+
device: torch.device,
608610
do_classifier_free_guidance: bool,
609611
) -> Optional[list[T2IAdapterData]]:
610612
if t2i_adapter is None:
@@ -655,7 +657,7 @@ def run_t2i_adapters(
655657
width=control_width_resize,
656658
height=control_height_resize,
657659
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
658-
device=t2i_adapter_model.device,
660+
device=device,
659661
dtype=t2i_adapter_model.dtype,
660662
resize_mode=t2i_adapter_field.resize_mode,
661663
)
@@ -946,6 +948,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
946948
@torch.no_grad()
947949
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
948950
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
951+
device = TorchDevice.choose_torch_device()
949952
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
950953

951954
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@@ -960,6 +963,7 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
960963
context,
961964
self.t2i_adapter,
962965
latents.shape,
966+
device=device,
963967
do_classifier_free_guidance=True,
964968
)
965969

@@ -1006,13 +1010,13 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
10061010
),
10071011
):
10081012
assert isinstance(unet, UNet2DConditionModel)
1009-
latents = latents.to(device=unet.device, dtype=unet.dtype)
1013+
latents = latents.to(device=device, dtype=unet.dtype)
10101014
if noise is not None:
1011-
noise = noise.to(device=unet.device, dtype=unet.dtype)
1015+
noise = noise.to(device=device, dtype=unet.dtype)
10121016
if mask is not None:
1013-
mask = mask.to(device=unet.device, dtype=unet.dtype)
1017+
mask = mask.to(device=device, dtype=unet.dtype)
10141018
if masked_latents is not None:
1015-
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
1019+
masked_latents = masked_latents.to(device=device, dtype=unet.dtype)
10161020

10171021
scheduler = get_scheduler(
10181022
context=context,
@@ -1028,7 +1032,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
10281032
context=context,
10291033
positive_conditioning_field=self.positive_conditioning,
10301034
negative_conditioning_field=self.negative_conditioning,
1031-
device=unet.device,
1035+
device=device,
10321036
dtype=unet.dtype,
10331037
latent_height=latent_height,
10341038
latent_width=latent_width,
@@ -1041,6 +1045,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
10411045
context=context,
10421046
control_input=self.control,
10431047
latents_shape=latents.shape,
1048+
device=device,
10441049
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
10451050
do_classifier_free_guidance=True,
10461051
exit_stack=exit_stack,
@@ -1058,7 +1063,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
10581063

10591064
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
10601065
scheduler,
1061-
device=unet.device,
1066+
device=device,
10621067
steps=self.steps,
10631068
denoising_start=self.denoising_start,
10641069
denoising_end=self.denoising_end,

invokeai/app/invocations/flux_denoise.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _run_diffusion(
276276
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
277277
ip_adapter_fields = self._normalize_ip_adapter_fields()
278278
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
279-
ip_adapter_fields, context
279+
ip_adapter_fields, context, device=x.device
280280
)
281281

282282
cfg_scale = self.prep_cfg_scale(
@@ -626,6 +626,7 @@ def _prep_ip_adapter_image_prompt_clip_embeds(
626626
self,
627627
ip_adapter_fields: list[IPAdapterField],
628628
context: InvocationContext,
629+
device: torch.device,
629630
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
630631
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
631632
clip_image_processor = CLIPImageProcessor()
@@ -665,11 +666,11 @@ def _prep_ip_adapter_image_prompt_clip_embeds(
665666
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
666667

667668
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
668-
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
669+
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
669670
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
670671

671672
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
672-
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
673+
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
673674
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
674675

675676
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)

invokeai/app/invocations/image_to_latents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from invokeai.backend.model_manager import LoadedModel
2727
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
2828
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
29+
from invokeai.backend.util.devices import TorchDevice
2930

3031

3132
@invocation(
@@ -98,7 +99,7 @@ def vae_encode(
9899
)
99100

100101
# non_noised_latents_from_image
101-
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
102+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
102103
with torch.inference_mode(), tiling_context:
103104
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
104105

invokeai/app/invocations/sd3_image_to_latents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from invokeai.app.services.shared.invocation_context import InvocationContext
1717
from invokeai.backend.model_manager.load.load_base import LoadedModel
1818
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
19+
from invokeai.backend.util.devices import TorchDevice
1920

2021

2122
@invocation(
@@ -39,7 +40,7 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso
3940

4041
vae.disable_tiling()
4142

42-
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
43+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
4344
with torch.inference_mode():
4445
image_tensor_dist = vae.encode(image_tensor).latent_dist
4546
# TODO: Use seed to make sampling reproducible.

invokeai/app/invocations/spandrel_image_to_image.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
2323
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
2424
from invokeai.backend.tiles.utils import TBLR, Tile
25+
from invokeai.backend.util.devices import TorchDevice
2526

2627

2728
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
@@ -102,7 +103,7 @@ def upscale_image(
102103
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
103104
)
104105

105-
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
106+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=spandrel_model.dtype)
106107

107108
# Run the model on each tile.
108109
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
@@ -116,9 +117,7 @@ def upscale_image(
116117
raise CanceledException
117118

118119
# Extract the current tile from the input tensor.
119-
input_tile = image_tensor[
120-
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
121-
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
120+
input_tile = image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
122121

123122
# Run the model on the tile.
124123
output_tile = spandrel_model.run(input_tile)

invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
201201
yield (lora_info.model, lora.weight)
202202
del lora_info
203203

204+
device = TorchDevice.choose_torch_device()
204205
with (
205206
ExitStack() as exit_stack,
206207
context.models.load(self.unet.unet) as unet,
@@ -209,9 +210,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
209210
),
210211
):
211212
assert isinstance(unet, UNet2DConditionModel)
212-
latents = latents.to(device=unet.device, dtype=unet.dtype)
213+
latents = latents.to(device=device, dtype=unet.dtype)
213214
if noise is not None:
214-
noise = noise.to(device=unet.device, dtype=unet.dtype)
215+
noise = noise.to(device=device, dtype=unet.dtype)
215216
scheduler = get_scheduler(
216217
context=context,
217218
scheduler_info=self.unet.scheduler,
@@ -225,7 +226,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
225226
context=context,
226227
positive_conditioning_field=self.positive_conditioning,
227228
negative_conditioning_field=self.negative_conditioning,
228-
device=unet.device,
229+
device=device,
229230
dtype=unet.dtype,
230231
latent_height=latent_tile_height,
231232
latent_width=latent_tile_width,
@@ -238,6 +239,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
238239
context=context,
239240
control_input=self.control,
240241
latents_shape=list(latents.shape),
242+
device=device,
241243
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
242244
do_classifier_free_guidance=True,
243245
exit_stack=exit_stack,
@@ -263,7 +265,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
263265

264266
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
265267
scheduler,
266-
device=unet.device,
268+
device=device,
267269
steps=self.steps,
268270
denoising_start=self.denoising_start,
269271
denoising_end=self.denoising_end,

invokeai/backend/flux/extensions/xlabs_ip_adapter_extension.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
1010
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
11+
from invokeai.backend.util.devices import TorchDevice
1112

1213

1314
class XLabsIPAdapterExtension:
@@ -45,7 +46,7 @@ def run_clip_image_encoder(
4546
) -> torch.Tensor:
4647
clip_image_processor = CLIPImageProcessor()
4748
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
48-
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
49+
clip_image = clip_image.to(device=TorchDevice.choose_torch_device(), dtype=image_encoder.dtype)
4950
clip_image_embeds = image_encoder(clip_image).image_embeds
5051
return clip_image_embeds
5152

invokeai/backend/model_patcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from invokeai.app.shared.models import FreeUConfig
1515
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
1616
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
17+
from invokeai.backend.util.devices import TorchDevice
1718

1819

1920
class ModelPatcher:
@@ -122,7 +123,7 @@ def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionMod
122123
)
123124

124125
model_embeddings.weight.data[token_id] = embedding.to(
125-
device=text_encoder.device, dtype=text_encoder.dtype
126+
device=TorchDevice.choose_torch_device(), dtype=text_encoder.dtype
126127
)
127128
ti_tokens.append(token_id)
128129

invokeai/backend/stable_diffusion/extensions/t2i_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
1313
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
1414
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
15+
from invokeai.backend.util.devices import TorchDevice
1516

1617
if TYPE_CHECKING:
1718
from invokeai.app.invocations.model import ModelIdentifierField
@@ -89,7 +90,7 @@ def _run_model(
8990
width=input_width,
9091
height=input_height,
9192
num_channels=model.config["in_channels"],
92-
device=model.device,
93+
device=TorchDevice.choose_torch_device(),
9394
dtype=model.dtype,
9495
resize_mode=self._resize_mode,
9596
)

0 commit comments

Comments
 (0)