Skip to content

Commit

Permalink
Merge branch 'dev' into torch
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 authored Apr 29, 2023
2 parents 7fb72ed + e55cb92 commit f54cd3f
Show file tree
Hide file tree
Showing 30 changed files with 336 additions and 103 deletions.
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self):
def activate(self, p, params_list):
additional = shared.opts.sd_lora

if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

Expand Down
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/scripts/lora_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ def before_ui():


shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))
83 changes: 68 additions & 15 deletions extensions-builtin/ScuNET/scripts/scunet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm

from basicsr.utils.download_util import load_file_from_url

import modules.upscaler
from modules import devices, modelloader
from scunet_model_arch import SCUNet as net
from modules.shared import opts
from modules import images


class UpscalerScuNET(modules.upscaler.Upscaler):
Expand Down Expand Up @@ -42,28 +46,78 @@ def __init__(self, dirname):
scalers.append(scaler_data2)
self.scalers = scalers

def do_upscale(self, img: PIL.Image, selected_file):
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)

device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1

stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)

with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:

for w_idx in w_idx_list:

in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]

out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)

E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)

return output

def do_upscale(self, img: PIL.Image.Image, selected_file):

torch.cuda.empty_cache()

model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img

device = devices.get_device_for('scunet')
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)

with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
tile = opts.SCUNET_tile
h, w = img.height, img.width
np_img = np.array(img)
np_img = np_img[:, :, ::-1] # RGB to BGR
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore

if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img

torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')

output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))

def load_model(self, path: str):
device = devices.get_device_for('scunet')
Expand All @@ -84,4 +138,3 @@ def load_model(self, path: str):
model = model.to(device)

return model

8 changes: 0 additions & 8 deletions javascript/contextMenus.js
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,6 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)

appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
})();
//End example Context Menu Items

Expand Down
38 changes: 31 additions & 7 deletions javascript/edit-attention.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function keyupEditAttention(event){
// Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
Expand All @@ -27,7 +27,7 @@ function keyupEditAttention(event){
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1);
Expand All @@ -43,10 +43,28 @@ function keyupEditAttention(event){
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}

function selectCurrentWord(){
if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t";

// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}

// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}

// If the user hasn't selected anything, let's select their current parenthesis block
if(! selectCurrentParenthesisBlock('<', '>')){
selectCurrentParenthesisBlock('(', ')')
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}

// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
}

event.preventDefault();
Expand Down Expand Up @@ -81,7 +99,13 @@ function keyupEditAttention(event){
weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"

text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}

target.focus();
target.value = text;
Expand All @@ -93,4 +117,4 @@ function keyupEditAttention(event){

addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
});
2 changes: 1 addition & 1 deletion javascript/generationParams.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ onUiUpdate(function(){

let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
let selectedTab = gradioApp().querySelector('#tabs div button')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
});
Expand Down
7 changes: 5 additions & 2 deletions javascript/imageviewer.js
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ document.addEventListener("DOMContentLoaded", function() {

modal.appendChild(modalNext)

gradioApp().appendChild(modal)

try {
gradioApp().appendChild(modal);
} catch (e) {
gradioApp().body.appendChild(modal);
}

document.body.appendChild(modal);

Expand Down
2 changes: 1 addition & 1 deletion javascript/progressbar.js
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return
}

if(elapsedFromStart > 5 && !res.queued && !res.active){
if(elapsedFromStart > 40 && !res.queued && !res.active){
removeProgressBar()
return
}
Expand Down
12 changes: 3 additions & 9 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import gradio as gr
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
Expand Down Expand Up @@ -395,16 +394,11 @@ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)

def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file

reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
image_list = reqDict.pop('imageList', [])
image_folder = [decode_base64_to_image(x.data) for x in image_list]

with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)

return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

Expand Down
8 changes: 6 additions & 2 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,18 @@ def cond_cast_float(input):


def randn(seed, shape):
from modules.shared import opts

torch.manual_seed(seed)
if device.type == 'mps':
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)


def randn_without_seed(shape):
if device.type == 'mps':
from modules.shared import opts

if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)

Expand Down
2 changes: 1 addition & 1 deletion modules/extra_networks_hypernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork

if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

Expand Down
5 changes: 5 additions & 0 deletions modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def parse_generation_parameters(x: str):

restore_old_hires_fix_params(res)

# Missing RNG means the default was set, which is GPU RNG
if "RNG" not in res:
res["RNG"] = "GPU"

return res


Expand All @@ -304,6 +308,7 @@ def parse_generation_parameters(x: str):
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
('RNG', 'randn_source'),
]


Expand Down
1 change: 1 addition & 0 deletions modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
}
default_time_format = '%Y%m%d%H%M%S'

Expand Down
2 changes: 1 addition & 1 deletion modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings,
)

p.scripts = modules.scripts.scripts_txt2img
p.scripts = modules.scripts.scripts_img2img
p.script_args = args

if shared.cmd_opts.enable_console_prompts:
Expand Down
4 changes: 2 additions & 2 deletions modules/interrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"]

try:
os.makedirs(tmpdir)
os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
Expand All @@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
os.remove(tmpdir)
os.removedirs(tmpdir)


class InterrogateModels:
Expand Down
10 changes: 8 additions & 2 deletions modules/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,

if extras_mode == 1:
for img in image_folder:
image = Image.open(img)
if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]

image_data.append(image)
image_names.append(os.path.splitext(img.orig_name)[0])
image_names.append(fn)
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'
Expand Down
Loading

0 comments on commit f54cd3f

Please sign in to comment.