Skip to content

Commit 2b02225

Browse files
[From single file] Make accelerate optional (huggingface#4132)
* Make accelerate optional * make accelerate optional
1 parent 50b73eb commit 2b02225

File tree

1 file changed

+31
-19
lines changed

1 file changed

+31
-19
lines changed

pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
""" Conversion script for the Stable Diffusion checkpoints."""
1616

1717
import re
18+
from contextlib import nullcontext
1819
from io import BytesIO
1920
from typing import Optional
2021

@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
779780
config_name = "openai/clip-vit-large-patch14"
780781
config = CLIPTextConfig.from_pretrained(config_name)
781782

782-
with init_empty_weights():
783+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
784+
with ctx():
783785
text_model = CLIPTextModel(config)
784786

785787
keys = list(checkpoint.keys())
@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
793795
if key.startswith(prefix):
794796
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
795797

796-
for param_name, param in text_model_dict.items():
797-
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
798+
if is_accelerate_available():
799+
for param_name, param in text_model_dict.items():
800+
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
801+
else:
802+
text_model.load_state_dict(text_model_dict)
798803

799804
return text_model
800805

@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
900905
# )
901906
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
902907

903-
with init_empty_weights():
908+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
909+
with ctx():
904910
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
905911

906912
keys = list(checkpoint.keys())
@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
950956

951957
text_model_dict[new_key] = checkpoint[key]
952958

953-
for param_name, param in text_model_dict.items():
954-
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
959+
if is_accelerate_available():
960+
for param_name, param in text_model_dict.items():
961+
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
962+
else:
963+
text_model.load_state_dict(text_model_dict)
955964

956965
return text_model
957966

@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
11721181
StableUnCLIPPipeline,
11731182
)
11741183

1175-
if not is_accelerate_available():
1176-
raise ImportError(
1177-
"To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`."
1178-
)
1179-
11801184
if pipeline_class is None:
11811185
pipeline_class = StableDiffusionPipeline
11821186

@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
13461350
# Convert the UNet2DConditionModel model.
13471351
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
13481352
unet_config["upcast_attention"] = upcast_attention
1349-
with init_empty_weights():
1350-
unet = UNet2DConditionModel(**unet_config)
1351-
13521353
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
13531354
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
13541355
)
13551356

1356-
for param_name, param in converted_unet_checkpoint.items():
1357-
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
1357+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
1358+
with ctx():
1359+
unet = UNet2DConditionModel(**unet_config)
1360+
1361+
if is_accelerate_available():
1362+
for param_name, param in converted_unet_checkpoint.items():
1363+
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
1364+
else:
1365+
unet.load_state_dict(converted_unet_checkpoint)
13581366

13591367
# Convert the VAE model.
13601368
if vae_path is None:
@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
13721380

13731381
vae_config["scaling_factor"] = vae_scaling_factor
13741382

1375-
with init_empty_weights():
1383+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
1384+
with ctx():
13761385
vae = AutoencoderKL(**vae_config)
13771386

1378-
for param_name, param in converted_vae_checkpoint.items():
1379-
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
1387+
if is_accelerate_available():
1388+
for param_name, param in converted_vae_checkpoint.items():
1389+
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
1390+
else:
1391+
vae.load_state_dict(converted_vae_checkpoint)
13801392
else:
13811393
vae = AutoencoderKL.from_pretrained(vae_path)
13821394

0 commit comments

Comments
 (0)