Skip to content

Commit

Permalink
fix an error that happens when you type into prompt while switching m…
Browse files Browse the repository at this point in the history
…odel, put queue stuff into separate file
  • Loading branch information
AUTOMATIC1111 committed Nov 28, 2022
1 parent 0376da1 commit 0b5dcb3
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 91 deletions.
98 changes: 98 additions & 0 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import html
import sys
import threading
import traceback
import time

from modules import shared

queue_lock = threading.Lock()


def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)

return res

return f


def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):

shared.state.begin()

with queue_lock:
res = func(*args, **kwargs)

shared.state.end()

return res

return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)


def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()

try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8

print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)

print(traceback.format_exc(), file=sys.stderr)

shared.state.job = ""
shared.state.job_count = 0

if extra_outputs_array is None:
extra_outputs_array = [None, '']

res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]

shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0

if not add_stats:
return tuple(res)

elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text

if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)

vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''

# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"

return tuple(res)

return f

67 changes: 3 additions & 64 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import gradio.utils
import numpy as np
from PIL import Image, PngImagePlugin

from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call

from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.paths import script_path
Expand Down Expand Up @@ -158,67 +158,6 @@ def __init__(self, d=None):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")


def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()

try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8

print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)

print(traceback.format_exc(), file=sys.stderr)

shared.state.job = ""
shared.state.job_count = 0

if extra_outputs_array is None:
extra_outputs_array = [None, '']

res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]

shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0

if not add_stats:
return tuple(res)

elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text

if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)

vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''

# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"

return tuple(res)

return f


def calc_time_left(progress, threshold, label, force_display):
Expand Down Expand Up @@ -666,7 +605,7 @@ def open_folder(f):
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info


def create_ui(wrap_gradio_gpu_call):
def create_ui():
import modules.img2img
import modules.txt2img

Expand Down Expand Up @@ -826,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call):
height,
]

token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])

modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
Expand Down
30 changes: 3 additions & 27 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware

from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path

from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
Expand All @@ -32,38 +33,12 @@
import modules.hypernetworks.hypernetwork


queue_lock = threading.Lock()
if cmd_opts.server_name:
server_name = cmd_opts.server_name
else:
server_name = "0.0.0.0" if cmd_opts.listen else None


def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)

return res

return f


def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):

shared.state.begin()

with queue_lock:
res = func(*args, **kwargs)

shared.state.end()

return res

return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)


def initialize():
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
Expand Down Expand Up @@ -159,7 +134,7 @@ def webui():
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()

shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
shared.demo = modules.ui.create_ui()

app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
Expand Down Expand Up @@ -189,6 +164,7 @@ def webui():
create_api(app)

modules.script_callbacks.app_started_callback(shared.demo, app)
modules.script_callbacks.app_started_callback(shared.demo, app)

wait_on_server(shared.demo)

Expand Down

0 comments on commit 0b5dcb3

Please sign in to comment.