Skip to content

Commit

Permalink
Improved sampling (Stability-AI#69)
Browse files Browse the repository at this point in the history
* New research features

* Add new model specs
---------

Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>

* remove sd1.5 and change default refiner to 1.0

* remove asking second time for output

* adapt model names

* adjusted strength

* Correctly pass prompt

---------

Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>
  • Loading branch information
2 people authored and LinearFalcon committed Jul 6, 2024
1 parent b8ec91e commit 4977d76
Show file tree
Hide file tree
Showing 2 changed files with 514 additions and 93 deletions.
129 changes: 72 additions & 57 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import numpy as np
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)

SAVE_PATH = "outputs/demo/txt2img/"

Expand Down Expand Up @@ -42,27 +34,34 @@
}

VERSION2SPECS = {
"SD-XL base": {
"SDXL-base-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
},
"SDXL-base-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
"is_guided": True,
},
"sd-2.1": {
"SD-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
"is_guided": True,
},
"sd-2.1-768": {
"SD-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
Expand All @@ -71,15 +70,23 @@
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-Refiner": {
"SDXL-refiner-0.9": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
"is_guided": True,
},
"SDXL-refiner-1.0": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
},
}

Expand All @@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):


def run_txt2img(
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
state,
version,
version_dict,
is_legacy=False,
return_latents=False,
filter=None,
stage2strength=None,
):
if version == "SD-XL base":
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
W, H = SD_XL_BASE_RATIOS[ratio]
if version.startswith("SDXL-base"):
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
else:
H = st.sidebar.number_input(
"H", value=version_dict["H"], min_value=64, max_value=2048
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
)
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
C = version_dict["C"]
F = version_dict["f"]

Expand All @@ -130,16 +138,11 @@ def run_txt2img(
prompt=prompt,
negative_prompt=negative_prompt,
)
num_rows, num_cols, sampler = init_sampling(
use_identity_guider=not version_dict["is_guided"]
)

sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
num_samples = num_rows * num_cols

if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
Expand All @@ -153,13 +156,16 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)

return out


def run_img2img(
state, version_dict, is_legacy=False, return_latents=False, filter=None
state,
version_dict,
is_legacy=False,
return_latents=False,
filter=None,
stage2strength=None,
):
img = load_img()
if img is None:
Expand All @@ -175,19 +181,19 @@ def run_img2img(
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
num_rows, num_cols, sampler = init_sampling(
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
use_identity_guider=not version_dict["is_guided"],
stage2strength=stage2strength,
)
num_samples = num_rows * num_cols

if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
Expand All @@ -198,7 +204,6 @@ def run_img2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out


Expand All @@ -210,6 +215,7 @@ def apply_refiner(
prompt,
negative_prompt,
filter=None,
finish_denoising=False,
):
init_dict = {
"orig_width": input.shape[3] * 8,
Expand Down Expand Up @@ -237,6 +243,7 @@ def apply_refiner(
num_samples,
skip_encode=True,
filter=filter,
add_noise=not finish_denoising,
)

return samples
Expand All @@ -249,20 +256,22 @@ def apply_refiner(
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")

if version == "SD-XL base":
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
set_lowvram_mode(st.checkbox("Low vram mode", True))

if version.startswith("SDXL-base"):
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________")
else:
add_pipeline = False

filter = DeepFloydDataFiltering(verbose=False)

seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed)

save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))

state = init_st(version_dict)
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]

is_legacy = version_dict["is_legacy"]
Expand All @@ -276,29 +285,34 @@ def apply_refiner(
else:
negative_prompt = "" # which is unused

stage2strength = None
finish_denoising = False

if add_pipeline:
st.write("__________________________")

version2 = "SDXL-Refiner"
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
)
st.write("**Refiner Options:**")

version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
state2 = init_st(version_dict2, load_filter=False)
st.info(state2["msg"])

stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
)

sampler2 = init_sampling(
sampler2, *_ = init_sampling(
key=2,
img2img_strength=stage2strength,
use_identity_guider=not version_dict2["is_guided"],
get_num_samples=False,
specify_num_samples=False,
)
st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if not finish_denoising:
stage2strength = None

if mode == "txt2img":
out = run_txt2img(
Expand All @@ -307,15 +321,17 @@ def apply_refiner(
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=filter,
filter=state.get("filter"),
stage2strength=stage2strength,
)
elif mode == "img2img":
out = run_img2img(
state,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=filter,
filter=state.get("filter"),
stage2strength=stage2strength,
)
else:
raise ValueError(f"unknown mode {mode}")
Expand All @@ -326,7 +342,6 @@ def apply_refiner(
samples_z = None

if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
Expand All @@ -335,9 +350,9 @@ def apply_refiner(
samples_z.shape[0],
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
filter=state.get("filter"),
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)

if save_locally and samples is not None:
perform_save_locally(save_path, samples)
Loading

0 comments on commit 4977d76

Please sign in to comment.