forked from lllyasviel/ControlNet-v1-1-nightly
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
738fab9
commit 22e64cf
Showing
6 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from share import * | ||
import config | ||
|
||
import cv2 | ||
import einops | ||
import gradio as gr | ||
import numpy as np | ||
import torch | ||
import random | ||
|
||
from pytorch_lightning import seed_everything | ||
from annotator.util import resize_image, HWC3 | ||
from annotator.mlsd import MLSDdetector | ||
from cldm.model import create_model, load_state_dict | ||
from cldm.ddim_hacked import DDIMSampler | ||
|
||
|
||
preprocessor = None | ||
|
||
model_name = 'control_v11p_sd15_mlsd' | ||
model = create_model(f'./models/{model_name}.yaml').cpu() | ||
model.load_state_dict(load_state_dict('./models/v1-5-pruned.ckpt', location='cuda'), strict=False) | ||
model.load_state_dict(load_state_dict(f'./models/{model_name}.pth', location='cuda'), strict=False) | ||
model = model.cuda() | ||
ddim_sampler = DDIMSampler(model) | ||
|
||
|
||
def process(det, input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold): | ||
global preprocessor | ||
|
||
if det == 'MLSD': | ||
if not isinstance(preprocessor, MLSDdetector): | ||
preprocessor = MLSDdetector() | ||
|
||
with torch.no_grad(): | ||
input_image = HWC3(input_image) | ||
|
||
if det == 'None': | ||
detected_map = input_image.copy() | ||
else: | ||
detected_map = preprocessor(resize_image(input_image, detect_resolution), value_threshold, distance_threshold) | ||
detected_map = HWC3(detected_map) | ||
|
||
img = resize_image(input_image, image_resolution) | ||
H, W, C = img.shape | ||
|
||
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) | ||
|
||
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 | ||
control = torch.stack([control for _ in range(num_samples)], dim=0) | ||
control = einops.rearrange(control, 'b h w c -> b c h w').clone() | ||
|
||
if seed == -1: | ||
seed = random.randint(0, 65535) | ||
seed_everything(seed) | ||
|
||
if config.save_memory: | ||
model.low_vram_shift(is_diffusing=False) | ||
|
||
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} | ||
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} | ||
shape = (4, H // 8, W // 8) | ||
|
||
if config.save_memory: | ||
model.low_vram_shift(is_diffusing=True) | ||
|
||
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) | ||
# Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 | ||
|
||
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, | ||
shape, cond, verbose=False, eta=eta, | ||
unconditional_guidance_scale=scale, | ||
unconditional_conditioning=un_cond) | ||
|
||
if config.save_memory: | ||
model.low_vram_shift(is_diffusing=False) | ||
|
||
x_samples = model.decode_first_stage(samples) | ||
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | ||
|
||
results = [x_samples[i] for i in range(num_samples)] | ||
return [detected_map] + results | ||
|
||
|
||
block = gr.Blocks().queue() | ||
with block: | ||
with gr.Row(): | ||
gr.Markdown("## Control Stable Diffusion with MLSD Lines") | ||
with gr.Row(): | ||
with gr.Column(): | ||
input_image = gr.Image(source='upload', type="numpy") | ||
prompt = gr.Textbox(label="Prompt") | ||
run_button = gr.Button(label="Run") | ||
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | ||
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=12345) | ||
det = gr.Radio(choices=["MLSD", "None"], type="value", value="MLSD", label="Preprocessor") | ||
with gr.Accordion("Advanced options", open=False): | ||
value_threshold = gr.Slider(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01) | ||
distance_threshold = gr.Slider(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01) | ||
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) | ||
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | ||
guess_mode = gr.Checkbox(label='Guess Mode', value=False) | ||
detect_resolution = gr.Slider(label="Preprocessor Resolution", minimum=128, maximum=1024, value=512, step=1) | ||
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | ||
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | ||
eta = gr.Slider(label="DDIM ETA", minimum=0.0, maximum=1.0, value=1.0, step=0.01) | ||
a_prompt = gr.Textbox(label="Added Prompt", value='best quality') | ||
n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') | ||
with gr.Column(): | ||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | ||
ips = [det, input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold] | ||
run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) | ||
|
||
|
||
block.launch(server_name='0.0.0.0') |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters