Skip to content

Commit 2dc7848

Browse files
authored
Merge pull request #4 from rupeshs/add-wuerstchen-support
Add wuerstchen support
2 parents 68d3e59 + bd4062b commit 2dc7848

23 files changed

+292
-39
lines changed

Readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ We can run StableDiffusion XL 1.0 on Google Colab
1212
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eEZ_O-Fw87hoEsfSxUnGZhdqvMFEO5iV?usp=sharing)
1313

1414
## Features
15+
- Supports Würstchen
1516
- Supports Stable diffusion XL
1617
- Supports various Stable Diffusion workflows
1718
- Text to Image

configs/stable_diffusion_models.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ lllyasviel/sd-controlnet-openpose
1717
lllyasviel/sd-controlnet-depth
1818
lllyasviel/sd-controlnet-scribble
1919
lllyasviel/sd-controlnet-seg
20-
stabilityai/stable-diffusion-xl-base-1.0
20+
stabilityai/stable-diffusion-xl-base-1.0
21+
warp-ai/wuerstchen

environment.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@ channels:
66
- defaults
77
dependencies:
88
- python=3.8.5
9-
- pip=20.3
9+
- pip=23.2.1
1010
- pytorch-cuda=11.7
11-
- pytorch=2.0.0
12-
- torchvision=0.15.0
11+
- pytorch=2.0.1
12+
- torchvision=0.15.2
1313
- numpy=1.19.2
1414
- pip:
15-
- accelerate==0.21.0
16-
- diffusers==0.19.3
15+
- accelerate==0.23.0
16+
- diffusers==0.21.1
1717
- gradio==3.39.0
1818
- safetensors==0.3.1
1919
- scipy==1.10.0
20-
- transformers==4.31.0
20+
- transformers==4.33.2
2121
- pydantic==1.10.4
2222
- mypy==1.0.0
2323
- black==23.1.0
2424
- flake8==6.0.0
25-
- markupsafe==2.0.1
25+
- markupsafe==2.1.3
2626
- opencv-contrib-python==4.7.0.72
2727
- controlnet-aux==0.0.1
2828
- invisible-watermark==0.2.0

src/app.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
from frontend.web.ui import diffusionmagic_web_ui
66
from settings import AppSettings
77

8-
# mypy --ignore-missing-imports --explicit-package-bases .
9-
# flake8 --max-line-length=100 .
8+
9+
def _get_model(model_id: str) -> str:
10+
if model_id == "":
11+
model_id = AppSettings().get_settings().model_settings.model_id
12+
return model_id
13+
14+
1015
if __name__ == "__main__":
1116
try:
1217
app_settings = AppSettings()
@@ -19,10 +24,27 @@
1924
parser.add_argument(
2025
"-s", "--share", help="Shareable link", action="store_true", default=False
2126
)
27+
parser.add_argument(
28+
"-m",
29+
"--model",
30+
help="Model identifier,E.g. runwayml/stable-diffusion-v1-5",
31+
default="",
32+
)
2233
args = parser.parse_args()
2334
compute = Computing()
24-
generate = Generate(compute)
25-
dm_web_ui = diffusionmagic_web_ui(generate)
35+
model_id = _get_model(args.model)
36+
37+
print(f"Model : {model_id}")
38+
39+
generate = Generate(
40+
compute,
41+
model_id,
42+
)
43+
44+
dm_web_ui = diffusionmagic_web_ui(
45+
generate,
46+
model_id,
47+
)
2648
if args.share:
2749
dm_web_ui.queue().launch(share=True)
2850
else:

src/backend/controlnet/controls/normal_control.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ def get_control_image(self, image: Image) -> Image:
1818

1919
bg_threhold = 0.4
2020

21-
x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
21+
x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3) # type: ignore
2222
x[image_depth < bg_threhold] = 0
2323

24-
y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
24+
y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3) # type: ignore
2525
y[image_depth < bg_threhold] = 0
2626

2727
z = np.ones_like(x) * np.pi * 2.0

src/backend/generate.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
StableDiffusionImageInstructPixToPixSetting,
1414
StableDiffusionControlnetSetting,
1515
)
16+
from backend.wuerstchen.models.setting import WurstchenSetting
1617
from backend.controlnet.ControlContext import ControlnetContext
1718
from backend.stablediffusion.stablediffusion import StableDiffusion
1819
from backend.stablediffusion.stablediffusionxl import StableDiffusionXl
20+
from backend.wuerstchen.wuerstchen import Wuerstchen
1921
from settings import AppSettings
2022

2123

2224
class Generate:
23-
def __init__(self, compute: Computing):
25+
def __init__(
26+
self,
27+
compute: Computing,
28+
model_id: str,
29+
):
2430
self.pipe_initialized = False
2531
self.inpaint_pipe_initialized = False
2632
self.depth_pipe_initialized = False
@@ -33,8 +39,9 @@ def __init__(self, compute: Computing):
3339
self.controlnet = ControlnetContext(compute)
3440
self.stable_diffusion_xl = StableDiffusionXl(compute)
3541
self.app_settings = AppSettings().get_settings()
36-
self.model_id = self.app_settings.model_settings.model_id
42+
self.model_id = model_id
3743
self.low_vram_mode = self.app_settings.low_memory_mode
44+
self.wuerstchen = Wuerstchen(compute)
3845

3946
def diffusion_text_to_image(
4047
self,
@@ -89,6 +96,15 @@ def _init_stable_diffusion_xl(self):
8996
)
9097
self.pipe_initialized = True
9198

99+
def _init_wuerstchen(self):
100+
if not self.pipe_initialized:
101+
print("Initializing wuerstchen pipeline")
102+
self.wuerstchen.get_text_to_image_wuerstchen_pipleline(
103+
self.model_id,
104+
self.low_vram_mode,
105+
)
106+
self.pipe_initialized = True
107+
92108
def diffusion_image_to_image(
93109
self,
94110
image,
@@ -479,3 +495,30 @@ def diffusion_image_variations_xl(
479495
"ImageVariations",
480496
)
481497
return images
498+
499+
def diffusion_text_to_image_wuerstchen(
500+
self,
501+
prompt,
502+
neg_prompt,
503+
image_height,
504+
image_width,
505+
guidance_scale,
506+
num_images,
507+
seed,
508+
) -> Any:
509+
wurstchen_settings = WurstchenSetting(
510+
prompt=prompt,
511+
negative_prompt=neg_prompt,
512+
image_height=image_height,
513+
image_width=image_width,
514+
prior_guidance_scale=guidance_scale,
515+
number_of_images=num_images,
516+
seed=seed,
517+
)
518+
self._init_wuerstchen()
519+
images = self.wuerstchen.text_to_image_wuerstchen(wurstchen_settings)
520+
self._save_images(
521+
images,
522+
"TextToImage",
523+
)
524+
return images

src/backend/stablediffusion/stable_diffusion_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ class StableDiffusionType(str, Enum):
1717
controlnet_scribble = "controlnet_scribble"
1818
controlnet_seg = "controlnet_seg"
1919
stable_diffusion_xl = "StableDiffusionXl"
20+
wuerstchen = "Wuerstchen"
2021

2122

2223
def get_diffusion_type(
2324
model_id: str,
2425
) -> StableDiffusionType:
2526
stable_diffusion_type = StableDiffusionType.base
27+
model_id = model_id.lower()
2628
if "inpainting" in model_id:
2729
stable_diffusion_type = StableDiffusionType.inpainting
2830
elif "instruct-pix2pix" in model_id:
@@ -47,4 +49,6 @@ def get_diffusion_type(
4749
stable_diffusion_type = StableDiffusionType.controlnet_seg
4850
elif "stable-diffusion-xl" in model_id:
4951
stable_diffusion_type = StableDiffusionType.stable_diffusion_xl
52+
elif "wuerstchen" in model_id:
53+
stable_diffusion_type = StableDiffusionType.wuerstchen
5054
return stable_diffusion_type

src/backend/stablediffusion/stablediffusionxl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from time import time
22

3-
import torch
3+
from torch import Generator
44
from diffusers import (
55
DiffusionPipeline,
66
StableDiffusionXLImg2ImgPipeline,
@@ -72,7 +72,7 @@ def text_to_image_xl(self, setting: StableDiffusionSetting):
7272
generator = None
7373
if setting.seed != -1:
7474
print(f"Using seed {setting.seed}")
75-
generator = torch.Generator(self.device).manual_seed(setting.seed)
75+
generator = Generator(self.device).manual_seed(setting.seed)
7676

7777
# if setting.attention_slicing:
7878
# self.pipeline.enable_attention_slicing()
@@ -149,7 +149,7 @@ def image_to_image(self, setting: StableDiffusionImageToImageSetting):
149149
generator = None
150150
if setting.seed != -1 and setting.seed:
151151
print(f"Using seed {setting.seed}")
152-
generator = torch.Generator(self.device).manual_seed(setting.seed)
152+
generator = Generator(self.device).manual_seed(setting.seed)
153153

154154
if setting.attention_slicing:
155155
self.img_to_img_pipeline.enable_attention_slicing() # type: ignore
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pydantic import BaseModel
2+
from typing import Optional
3+
4+
5+
class WurstchenSetting(BaseModel):
6+
prompt: str
7+
negative_prompt: Optional[str]
8+
image_height: Optional[int] = 512
9+
image_width: Optional[int] = 512
10+
prior_guidance_scale: Optional[float] = 4.0
11+
number_of_images: Optional[int] = 1
12+
seed: Optional[int] = -1

src/backend/wuerstchen/wuerstchen.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from time import time
2+
3+
from backend.computing import Computing
4+
from backend.wuerstchen.models.setting import WurstchenSetting
5+
from torch import Generator
6+
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
7+
from diffusers import AutoPipelineForText2Image
8+
9+
10+
class Wuerstchen:
11+
def __init__(self, compute: Computing):
12+
self.compute = compute
13+
self.pipeline = None
14+
self.device = self.compute.name
15+
super().__init__()
16+
17+
def get_text_to_image_wuerstchen_pipleline(
18+
self,
19+
model_id: str = "warp-ai/wuerstchen",
20+
low_vram_mode: bool = False,
21+
):
22+
self.model_id = model_id
23+
24+
self.low_vram_mode = low_vram_mode
25+
print(f"Wuerstchen - {self.compute.name},{self.compute.datatype}")
26+
print(f"using model {model_id}")
27+
tic = time()
28+
self._load_model()
29+
self._pipeline_to_device()
30+
delta = time() - tic
31+
print(f"Model loaded in {delta:.2f}s ")
32+
33+
def text_to_image_wuerstchen(self, setting: WurstchenSetting):
34+
if self.pipeline is None:
35+
raise Exception("Text to image pipeline not initialized")
36+
37+
generator = None
38+
if setting.seed != -1:
39+
print(f"Using seed {setting.seed}")
40+
generator = Generator(self.device).manual_seed(setting.seed)
41+
42+
images = self.pipeline(
43+
setting.prompt,
44+
negative_prompt=setting.negative_prompt,
45+
height=setting.image_height,
46+
width=setting.image_width,
47+
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
48+
prior_guidance_scale=setting.prior_guidance_scale,
49+
num_images_per_prompt=setting.number_of_images,
50+
generator=generator,
51+
).images
52+
53+
return images
54+
55+
def _pipeline_to_device(self):
56+
if self.low_vram_mode:
57+
print("Running in low VRAM mode,slower to generate images")
58+
self.pipeline.enable_sequential_cpu_offload()
59+
else:
60+
if self.compute.name == "cuda":
61+
self.pipeline = self.pipeline.to("cuda")
62+
elif self.compute.name == "mps":
63+
self.pipeline = self.pipeline.to("mps")
64+
65+
def _load_full_precision_model(self):
66+
self.pipeline = AutoPipelineForText2Image.from_pretrained(
67+
self.model_id,
68+
torch_dtype=self.compute.datatype,
69+
)
70+
71+
def _load_model(self):
72+
if self.compute.name == "cuda":
73+
try:
74+
self.pipeline = AutoPipelineForText2Image.from_pretrained(
75+
self.model_id,
76+
torch_dtype=self.compute.datatype,
77+
)
78+
except Exception as ex:
79+
print(
80+
f" The fp16 of the model not found using full precision model, {ex}"
81+
)
82+
self._load_full_precision_model()
83+
else:
84+
self._load_full_precision_model()

0 commit comments

Comments
 (0)