Skip to content

Commit 46e48e8

Browse files
authored
Merge pull request huggingface#3 from dotieuthien/add-convert-tensorrt
Fix code quality
2 parents c2b584e + 105cc40 commit 46e48e8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

scripts/convert_stable_diffusion_controlnet_to_onnx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from packaging import version
2525
from torch.onnx import export
2626
from polygraphy.backend.onnx.loader import fold_constants
27-
from diffusers.models.cross_attention import CrossAttnProcessor
27+
from diffusers.models.attention_processor import AttnProcessor
2828
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
2929

3030

@@ -192,14 +192,14 @@ def convert_models(model_path: str, controlnet_path: list, output_path: str, ops
192192
for path in controlnet_path:
193193
controlnet = ControlNetModel.from_pretrained(path, torch_dtype=dtype).to(device)
194194
if is_torch_2_0_1:
195-
controlnet.set_attn_processor(CrossAttnProcessor())
195+
controlnet.set_attn_processor(AttnProcessor())
196196
controlnets.append(controlnet)
197197

198198
pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(model_path, controlnet=controlnets, torch_dtype=dtype).to(device)
199199
output_path = Path(output_path)
200200
if is_torch_2_0_1:
201-
pipeline.unet.set_attn_processor(CrossAttnProcessor())
202-
pipeline.vae.set_attn_processor(CrossAttnProcessor())
201+
pipeline.unet.set_attn_processor(AttnProcessor())
202+
pipeline.vae.set_attn_processor(AttnProcessor())
203203

204204
# TEXT ENCODER
205205
num_tokens = pipeline.text_encoder.config.max_position_embeddings

0 commit comments

Comments
 (0)