@@ -411,6 +411,7 @@ def prep_control_data(
411
411
context : InvocationContext ,
412
412
control_input : ControlField | list [ControlField ] | None ,
413
413
latents_shape : List [int ],
414
+ device : torch .device ,
414
415
exit_stack : ExitStack ,
415
416
do_classifier_free_guidance : bool = True ,
416
417
) -> list [ControlNetData ] | None :
@@ -452,7 +453,7 @@ def prep_control_data(
452
453
height = control_height_resize ,
453
454
# batch_size=batch_size * num_images_per_prompt,
454
455
# num_images_per_prompt=num_images_per_prompt,
455
- device = control_model . device ,
456
+ device = device ,
456
457
dtype = control_model .dtype ,
457
458
control_mode = control_info .control_mode ,
458
459
resize_mode = control_info .resize_mode ,
@@ -605,6 +606,7 @@ def run_t2i_adapters(
605
606
context : InvocationContext ,
606
607
t2i_adapter : Optional [Union [T2IAdapterField , list [T2IAdapterField ]]],
607
608
latents_shape : list [int ],
609
+ device : torch .device ,
608
610
do_classifier_free_guidance : bool ,
609
611
) -> Optional [list [T2IAdapterData ]]:
610
612
if t2i_adapter is None :
@@ -655,7 +657,7 @@ def run_t2i_adapters(
655
657
width = control_width_resize ,
656
658
height = control_height_resize ,
657
659
num_channels = t2i_adapter_model .config ["in_channels" ], # mypy treats this as a FrozenDict
658
- device = t2i_adapter_model . device ,
660
+ device = device ,
659
661
dtype = t2i_adapter_model .dtype ,
660
662
resize_mode = t2i_adapter_field .resize_mode ,
661
663
)
@@ -946,6 +948,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
946
948
@torch .no_grad ()
947
949
@SilenceWarnings () # This quenches the NSFW nag from diffusers.
948
950
def _old_invoke (self , context : InvocationContext ) -> LatentsOutput :
951
+ device = TorchDevice .choose_torch_device ()
949
952
seed , noise , latents = self .prepare_noise_and_latents (context , self .noise , self .latents )
950
953
951
954
mask , masked_latents , gradient_mask = self .prep_inpaint_mask (context , latents )
@@ -960,6 +963,7 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
960
963
context ,
961
964
self .t2i_adapter ,
962
965
latents .shape ,
966
+ device = device ,
963
967
do_classifier_free_guidance = True ,
964
968
)
965
969
@@ -1006,13 +1010,13 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
1006
1010
),
1007
1011
):
1008
1012
assert isinstance (unet , UNet2DConditionModel )
1009
- latents = latents .to (device = unet . device , dtype = unet .dtype )
1013
+ latents = latents .to (device = device , dtype = unet .dtype )
1010
1014
if noise is not None :
1011
- noise = noise .to (device = unet . device , dtype = unet .dtype )
1015
+ noise = noise .to (device = device , dtype = unet .dtype )
1012
1016
if mask is not None :
1013
- mask = mask .to (device = unet . device , dtype = unet .dtype )
1017
+ mask = mask .to (device = device , dtype = unet .dtype )
1014
1018
if masked_latents is not None :
1015
- masked_latents = masked_latents .to (device = unet . device , dtype = unet .dtype )
1019
+ masked_latents = masked_latents .to (device = device , dtype = unet .dtype )
1016
1020
1017
1021
scheduler = get_scheduler (
1018
1022
context = context ,
@@ -1028,7 +1032,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
1028
1032
context = context ,
1029
1033
positive_conditioning_field = self .positive_conditioning ,
1030
1034
negative_conditioning_field = self .negative_conditioning ,
1031
- device = unet . device ,
1035
+ device = device ,
1032
1036
dtype = unet .dtype ,
1033
1037
latent_height = latent_height ,
1034
1038
latent_width = latent_width ,
@@ -1041,6 +1045,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
1041
1045
context = context ,
1042
1046
control_input = self .control ,
1043
1047
latents_shape = latents .shape ,
1048
+ device = device ,
1044
1049
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
1045
1050
do_classifier_free_guidance = True ,
1046
1051
exit_stack = exit_stack ,
@@ -1058,7 +1063,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
1058
1063
1059
1064
timesteps , init_timestep , scheduler_step_kwargs = self .init_scheduler (
1060
1065
scheduler ,
1061
- device = unet . device ,
1066
+ device = device ,
1062
1067
steps = self .steps ,
1063
1068
denoising_start = self .denoising_start ,
1064
1069
denoising_end = self .denoising_end ,
0 commit comments