15
15
""" Conversion script for the Stable Diffusion checkpoints."""
16
16
17
17
import re
18
+ from contextlib import nullcontext
18
19
from io import BytesIO
19
20
from typing import Optional
20
21
@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
779
780
config_name = "openai/clip-vit-large-patch14"
780
781
config = CLIPTextConfig .from_pretrained (config_name )
781
782
782
- with init_empty_weights ():
783
+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
784
+ with ctx ():
783
785
text_model = CLIPTextModel (config )
784
786
785
787
keys = list (checkpoint .keys ())
@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
793
795
if key .startswith (prefix ):
794
796
text_model_dict [key [len (prefix + "." ) :]] = checkpoint [key ]
795
797
796
- for param_name , param in text_model_dict .items ():
797
- set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
798
+ if is_accelerate_available ():
799
+ for param_name , param in text_model_dict .items ():
800
+ set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
801
+ else :
802
+ text_model .load_state_dict (text_model_dict )
798
803
799
804
return text_model
800
805
@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
900
905
# )
901
906
config = CLIPTextConfig .from_pretrained (config_name , ** config_kwargs )
902
907
903
- with init_empty_weights ():
908
+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
909
+ with ctx ():
904
910
text_model = CLIPTextModelWithProjection (config ) if has_projection else CLIPTextModel (config )
905
911
906
912
keys = list (checkpoint .keys ())
@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
950
956
951
957
text_model_dict [new_key ] = checkpoint [key ]
952
958
953
- for param_name , param in text_model_dict .items ():
954
- set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
959
+ if is_accelerate_available ():
960
+ for param_name , param in text_model_dict .items ():
961
+ set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
962
+ else :
963
+ text_model .load_state_dict (text_model_dict )
955
964
956
965
return text_model
957
966
@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
1172
1181
StableUnCLIPPipeline ,
1173
1182
)
1174
1183
1175
- if not is_accelerate_available ():
1176
- raise ImportError (
1177
- "To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`."
1178
- )
1179
-
1180
1184
if pipeline_class is None :
1181
1185
pipeline_class = StableDiffusionPipeline
1182
1186
@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
1346
1350
# Convert the UNet2DConditionModel model.
1347
1351
unet_config = create_unet_diffusers_config (original_config , image_size = image_size )
1348
1352
unet_config ["upcast_attention" ] = upcast_attention
1349
- with init_empty_weights ():
1350
- unet = UNet2DConditionModel (** unet_config )
1351
-
1352
1353
converted_unet_checkpoint = convert_ldm_unet_checkpoint (
1353
1354
checkpoint , unet_config , path = checkpoint_path , extract_ema = extract_ema
1354
1355
)
1355
1356
1356
- for param_name , param in converted_unet_checkpoint .items ():
1357
- set_module_tensor_to_device (unet , param_name , "cpu" , value = param )
1357
+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
1358
+ with ctx ():
1359
+ unet = UNet2DConditionModel (** unet_config )
1360
+
1361
+ if is_accelerate_available ():
1362
+ for param_name , param in converted_unet_checkpoint .items ():
1363
+ set_module_tensor_to_device (unet , param_name , "cpu" , value = param )
1364
+ else :
1365
+ unet .load_state_dict (converted_unet_checkpoint )
1358
1366
1359
1367
# Convert the VAE model.
1360
1368
if vae_path is None :
@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
1372
1380
1373
1381
vae_config ["scaling_factor" ] = vae_scaling_factor
1374
1382
1375
- with init_empty_weights ():
1383
+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
1384
+ with ctx ():
1376
1385
vae = AutoencoderKL (** vae_config )
1377
1386
1378
- for param_name , param in converted_vae_checkpoint .items ():
1379
- set_module_tensor_to_device (vae , param_name , "cpu" , value = param )
1387
+ if is_accelerate_available ():
1388
+ for param_name , param in converted_vae_checkpoint .items ():
1389
+ set_module_tensor_to_device (vae , param_name , "cpu" , value = param )
1390
+ else :
1391
+ vae .load_state_dict (converted_vae_checkpoint )
1380
1392
else :
1381
1393
vae = AutoencoderKL .from_pretrained (vae_path )
1382
1394
0 commit comments