Skip to content

Commit 2c1d686

Browse files
implement multi image prompting for gpt-image-1 and fix transparency in outputs (Comfy-Org#7763)
* implement multi image prompting for GPTI Image 1 * fix transparency not working * fix ruff
1 parent e8ddc2b commit 2c1d686

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

comfy_api_nodes/nodes_api.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def validate_and_cast_response (response):
5353
raise Exception("Failed to download the image")
5454
img = Image.open(io.BytesIO(img_response.content))
5555

56-
img = img.convert("RGB") # Ensure RGB format
56+
img = img.convert("RGBA")
5757

5858
# Convert to numpy array, normalize to float32 between 0 and 1
5959
img_array = np.array(img).astype(np.float32) / 255.0
@@ -339,25 +339,38 @@ def api_call(self, prompt, seed=0, quality="low", background="opaque", image=Non
339339
model = "gpt-image-1"
340340
path = "/proxy/openai/images/generations"
341341
request_class = OpenAIImageGenerationRequest
342-
img_binary = None
342+
img_binaries = []
343343
mask_binary = None
344-
344+
files = []
345345

346346
if image is not None:
347347
path = "/proxy/openai/images/edits"
348348
request_class = OpenAIImageEditRequest
349349

350-
scaled_image = downscale_input(image).squeeze()
350+
batch_size = image.shape[0]
351351

352-
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
353-
img = Image.fromarray(image_np)
354-
img_byte_arr = io.BytesIO()
355-
img.save(img_byte_arr, format='PNG')
356-
img_byte_arr.seek(0)
357-
img_binary = img_byte_arr#.getvalue()
358-
img_binary.name = "image.png"
352+
353+
for i in range(batch_size):
354+
single_image = image[i:i+1]
355+
scaled_image = downscale_input(single_image).squeeze()
356+
357+
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
358+
img = Image.fromarray(image_np)
359+
img_byte_arr = io.BytesIO()
360+
img.save(img_byte_arr, format='PNG')
361+
img_byte_arr.seek(0)
362+
img_binary = img_byte_arr
363+
img_binary.name = f"image_{i}.png"
364+
365+
img_binaries.append(img_binary)
366+
if batch_size == 1:
367+
files.append(("image", img_binary))
368+
else:
369+
files.append(("image[]", img_binary))
359370

360371
if mask is not None:
372+
if image.shape[0] != 1:
373+
raise Exception("Cannot use a mask with multiple image")
361374
if image is None:
362375
raise Exception("Cannot use a mask without an input image")
363376
if mask.shape[1:] != image.shape[1:-1]:
@@ -373,14 +386,10 @@ def api_call(self, prompt, seed=0, quality="low", background="opaque", image=Non
373386
mask_img_byte_arr = io.BytesIO()
374387
mask_img.save(mask_img_byte_arr, format='PNG')
375388
mask_img_byte_arr.seek(0)
376-
mask_binary = mask_img_byte_arr#.getvalue()
389+
mask_binary = mask_img_byte_arr
377390
mask_binary.name = "mask.png"
391+
files.append(("mask", mask_binary))
378392

379-
files = {}
380-
if img_binary:
381-
files["image"] = img_binary
382-
if mask_binary:
383-
files["mask"] = mask_binary
384393

385394
# Build the operation
386395
operation = SynchronousOperation(

0 commit comments

Comments
 (0)