@@ -53,7 +53,7 @@ def validate_and_cast_response (response):
5353 raise Exception ("Failed to download the image" )
5454 img = Image .open (io .BytesIO (img_response .content ))
5555
56- img = img .convert ("RGB" ) # Ensure RGB format
56+ img = img .convert ("RGBA" )
5757
5858 # Convert to numpy array, normalize to float32 between 0 and 1
5959 img_array = np .array (img ).astype (np .float32 ) / 255.0
@@ -339,25 +339,38 @@ def api_call(self, prompt, seed=0, quality="low", background="opaque", image=Non
339339 model = "gpt-image-1"
340340 path = "/proxy/openai/images/generations"
341341 request_class = OpenAIImageGenerationRequest
342- img_binary = None
342+ img_binaries = []
343343 mask_binary = None
344-
344+ files = []
345345
346346 if image is not None :
347347 path = "/proxy/openai/images/edits"
348348 request_class = OpenAIImageEditRequest
349349
350- scaled_image = downscale_input ( image ). squeeze ()
350+ batch_size = image . shape [ 0 ]
351351
352- image_np = (scaled_image .numpy () * 255 ).astype (np .uint8 )
353- img = Image .fromarray (image_np )
354- img_byte_arr = io .BytesIO ()
355- img .save (img_byte_arr , format = 'PNG' )
356- img_byte_arr .seek (0 )
357- img_binary = img_byte_arr #.getvalue()
358- img_binary .name = "image.png"
352+
353+ for i in range (batch_size ):
354+ single_image = image [i :i + 1 ]
355+ scaled_image = downscale_input (single_image ).squeeze ()
356+
357+ image_np = (scaled_image .numpy () * 255 ).astype (np .uint8 )
358+ img = Image .fromarray (image_np )
359+ img_byte_arr = io .BytesIO ()
360+ img .save (img_byte_arr , format = 'PNG' )
361+ img_byte_arr .seek (0 )
362+ img_binary = img_byte_arr
363+ img_binary .name = f"image_{ i } .png"
364+
365+ img_binaries .append (img_binary )
366+ if batch_size == 1 :
367+ files .append (("image" , img_binary ))
368+ else :
369+ files .append (("image[]" , img_binary ))
359370
360371 if mask is not None :
372+ if image .shape [0 ] != 1 :
373+ raise Exception ("Cannot use a mask with multiple image" )
361374 if image is None :
362375 raise Exception ("Cannot use a mask without an input image" )
363376 if mask .shape [1 :] != image .shape [1 :- 1 ]:
@@ -373,14 +386,10 @@ def api_call(self, prompt, seed=0, quality="low", background="opaque", image=Non
373386 mask_img_byte_arr = io .BytesIO ()
374387 mask_img .save (mask_img_byte_arr , format = 'PNG' )
375388 mask_img_byte_arr .seek (0 )
376- mask_binary = mask_img_byte_arr #.getvalue()
389+ mask_binary = mask_img_byte_arr
377390 mask_binary .name = "mask.png"
391+ files .append (("mask" , mask_binary ))
378392
379- files = {}
380- if img_binary :
381- files ["image" ] = img_binary
382- if mask_binary :
383- files ["mask" ] = mask_binary
384393
385394 # Build the operation
386395 operation = SynchronousOperation (
0 commit comments