Skip to content

Commit

Permalink
added a button to run hires fix on selected image in the gallery
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jan 1, 2024
1 parent 5d7d182 commit 501993e
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 86 deletions.
8 changes: 8 additions & 0 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ function submit() {
return res;
}

function submit_txt2img_upscale() {
res = submit.apply(null, arguments);

Check failure on line 154 in javascript/ui.js

View workflow job for this annotation

GitHub Actions / eslint

'res' is not defined

res[2] = selected_gallery_index();

Check failure on line 156 in javascript/ui.js

View workflow job for this annotation

GitHub Actions / eslint

'res' is not defined

return res;

Check failure on line 158 in javascript/ui.js

View workflow job for this annotation

GitHub Actions / eslint

'res' is not defined
}

function submit_img2img() {
showSubmitButtons('img2img', false);

Expand Down
46 changes: 37 additions & 9 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class StableDiffusionProcessing:
token_merging_ratio = 0
token_merging_ratio_hr = 0
disable_extra_networks: bool = False
firstpass_image: Image = None

scripts_value: scripts.ScriptRunner = field(default=None, init=False)
script_args_value: list = field(default=None, init=False)
Expand Down Expand Up @@ -1238,18 +1239,45 @@ def init(self, all_prompts, all_seeds, all_subseeds):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

x = self.rng.next()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
if self.firstpass_image is not None and self.enable_hr:
# here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix

if not self.enable_hr:
return samples
devices.torch_gc()
if self.latent_scale_mode is None:
image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
image = np.moveaxis(image, 2, 0)

samples = None
decoded_samples = torch.asarray(np.expand_dims(image, 0))

else:
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(np.expand_dims(image, axis=0))
image = image.to(shared.device, dtype=devices.dtype_vae)

if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
decoded_samples = None
devices.torch_gc()

if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None
# here we generate an image normally

x = self.rng.next()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x

if not self.enable_hr:
return samples

devices.torch_gc()

if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None

with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
Expand Down
19 changes: 17 additions & 2 deletions modules/txt2img.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from contextlib import closing

import modules.scripts
from modules import processing
from modules import processing, infotext_utils
from modules.infotext_utils import create_override_settings_dict
from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
import gradio as gr


def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args):
assert len(gallery) > 0, 'No image to upscale'

image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
image = infotext_utils.image_from_url_text(image_info)

return txt2img(id_task, request, *args, firstpass_image=image)


def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None):
override_settings = create_override_settings_dict(override_settings_texts)

if firstpass_image is not None:
enable_hr = True
batch_size = 1
n_iter = 1

p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
Expand All @@ -38,6 +52,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
firstpass_image=firstpass_image,
)

p.scripts = modules.scripts.scripts_txt2img
Expand Down
108 changes: 59 additions & 49 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,61 +375,71 @@ def create_ui():
show_progress=False,
)

txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)

txt2img_inputs = [
dummy_component,
toprow.prompt,
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
sampler_name,
batch_count,
batch_size,
cfg_scale,
height,
width,
enable_hr,
denoising_strength,
hr_scale,
hr_upscaler,
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
hr_sampler_name,
hr_prompt,
hr_negative_prompt,
override_settings,
] + custom_inputs

txt2img_outputs = [
output_panel.gallery,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
]

txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
dummy_component,
toprow.prompt,
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
sampler_name,
batch_count,
batch_size,
cfg_scale,
height,
width,
enable_hr,
denoising_strength,
hr_scale,
hr_upscaler,
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
hr_sampler_name,
hr_prompt,
hr_negative_prompt,
override_settings,

] + custom_inputs,

outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
],
inputs=txt2img_inputs,
outputs=txt2img_outputs,
show_progress=False,
)

toprow.prompt.submit(**txt2img_args)
toprow.submit.click(**txt2img_args)

output_panel.button_upscale.click(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
_js="submit_txt2img_upscale",
inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:],
outputs=txt2img_outputs,
show_progress=False,
)

res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)

toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressTxt2img",
inputs=[dummy_component],
outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
output_panel.gallery,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
],
show_progress=False,
)
Expand Down Expand Up @@ -479,7 +489,7 @@ def create_ui():
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])

extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)

extra_tabs.__exit__()

Expand Down Expand Up @@ -710,7 +720,7 @@ def select_img2img_tab(tab):
outputs=[inpaint_controls, mask_alpha],
)

img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)

img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
Expand Down Expand Up @@ -755,10 +765,10 @@ def select_img2img_tab(tab):
img2img_batch_png_info_dir,
] + custom_inputs,
outputs=[
img2img_gallery,
generation_info,
html_info,
html_log,
output_panel.gallery,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
],
show_progress=False,
)
Expand Down Expand Up @@ -796,10 +806,10 @@ def select_img2img_tab(tab):
_js="restoreProgressImg2img",
inputs=[dummy_component],
outputs=[
img2img_gallery,
generation_info,
html_info,
html_log,
output_panel.gallery,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
],
show_progress=False,
)
Expand Down Expand Up @@ -839,7 +849,7 @@ def select_img2img_tab(tab):
))

extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)

extra_tabs.__exit__()

Expand Down
Loading

2 comments on commit 501993e

@4lt3r3go
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all i ever wanted! thanks for this

@nothingness6
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So....does this code fix the issue of the Generate button? Then what should I do? I just found the ui_postprocessing.py file and have no idea what to do. Should I copy/paste this code somewhere?
Sorry, but I'm very beginner for this.

Please sign in to comment.