Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SegGPT #27735

Merged
merged 122 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
fecf251
First commit
EduardoPach Nov 27, 2023
2fbf69b
Improvements
EduardoPach Nov 27, 2023
155e486
More improvements
EduardoPach Nov 27, 2023
2de229a
Converted original checkpoint to HF checkpoint
EduardoPach Nov 28, 2023
b3d5049
Fix style
EduardoPach Nov 28, 2023
fe53c92
Fixed forward
EduardoPach Nov 29, 2023
051b6c8
More improvements
EduardoPach Nov 30, 2023
70d0290
More improvements
EduardoPach Dec 5, 2023
951ac86
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Dec 7, 2023
14b70a7
Remove asserts
EduardoPach Dec 7, 2023
dce5f4a
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
EduardoPach Dec 7, 2023
3c8e12a
Remove unnecessary attributes
EduardoPach Dec 7, 2023
449fa80
Changed model name to camel case
EduardoPach Dec 7, 2023
41e409b
Improve forward doc
EduardoPach Dec 7, 2023
3e0a77a
Improve tests
EduardoPach Dec 7, 2023
b039594
More improvements
EduardoPach Dec 7, 2023
5c604d4
Fix copies
EduardoPach Dec 7, 2023
00c8bda
Fix doc
EduardoPach Dec 7, 2023
dfc48dd
Make SegGptImageProcessor more flexible
EduardoPach Dec 7, 2023
4152a68
Added few-shot test
EduardoPach Dec 8, 2023
c04a177
Fix merge
NielsRogge Dec 10, 2023
acc58cc
Fix style
NielsRogge Dec 10, 2023
8bb30d5
Update READMEs and docs
NielsRogge Dec 10, 2023
6ec868b
Update READMEs
NielsRogge Dec 10, 2023
88e8144
Make inputs required
NielsRogge Dec 10, 2023
2a066e9
Add SegGptForImageSegmentation
NielsRogge Dec 11, 2023
09187e0
Make tests pass
EduardoPach Dec 12, 2023
5190205
Rename to out_indicies
EduardoPach Dec 12, 2023
bf0bab9
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Dec 13, 2023
c38de07
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Dec 13, 2023
52de2a7
Fixed naming convention
EduardoPach Dec 13, 2023
bebb958
Copying SegGptMlp from modeling_sam.py
EduardoPach Dec 13, 2023
2c7c311
Some minor improvements
NielsRogge Dec 12, 2023
75b2d90
Remove mlp_ratio
NielsRogge Dec 13, 2023
a612330
Fix docstrings
NielsRogge Dec 14, 2023
74383a8
Fixed docstring match
Jan 7, 2024
932a01f
Objects defined before use
Jan 7, 2024
f54d036
Storing only patch_size and beta for SegGptLoss
Jan 7, 2024
b283608
removed _prepare_inputs method
Jan 7, 2024
0fcfbcf
Removed modified from headers
Jan 7, 2024
64d2a90
Renamed to output_indicies
Jan 7, 2024
559c5be
Removed unnecessary einsums
Jan 7, 2024
45ce96b
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Jan 7, 2024
6f982aa
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Jan 7, 2024
c4d5c00
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Jan 7, 2024
6bc0571
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Jan 7, 2024
33c3f4d
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Jan 7, 2024
a435033
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Jan 7, 2024
798a7d3
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Jan 7, 2024
3b443dc
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Jan 7, 2024
e24c369
Fixing issues
Jan 7, 2024
bd0b552
Raise error as soon as possible
Jan 7, 2024
cca0937
More fixes
Jan 7, 2024
39e2767
Jan 7, 2024
3545672
Fix merge
NielsRogge Jan 14, 2024
7228221
Fix merge
NielsRogge Jan 14, 2024
6133e40
Added palette to SegGptImageProcessor
Jan 25, 2024
cf93d25
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
Jan 25, 2024
1837324
Fixed typo
Jan 25, 2024
dc3cf80
Fixed shape typo
Jan 25, 2024
fc2304c
Added permute before doing palette to class mapping
Jan 25, 2024
b086384
Fixed style
Jan 26, 2024
0c9ad32
Fixed and added tests
Jan 26, 2024
beab961
Fixed docstrings
Jan 26, 2024
bd35b95
Matching SegFormer API for post_processing_semantic_segmentation
Feb 5, 2024
48d7fc3
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 5, 2024
3bcdf52
Fixed copies
Feb 5, 2024
900fac9
Fixed SegGptImageProcessor to handle both binary and RGB masks
Feb 5, 2024
abfc78a
Updated docstrings of SegGptImageProcessor
Feb 5, 2024
ba3f9cb
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
09cdd0e
Update docs/source/en/model_doc/seggpt.md
EduardoPach Feb 7, 2024
690de06
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 7, 2024
37fabe2
Update src/transformers/models/seggpt/convert_seggpt_to_hf.py
EduardoPach Feb 7, 2024
a2a45cd
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
59787c2
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
cf3c1da
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
706285f
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
755a7b5
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 7, 2024
c6257f5
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
041c400
Update tests/models/seggpt/test_image_processing_seggpt.py
EduardoPach Feb 7, 2024
3ffecbc
Update tests/models/seggpt/test_modeling_seggpt.py
EduardoPach Feb 7, 2024
2d1d77c
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
6354a60
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
f5e23c2
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 7, 2024
d347473
Object definitions above & fix style
Feb 7, 2024
9e35aa6
Renamed output_indices to intermediate_feature_indices
Feb 7, 2024
f6f068c
Removed unnecessary check on bool_masked_pos
Feb 7, 2024
a0cbe9b
Loss first in the outputs
Feb 7, 2024
f1dc953
Added validation for do_normalize
Feb 7, 2024
b8b1d5e
Improved SegGptImageProcessor and added new tests
Feb 10, 2024
d172ca0
Added comment
Feb 10, 2024
88db53f
Added docstrings to SegGptLoss
Feb 10, 2024
db06f21
Reimplemented ensemble condition logic in SegGptEncoder
Feb 10, 2024
5514650
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 10, 2024
7c4805e
Update src/transformers/models/seggpt/__init__.py
EduardoPach Feb 10, 2024
6ad819b
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 10, 2024
9dab61f
Update src/transformers/models/seggpt/convert_seggpt_to_hf.py
EduardoPach Feb 10, 2024
1b87260
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 10, 2024
c74ca80
Updated docstrings to use post_process_semantic_segmentation
Feb 10, 2024
e5f2c8c
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
Feb 10, 2024
71dfbf2
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 12, 2024
af21937
Fixed typo on docstrings
Feb 12, 2024
62a82eb
moved pixel values test to test_image_processing_seggpt
Feb 13, 2024
4d425df
Merge remote-tracking branch 'upstream/main' into adding-seggpt
Feb 15, 2024
460f3fa
Addressed comments
Feb 15, 2024
f62b21e
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 15, 2024
620381d
Update src/transformers/models/seggpt/image_processing_seggpt.py
EduardoPach Feb 15, 2024
9999d0b
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 15, 2024
b73b21e
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 15, 2024
01c9f7e
Updated docstrings for SegGptLoss
Feb 15, 2024
c507557
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
Feb 15, 2024
3373cbb
Address comments
Feb 15, 2024
f701039
Added SegGpt example to model docs
Feb 15, 2024
0e46681
Update src/transformers/models/seggpt/modeling_seggpt.py
EduardoPach Feb 21, 2024
43f2d34
moved patchify and unpatchify
EduardoPach Feb 21, 2024
ebc96f4
Merge remote-tracking branch 'upstream/main' into adding-seggpt
EduardoPach Feb 21, 2024
90f911d
Rename checkpoint
EduardoPach Feb 22, 2024
efb85d6
Renamed intermediate_features to intermediate_hidden_states for consi…
EduardoPach Feb 22, 2024
afeb9f2
Update src/transformers/models/seggpt/configuration_seggpt.py
EduardoPach Feb 22, 2024
7e22958
Replaced post_process_masks for post_process_semantic_segmentation in…
EduardoPach Feb 26, 2024
32dd142
Merge branch 'adding-seggpt' of https://github.com/EduardoPach/transf…
EduardoPach Feb 26, 2024
05f0a85
Merge remote-tracking branch 'upstream/main' into adding-seggpt
EduardoPach Feb 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed and added tests
  • Loading branch information
Eduardo Pacheco committed Jan 26, 2024
commit 0c9ad32791eab41505d75589ebc091679a7646e7
2 changes: 2 additions & 0 deletions docs/source/en/model_doc/seggpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The abstract from the paper is the following:

Tips:
- One can use [`SegGptImageProcessor`] to prepare image input, prompt and mask to the model.
- It's highly adivisable to instantiate your own [`SegGptImageProcessor`] with the appropriate `num_labels` (not considering background) for your use case.
- When doing infenrece with [`SegGptForImageSegmentation`] if your `batch_size` is greater than 1 you can use feature ensemble across your images by passing `feature_ensemble=True` in the forward method.

This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco).
The original code can be found [here]([(https://github.com/baaivision/Painter/tree/main)).
Expand Down
64 changes: 11 additions & 53 deletions src/transformers/models/seggpt/image_processing_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@

# See https://arxiv.org/pdf/2212.02499.pdf at 3.1 Redefining Output Spaces as "Images" - Semantic Segmentation from PAINTER paper
# Taken from https://github.com/Abdullah-Meda/Painter/blob/main/Painter/data/coco_semseg/gen_color_coco_panoptic_segm.py#L31
def build_palette(num_classes: int) -> List[Tuple[int, int]]:
base = int(num_classes ** (1 / 3)) + 1 # 19
def build_palette(num_labels: int) -> List[Tuple[int, int]]:
base = int(num_labels ** (1 / 3)) + 1 # 19
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
margin = 256 // base

# we assume that class_idx 0 is the background which is mapped to black
color_list = [(0, 0, 0)]
for location in range(num_classes):
for location in range(num_labels):
num_seq_r = location // base**2
num_seq_g = (location % base**2) // base
num_seq_b = location % base
Expand Down Expand Up @@ -111,7 +111,7 @@ class SegGptImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
num_classes (`int`, *optional*, defaults to `None`):
num_labels (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be built,
assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx channel to a 3 channel RGB.
Not specifying this will result in the prompt mask being passed through as is.
Expand All @@ -129,7 +129,7 @@ def __init__(
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -143,8 +143,8 @@ def __init__(
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.num_classes = num_classes
self.palette = build_palette(self.num_classes) if num_classes is not None else None
self.num_labels = num_labels
self.palette = build_palette(self.num_labels) if num_labels is not None else None

def resize(
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
self,
Expand Down Expand Up @@ -219,7 +219,7 @@ def _preprocess_step(
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
is_mask (`bool`, *optional*, defaults to `False`):
Whether the image is a mask. If True, the image is converted to RGB using the palette if
`self.num_classes` is specified otherwise RGB is achieved by duplicating the channel.
`self.num_labels` is specified otherwise RGB is achieved by duplicating the channel.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Expand Down Expand Up @@ -286,7 +286,7 @@ def _preprocess_step(
images = [to_numpy_array(image) for image in images]

if is_mask:
if self.num_classes is not None:
if self.num_labels is not None:
images = [mask_to_rgb(image, self.palette) for image in images]
else:
images = [np.repeat(image[..., None], 3, axis=-1) for image in images]
Expand Down Expand Up @@ -457,48 +457,6 @@ def preprocess(

return BatchFeature(data=data, tensor_type=return_tensors)

def post_process_masks(
self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None
) -> List[Dict[str, TensorType]]:
"""
Converts the output of [`SegGptImageSegmentationOutput`] into segmentation maps. Only supports
PyTorch.

Args:
outputs ([`SegGptImageSegmentationOutput`]):
Raw outputs of the model.
target_sizes (`List[Tuple[int, int]]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[Dict[str, TensorType]]`: A list of dictionaries, each dictionary containing the mask for an image
in the batch as predicted by the model.
"""
requires_backends(self, ["torch"])
# batch_size x num_channels x 2*height x width
masks = outputs.pred_masks
# Take predicted mask as input and prompt are concatenated in the height dimension
masks = masks[:, :, masks.shape[2] // 2 :, :] # batch_size x num_channels x height x width
# To unnormalize since we have channel first we need to permute to channel last and then unnormalize
# batch_size x height x width x num_channels
masks = masks.permute(0, 2, 3, 1) * torch.tensor(self.image_std) + torch.tensor(self.image_mean)
# batch_size x num_channels x height x width
masks = masks.permute(0, 3, 1, 2)

results = []

for idx, mask in enumerate(masks):
if target_sizes is not None:
mask = torch.nn.functional.interpolate(
mask.unsqueeze(0),
size=target_sizes[idx],
mode="nearest",
)[0]

results.append({"mask": mask})

return results

def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None):
"""
Converts the output of [`SegGptImageSegmentationOutput`] into segmentation maps. Only supports
Expand Down Expand Up @@ -546,10 +504,10 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[Lis
mode="nearest",
)[0]

if self.num_classes is not None:
if self.num_labels is not None:
channels, height, width = mask.shape
dist = mask.permute(1, 2, 0).view(height, width, 1, channels) - palette_tensor.view(
1, 1, self.num_classes + 1, channels
1, 1, self.num_labels + 1, channels
)
dist = torch.pow(dist, 2)
dist = torch.sum(dist, dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/seggpt/modeling_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def forward(

>>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
>>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
>>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw)
>>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")

>>> checkpoint = "EduardoPacheco/seggpt-vit-large"
>>> model = SegGptModel.from_pretrained(checkpoint)
Expand Down
113 changes: 113 additions & 0 deletions tests/models/seggpt/test_image_processing_seggpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs


if is_vision_available():
from transformers import SegGptImageProcessor


class SegGptImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std

def prepare_image_processor_dict(self):
return {
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
}

def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]

def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)


@require_torch
@require_vision
class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
image_processing_class = SegGptImageProcessor if is_vision_available() else None

def setUp(self):
self.image_processor_tester = SegGptImageProcessingTester(self)

@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))

def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})

image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})

def test_image_processor_palette(self):
num_labels = 3
image_processing = SegGptImageProcessor(num_labels=num_labels)
self.assertEqual(image_processing.num_labels, num_labels)
self.assertEqual(len(image_processing.palette), num_labels + 1)
self.assertEqual(image_processing.palette[0], (0, 0, 0))
12 changes: 6 additions & 6 deletions tests/models/seggpt/test_modeling_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_one_shot_inference(self):
images, masks = prepare_img()
input_image = images[1]
prompt_image = images[0]
prompt_mask = masks[0]
prompt_mask = masks[0].convert("L")

inputs = image_processor(
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
Expand Down Expand Up @@ -334,11 +334,11 @@ def test_one_shot_inference(self):

self.assertTrue(torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_slice, atol=1e-4))

result = image_processor.post_process_masks(outputs, [input_image.size[::-1]])[0]["mask"]
result = image_processor.post_process_semantic_segmentation(outputs, [input_image.size[::-1]])[0]["mask"]

result_expected_shape = torch.Size((1, 3, 170, 297))
expected_area = 26654
area = (torch.clip(result * 255, 0, 255).mean(dim=1) > 0).sum().item()
result_expected_shape = torch.Size((170, 297))
expected_area = 1082
area = (result > 0).sum().item()
self.assertEqual(result.shape, result_expected_shape)
self.assertEqual(area, expected_area)

Expand All @@ -350,7 +350,7 @@ def test_few_shot_inference(self):
images, masks = prepare_img()
input_images = [images[1]] * 2
prompt_images = [images[0], images[2]]
prompt_masks = [masks[0], masks[2]]
prompt_masks = [masks[0].convert("L"), masks[2].convert("L")]

inputs = image_processor(
images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt"
Expand Down