Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fast SamImageProcessor #36999

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

sushmanthreddy
Copy link
Contributor

@sushmanthreddy sushmanthreddy commented Mar 26, 2025

related #36978

adding sam model fast image processor

@github-actions github-actions bot marked this pull request as draft March 26, 2025 11:49
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@Rocketknight1
Copy link
Member

cc @NielsRogge @qubvel

@qubvel
Copy link
Member

qubvel commented Mar 26, 2025

cc @yonigozlan

@qubvel qubvel changed the title SamFASTIMAGEPROCESSOR Add Fast SamImageProcessor Mar 26, 2025
@qubvel
Copy link
Member

qubvel commented Mar 26, 2025

Hi @sushmanthreddy, thanks for your contribution. I hope you don't mind that I've edited the initial message to make sure the issue is not closed when the PR is merged 🤗

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sushmanthreddy ! Thanks a lot for working on this. Looks great for a first draft! I left some comment mainly on things that don't need to be overridden as they are handled by the parent class.

I added a section on testing in the community issue #36978 , and it would be great to add an image processing test file for this model. It looks like the image processor functions are currently tested in test_processor_sam.py, but we would need them in a SamImageProcessingTest class in test_image_processing.py.

Indeed SamImageProcessingTest will inherit from ImageProcessingTestMixin, which add some tests comparing the slow and the fast image processor outputs, which is the best way to make sure the fast image processor is correct!
You might want to add or override some tests comparing fast and slow image processors, to compare all outputs and not just pixel_values.


class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
mask_size: Optional[Dict[str, int]]
mask_pad_size: Optional[Dict[str, int]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also have pad_size here

Comment on lines 93 to 100
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

# Initialize size dictionaries properly
self.size = self.size if self.size is not None else {"longest_edge": 1024}
self.pad_size = self.pad_size if self.pad_size is not None else {"height": 1024, "width": 1024}
self.mask_size = self.mask_size if self.mask_size is not None else {"longest_edge": 256}
self.mask_pad_size = self.mask_pad_size if self.mask_pad_size is not None else {"height": 256, "width": 256}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override as the attributes are already set as class attributes above

Comment on lines 147 to 151
if isinstance(size, dict) and "longest_edge" in size:
target_size = self._get_preprocess_shape(image.shape[-2:], size["longest_edge"])
resized_image = F.resize(image, target_size, interpolation=interpolation)
else:
resized_image = self.resize(image=image, size=size, interpolation=interpolation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here size will necessarily be a SizeDict. To be consistent with the slow image processor, you can raise an error if size.longest_edge is None, and have this instead:

Suggested change
if isinstance(size, dict) and "longest_edge" in size:
target_size = self._get_preprocess_shape(image.shape[-2:], size["longest_edge"])
resized_image = F.resize(image, target_size, interpolation=interpolation)
else:
resized_image = self.resize(image=image, size=size, interpolation=interpolation)
target_size = self._get_preprocess_shape(image.shape[-2:], size.longest_edge)
resized_image = F.resize(image, target_size, interpolation=interpolation)

@sushmanthreddy sushmanthreddy marked this pull request as ready for review March 27, 2025 18:34
@sushmanthreddy
Copy link
Contributor Author

@yonigozlan can u review once ?

@yonigozlan
Copy link
Member

Hi @sushmanthreddy, can you clean up the PR with make style before I review? looks like something went wrong with the formatting

@sushmanthreddy sushmanthreddy force-pushed the samfastimageprocessor branch from 1453b03 to 274884a Compare March 31, 2025 17:13
@sushmanthreddy
Copy link
Contributor Author

@yonigozlan now quality checks are passing
can u pls review once ?

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sushmanthreddy , great work thanks for contributing! Definitely not an easy processor to tackle. There are a few things to change, mostly to be more consistent with fast image processors standards.
Also thanks a lot for adding tests, this processors really needs them. However please have a look at how image processing tests are structured for other models, and try to refactor this one to fit the standards.
Thanks again for the huge work!

Comment on lines +310 to +315
extras["audio"] = deps_list(
"librosa",
"pyctcdecode",
"phonemizer",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extras["audio"] = deps_list(
"librosa",
"pyctcdecode",
"phonemizer",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
)
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5")```

Comment on lines +152 to +183
def _rescale_and_normalize(
self,
image: torch.Tensor,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
) -> torch.Tensor:
"""
Apply rescaling and normalization to images.

Args:
image: Input image tensor
do_rescale: Whether to apply rescaling
rescale_factor: Factor to use for rescaling
do_normalize: Whether to apply normalization
image_mean: Mean values for normalization
image_std: Standard deviation values for normalization

Returns:
Processed image tensor
"""
if do_rescale:
image = image * rescale_factor

if do_normalize:
mean = torch.tensor(image_mean, device=image.device).view(-1, 1, 1)
std = torch.tensor(image_std, device=image.device).view(-1, 1, 1)
image = (image - mean) / std

return image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use rescale_and_normalize from BaseImageProcessorFast?

Comment on lines +206 to +252
"""
Preprocess an image or batch of images for the SAM model.

Args:
images (List[torch.Tensor]):
List of images to process.
do_resize (bool):
Whether to resize the input.
size (SizeDict):
Size dictionary for image resizing.
interpolation (Optional[F.InterpolationMode]):
Interpolation mode for resizing.
do_center_crop (Optional[bool], defaults to None):
Whether to center crop the input.
crop_size (Optional[SizeDict], defaults to None):
Size dictionary for center cropping.
do_rescale (Optional[bool], defaults to None):
Whether to rescale the input.
rescale_factor (Optional[float], defaults to None):
Factor to use for rescaling.
do_normalize (Optional[bool], defaults to None):
Whether to normalize the input.
image_mean (Optional[Union[float, List[float]]], defaults to None):
Mean values for normalization.
image_std (Optional[Union[float, List[float]]], defaults to None):
Standard deviation values for normalization.
return_tensors (Optional[Union[str, TensorType]], defaults to None):
Output tensor type.
mask_size (Optional[Dict[str, int]], defaults to None):
Size dictionary for mask resizing.
mask_pad_size (Optional[Dict[str, int]], defaults to None):
Size dictionary for mask padding.
do_pad (Optional[bool], defaults to None):
Whether to pad the input.
pad_size (Optional[Dict[str, int]], defaults to None):
Size dictionary for input padding.
segmentation_maps (Optional[List[torch.Tensor]], defaults to None):
Optional list of segmentation maps to process.

Returns:
BatchFeature:
A BatchFeature object containing the processed inputs with the following fields:
- pixel_values (torch.Tensor): Processed image tensors
- original_sizes (List[Tuple]): Original image sizes
- reshaped_input_sizes (List[Tuple]): Resized input sizes
- labels (torch.Tensor, optional): Processed segmentation maps if provided
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for that, as the docstring should be on the preprocess function, and generated with add_start_docstrings as in other fast image processors

Comment on lines +264 to +269
resize_interpolation = interpolation
if resize_interpolation is None:
if is_torchvision_v2_available():
resize_interpolation = F.InterpolationMode.BILINEAR
else:
resize_interpolation = F.InterpolationMode.BILINEAR
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for that, just use interpolation directly

Comment on lines +254 to +332
original_sizes = []
reshaped_input_sizes = []
processed_images = []

for image in images:
original_sizes.append(image.shape[-2:])

if do_resize:
target_size = self._get_preprocess_shape(image.shape[-2:], size["longest_edge"])

resize_interpolation = interpolation
if resize_interpolation is None:
if is_torchvision_v2_available():
resize_interpolation = F.InterpolationMode.BILINEAR
else:
resize_interpolation = F.InterpolationMode.BILINEAR
resized_image = F.resize(image, target_size, interpolation=resize_interpolation)
reshaped_input_sizes.append(resized_image.shape[-2:])
else:
resized_image = image
reshaped_input_sizes.append(image.shape[-2:])

processed_image = self._rescale_and_normalize(
resized_image, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)

if do_pad:
padded_height, padded_width = pad_size["height"], pad_size["width"]
input_height, input_width = processed_image.shape[-2:]
pad_bottom = max(0, padded_height - input_height)
pad_right = max(0, padded_width - input_width)
padding = (0, 0, pad_right, pad_bottom)
processed_image = F.pad(processed_image, padding, fill=0)

processed_images.append(processed_image)

processed_masks = None
if segmentation_maps is not None:
processed_masks = []

if len(segmentation_maps) != len(images):
raise ValueError(
f"Number of segmentation maps ({len(segmentation_maps)}) does not match "
f"number of images ({len(images)})"
)

for i, mask in enumerate(segmentation_maps):
if mask.dim() == 2:
mask = mask.unsqueeze(0)

mask_h, mask_w = mask.shape[-2:]
img_h, img_w = original_sizes[i]
if mask_h != img_h or mask_w != img_w:
raise ValueError(
f"Segmentation map size ({mask_h}, {mask_w}) does not match image size ({img_h}, {img_w})"
)

if do_resize and mask_size is not None:
mask_target_size = self._get_preprocess_shape(mask.shape[-2:], mask_size["longest_edge"])

mask_interpolation = F.InterpolationMode.NEAREST
resized_mask = F.resize(mask.float(), mask_target_size, interpolation=mask_interpolation)
else:
resized_mask = mask

if do_pad and mask_pad_size is not None:
mask_pad_h, mask_pad_w = mask_pad_size["height"], mask_pad_size["width"]
mask_h, mask_w = resized_mask.shape[-2:]
pad_bottom = max(0, mask_pad_h - mask_h)
pad_right = max(0, mask_pad_w - mask_w)
padding = (0, 0, pad_right, pad_bottom)
resized_mask = F.pad(resized_mask, padding, fill=0)

processed_masks.append(resized_mask.long())

if return_tensors:
processed_images = torch.stack(processed_images, dim=0)
if processed_masks is not None:
processed_masks = torch.stack(processed_masks, dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to introduce some batch processing, using group_images_by_shape and reorder_images, as in BaseImageProcessorFast


return out

def _rle_to_mask(self, rle: Dict[str, Any]) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have an optional device arg here

mask_boxes = all_boxes[keep_by_nms]

# Convert RLE back to binary masks
masks = [self._rle_to_mask(rle) for rle in rle_masks]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass the device if specified as an arg, otherwise get it from the input tensors

Comment on lines +38 to +80
# Import SamImageProcessorFast if available
try:
from transformers.models.sam.image_processing_sam_fast import SamImageProcessorFast

# Create a wrapper class that inherits from SamImageProcessor to satisfy type checking
class SamImageProcessorFastWrapper(SamImageProcessor):
"""
Wrapper class for SamImageProcessorFast that inherits from SamImageProcessor
to satisfy type checking in SamProcessor
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.fast_processor = SamImageProcessorFast(**kwargs)

# Copy attributes from fast processor to wrapper
for attr_name in dir(self.fast_processor):
if not attr_name.startswith("_") and not hasattr(self, attr_name):
setattr(self, attr_name, getattr(self.fast_processor, attr_name))

def __call__(self, *args, **kwargs):
return self.fast_processor(*args, **kwargs)

def post_process_masks(self, *args, **kwargs):
return self.fast_processor.post_process_masks(*args, **kwargs)

def generate_crop_boxes(self, *args, **kwargs):
return self.fast_processor.generate_crop_boxes(*args, **kwargs)

def filter_masks(self, *args, **kwargs):
return self.fast_processor.filter_masks(*args, **kwargs)

def post_process_for_mask_generation(self, *args, **kwargs):
return self.fast_processor.post_process_for_mask_generation(*args, **kwargs)

def to_dict(self):
return self.fast_processor.to_dict()

def _preprocess(self, *args, **kwargs):
return self.fast_processor._preprocess(*args, **kwargs)
except ImportError:
SamImageProcessorFast = None
SamImageProcessorFastWrapper = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for that, let's just import PixtralImageProcessorFast under if is_torchvision_available():

SamImageProcessorFast = None
SamImageProcessorFastWrapper = None


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to have a SamImageProcessingTester, check how it's done in other image processor tests

@require_vision
@require_torch
@require_torchvision
class SamImageProcessorFastTest(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should have a unique SamImageProcessingTest with attributes:

    image_processing_class = SamImageProcessor if is_vision_available() else None
    fast_image_processing_class = SamImageProcessorFast if is_torchvision_available() else None

and each test should be done on both processors, by iterating on self.image_processor_list:

        for image_processing_class in self.image_processor_list:
            image_processing = image_processing_class(**self.image_processor_dict)
            # ... rest of the test

Pls have a look at how it's done for other image processor tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants