Skip to content

Add a launch flag for --latent-preview-api to get image previews through json payload #4993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd"

parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)

parser.add_argument("--latent-preview-api", action="store_true", help="Converts latent previews to base64 to be sent through the websocket json payload under message['type'] == 'latent_preview'")
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

cache_group = parser.add_mutually_exclusive_group()
Expand Down
6 changes: 5 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def hook(value, total, preview_image):

server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
if args.latent_preview_api:
server.send_sync("latent_preview", preview_image, server.client_id)
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
else:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)


Expand Down
23 changes: 23 additions & 0 deletions script_examples/websockets_api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def get_images(ws, prompt):
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
#elif message["type"] == "latent_preview": #if you want previews of the latent images add --latent-preview-api to the launch flags
#img_base64 = message['data']

#recommend moving the following to a function that updates on an interval and use a global variable to store img_base64
#img_data = base64.b64decode(img_base64)
#buffered = io.BytesIO(img_data)
#preview_image = Image.open(buffered)
#preview_image = ImageOps.contain(preview_image, (512, 512))

#you'll need the following imports:
#from io import BytesIO
#import base64
#from PIL import Image, ImageOps

#while there's no point in using previews in this example, this is helpful to have for using with frontends like gradio
#if using this in an app that will make multiple generations in a row: make sure to ws.close() after the get_images call
#for example, if you call the following function, that then calls this get_images(ws, prompt) function:
#def get_prompt_images(prompt):
#ws = websocket.WebSocket()
#ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
#images = get_images(ws, prompt)
#ws.close()
#the ws.close() will prevent connection timeouts that might randomly occur
else:
continue #previews are binary data

Expand Down
7 changes: 7 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import asyncio
import traceback
import base64

import nodes
import folder_paths
Expand Down Expand Up @@ -741,6 +742,12 @@ def get_queue_info(self):
async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif event == "latent_preview":
img = data[1]
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
await self.send_json("latent_preview", img_base64, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
else:
Expand Down