@@ -304,10 +304,23 @@ def stitch(
304304 image2 .movedim (- 1 , 1 ), target_w , target_h , "lanczos" , "disabled"
305305 ).movedim (1 , - 1 )
306306
307+ color_map = {
308+ "white" : 1.0 ,
309+ "black" : 0.0 ,
310+ "red" : (1.0 , 0.0 , 0.0 ),
311+ "green" : (0.0 , 1.0 , 0.0 ),
312+ "blue" : (0.0 , 0.0 , 1.0 ),
313+ }
314+
315+ color_val = color_map [spacing_color ]
316+
307317 # When not matching sizes, pad to align non-concat dimensions
308318 if not match_image_size :
309319 h1 , w1 = image1 .shape [1 :3 ]
310320 h2 , w2 = image2 .shape [1 :3 ]
321+ pad_value = 0.0
322+ if not isinstance (color_val , tuple ):
323+ pad_value = color_val
311324
312325 if direction in ["left" , "right" ]:
313326 # For horizontal concat, pad heights to match
@@ -316,23 +329,23 @@ def stitch(
316329 if h1 < target_h :
317330 pad_h = target_h - h1
318331 pad_top , pad_bottom = pad_h // 2 , pad_h - pad_h // 2
319- image1 = torch .nn .functional .pad (image1 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = 0.0 )
332+ image1 = torch .nn .functional .pad (image1 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = pad_value )
320333 if h2 < target_h :
321334 pad_h = target_h - h2
322335 pad_top , pad_bottom = pad_h // 2 , pad_h - pad_h // 2
323- image2 = torch .nn .functional .pad (image2 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = 0.0 )
336+ image2 = torch .nn .functional .pad (image2 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = pad_value )
324337 else : # up, down
325338 # For vertical concat, pad widths to match
326339 if w1 != w2 :
327340 target_w = max (w1 , w2 )
328341 if w1 < target_w :
329342 pad_w = target_w - w1
330343 pad_left , pad_right = pad_w // 2 , pad_w - pad_w // 2
331- image1 = torch .nn .functional .pad (image1 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = 0.0 )
344+ image1 = torch .nn .functional .pad (image1 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = pad_value )
332345 if w2 < target_w :
333346 pad_w = target_w - w2
334347 pad_left , pad_right = pad_w // 2 , pad_w - pad_w // 2
335- image2 = torch .nn .functional .pad (image2 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = 0.0 )
348+ image2 = torch .nn .functional .pad (image2 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = pad_value )
336349
337350 # Ensure same number of channels
338351 if image1 .shape [- 1 ] != image2 .shape [- 1 ]:
@@ -366,15 +379,6 @@ def stitch(
366379 if spacing_width > 0 :
367380 spacing_width = spacing_width + (spacing_width % 2 ) # Ensure even
368381
369- color_map = {
370- "white" : 1.0 ,
371- "black" : 0.0 ,
372- "red" : (1.0 , 0.0 , 0.0 ),
373- "green" : (0.0 , 1.0 , 0.0 ),
374- "blue" : (0.0 , 0.0 , 1.0 ),
375- }
376- color_val = color_map [spacing_color ]
377-
378382 if direction in ["left" , "right" ]:
379383 spacing_shape = (
380384 image1 .shape [0 ],
0 commit comments