Skip to content

Commit d062fcc

Browse files
[feat] Add ImageStitch node for concatenating images (Comfy-Org#8369)
* [feat] Add ImageStitch node for concatenating images with borders Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing. Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage. * [fix] Fix CI issues with CUDA dependencies and linting - Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners - Fix ruff linting issues for code style compliance * [fix] Improve CI compatibility by mocking nodes module import Prevent CUDA initialization chain by mocking the nodes module at import time, which is cleaner than deep mocking of CUDA-specific functions. * [refactor] Clean up ImageStitch tests - Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini) - Remove metadata tests that test framework internals rather than functionality - Rename complex scenario test to be more descriptive of what it tests * [refactor] Rename 'border' to 'spacing' for semantic accuracy - Change border_width/border_color to spacing_width/spacing_color in API - Update all tests to use spacing terminology - Update comments and variable names throughout - More accurately describes the gap/separator between images
1 parent 456abad commit d062fcc

File tree

4 files changed

+423
-0
lines changed

4 files changed

+423
-0
lines changed

comfy_extras/nodes_images.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from io import BytesIO
1515
from inspect import cleandoc
1616
import torch
17+
import comfy.utils
1718

1819
from comfy.comfy_types import FileLocator
1920

@@ -229,6 +230,186 @@ def combine_all(svgs: list['SVG']) -> 'SVG':
229230
all_svgs_list.extend(svg_item.data)
230231
return SVG(all_svgs_list)
231232

233+
234+
class ImageStitch:
235+
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
236+
237+
@classmethod
238+
def INPUT_TYPES(s):
239+
return {
240+
"required": {
241+
"image1": ("IMAGE",),
242+
"direction": (["right", "down", "left", "up"], {"default": "right"}),
243+
"match_image_size": ("BOOLEAN", {"default": True}),
244+
"spacing_width": (
245+
"INT",
246+
{"default": 0, "min": 0, "max": 1024, "step": 2},
247+
),
248+
"spacing_color": (
249+
["white", "black", "red", "green", "blue"],
250+
{"default": "white"},
251+
),
252+
},
253+
"optional": {
254+
"image2": ("IMAGE",),
255+
},
256+
}
257+
258+
RETURN_TYPES = ("IMAGE",)
259+
FUNCTION = "stitch"
260+
CATEGORY = "image/transform"
261+
DESCRIPTION = """
262+
Stitches image2 to image1 in the specified direction.
263+
If image2 is not provided, returns image1 unchanged.
264+
Optional spacing can be added between images.
265+
"""
266+
267+
def stitch(
268+
self,
269+
image1,
270+
direction,
271+
match_image_size,
272+
spacing_width,
273+
spacing_color,
274+
image2=None,
275+
):
276+
if image2 is None:
277+
return (image1,)
278+
279+
# Handle batch size differences
280+
if image1.shape[0] != image2.shape[0]:
281+
max_batch = max(image1.shape[0], image2.shape[0])
282+
if image1.shape[0] < max_batch:
283+
image1 = torch.cat(
284+
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
285+
)
286+
if image2.shape[0] < max_batch:
287+
image2 = torch.cat(
288+
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
289+
)
290+
291+
# Match image sizes if requested
292+
if match_image_size:
293+
h1, w1 = image1.shape[1:3]
294+
h2, w2 = image2.shape[1:3]
295+
aspect_ratio = w2 / h2
296+
297+
if direction in ["left", "right"]:
298+
target_h, target_w = h1, int(h1 * aspect_ratio)
299+
else: # up, down
300+
target_w, target_h = w1, int(w1 / aspect_ratio)
301+
302+
image2 = comfy.utils.common_upscale(
303+
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
304+
).movedim(1, -1)
305+
306+
# When not matching sizes, pad to align non-concat dimensions
307+
if not match_image_size:
308+
h1, w1 = image1.shape[1:3]
309+
h2, w2 = image2.shape[1:3]
310+
311+
if direction in ["left", "right"]:
312+
# For horizontal concat, pad heights to match
313+
if h1 != h2:
314+
target_h = max(h1, h2)
315+
if h1 < target_h:
316+
pad_h = target_h - h1
317+
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
318+
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
319+
if h2 < target_h:
320+
pad_h = target_h - h2
321+
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
322+
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
323+
else: # up, down
324+
# For vertical concat, pad widths to match
325+
if w1 != w2:
326+
target_w = max(w1, w2)
327+
if w1 < target_w:
328+
pad_w = target_w - w1
329+
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
330+
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
331+
if w2 < target_w:
332+
pad_w = target_w - w2
333+
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
334+
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
335+
336+
# Ensure same number of channels
337+
if image1.shape[-1] != image2.shape[-1]:
338+
max_channels = max(image1.shape[-1], image2.shape[-1])
339+
if image1.shape[-1] < max_channels:
340+
image1 = torch.cat(
341+
[
342+
image1,
343+
torch.ones(
344+
*image1.shape[:-1],
345+
max_channels - image1.shape[-1],
346+
device=image1.device,
347+
),
348+
],
349+
dim=-1,
350+
)
351+
if image2.shape[-1] < max_channels:
352+
image2 = torch.cat(
353+
[
354+
image2,
355+
torch.ones(
356+
*image2.shape[:-1],
357+
max_channels - image2.shape[-1],
358+
device=image2.device,
359+
),
360+
],
361+
dim=-1,
362+
)
363+
364+
# Add spacing if specified
365+
if spacing_width > 0:
366+
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
367+
368+
color_map = {
369+
"white": 1.0,
370+
"black": 0.0,
371+
"red": (1.0, 0.0, 0.0),
372+
"green": (0.0, 1.0, 0.0),
373+
"blue": (0.0, 0.0, 1.0),
374+
}
375+
color_val = color_map[spacing_color]
376+
377+
if direction in ["left", "right"]:
378+
spacing_shape = (
379+
image1.shape[0],
380+
max(image1.shape[1], image2.shape[1]),
381+
spacing_width,
382+
image1.shape[-1],
383+
)
384+
else:
385+
spacing_shape = (
386+
image1.shape[0],
387+
spacing_width,
388+
max(image1.shape[2], image2.shape[2]),
389+
image1.shape[-1],
390+
)
391+
392+
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
393+
if isinstance(color_val, tuple):
394+
for i, c in enumerate(color_val):
395+
if i < spacing.shape[-1]:
396+
spacing[..., i] = c
397+
if spacing.shape[-1] == 4: # Add alpha
398+
spacing[..., 3] = 1.0
399+
else:
400+
spacing[..., : min(3, spacing.shape[-1])] = color_val
401+
if spacing.shape[-1] == 4:
402+
spacing[..., 3] = 1.0
403+
404+
# Concatenate images
405+
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
406+
if spacing_width > 0:
407+
images.insert(1, spacing)
408+
409+
concat_dim = 2 if direction in ["left", "right"] else 1
410+
return (torch.cat(images, dim=concat_dim),)
411+
412+
232413
class SaveSVGNode:
233414
"""
234415
Save SVG files on disk.
@@ -318,4 +499,5 @@ def replacement(match):
318499
"SaveAnimatedWEBP": SaveAnimatedWEBP,
319500
"SaveAnimatedPNG": SaveAnimatedPNG,
320501
"SaveSVGNode": SaveSVGNode,
502+
"ImageStitch": ImageStitch,
321503
}

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
20612061
"ImagePadForOutpaint": "Pad Image for Outpainting",
20622062
"ImageBatch": "Batch Images",
20632063
"ImageCrop": "Image Crop",
2064+
"ImageStitch": "Image Stitch",
20642065
"ImageBlend": "Image Blend",
20652066
"ImageBlur": "Image Blur",
20662067
"ImageQuantize": "Image Quantize",

tests-unit/comfy_extras_test/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)