-
Notifications
You must be signed in to change notification settings - Fork 28.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Fast SamImageProcessor #36999
base: main
Are you sure you want to change the base?
Add Fast SamImageProcessor #36999
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
cc @yonigozlan |
Hi @sushmanthreddy, thanks for your contribution. I hope you don't mind that I've edited the initial message to make sure the issue is not closed when the PR is merged 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sushmanthreddy ! Thanks a lot for working on this. Looks great for a first draft! I left some comment mainly on things that don't need to be overridden as they are handled by the parent class.
I added a section on testing in the community issue #36978 , and it would be great to add an image processing test file for this model. It looks like the image processor functions are currently tested in test_processor_sam.py
, but we would need them in a SamImageProcessingTest
class in test_image_processing.py
.
Indeed SamImageProcessingTest
will inherit from ImageProcessingTestMixin
, which add some tests comparing the slow and the fast image processor outputs, which is the best way to make sure the fast image processor is correct!
You might want to add or override some tests comparing fast and slow image processors, to compare all outputs and not just pixel_values
.
|
||
class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): | ||
mask_size: Optional[Dict[str, int]] | ||
mask_pad_size: Optional[Dict[str, int]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should also have pad_size
here
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
# Initialize size dictionaries properly | ||
self.size = self.size if self.size is not None else {"longest_edge": 1024} | ||
self.pad_size = self.pad_size if self.pad_size is not None else {"height": 1024, "width": 1024} | ||
self.mask_size = self.mask_size if self.mask_size is not None else {"longest_edge": 256} | ||
self.mask_pad_size = self.mask_pad_size if self.mask_pad_size is not None else {"height": 256, "width": 256} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override as the attributes are already set as class attributes above
if isinstance(size, dict) and "longest_edge" in size: | ||
target_size = self._get_preprocess_shape(image.shape[-2:], size["longest_edge"]) | ||
resized_image = F.resize(image, target_size, interpolation=interpolation) | ||
else: | ||
resized_image = self.resize(image=image, size=size, interpolation=interpolation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here size will necessarily be a SizeDict
. To be consistent with the slow image processor, you can raise an error if size.longest_edge
is None, and have this instead:
if isinstance(size, dict) and "longest_edge" in size: | |
target_size = self._get_preprocess_shape(image.shape[-2:], size["longest_edge"]) | |
resized_image = F.resize(image, target_size, interpolation=interpolation) | |
else: | |
resized_image = self.resize(image=image, size=size, interpolation=interpolation) | |
target_size = self._get_preprocess_shape(image.shape[-2:], size.longest_edge) | |
resized_image = F.resize(image, target_size, interpolation=interpolation) |
cdac607
to
6dca4f2
Compare
@yonigozlan can u review once ? |
Hi @sushmanthreddy, can you clean up the PR with |
1453b03
to
274884a
Compare
Merge branch 'main' of https://github.com/sushmanthreddy/transformers into samfastimageprocessor
@yonigozlan now quality checks are passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sushmanthreddy , great work thanks for contributing! Definitely not an easy processor to tackle. There are a few things to change, mostly to be more consistent with fast image processors standards.
Also thanks a lot for adding tests, this processors really needs them. However please have a look at how image processing tests are structured for other models, and try to refactor this one to fit the standards.
Thanks again for the huge work!
extras["audio"] = deps_list( | ||
"librosa", | ||
"pyctcdecode", | ||
"phonemizer", | ||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extras["audio"] = deps_list( | |
"librosa", | |
"pyctcdecode", | |
"phonemizer", | |
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", | |
) | |
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5")``` |
def _rescale_and_normalize( | ||
self, | ||
image: torch.Tensor, | ||
do_rescale: bool, | ||
rescale_factor: float, | ||
do_normalize: bool, | ||
image_mean: Optional[Union[float, List[float]]], | ||
image_std: Optional[Union[float, List[float]]], | ||
) -> torch.Tensor: | ||
""" | ||
Apply rescaling and normalization to images. | ||
|
||
Args: | ||
image: Input image tensor | ||
do_rescale: Whether to apply rescaling | ||
rescale_factor: Factor to use for rescaling | ||
do_normalize: Whether to apply normalization | ||
image_mean: Mean values for normalization | ||
image_std: Standard deviation values for normalization | ||
|
||
Returns: | ||
Processed image tensor | ||
""" | ||
if do_rescale: | ||
image = image * rescale_factor | ||
|
||
if do_normalize: | ||
mean = torch.tensor(image_mean, device=image.device).view(-1, 1, 1) | ||
std = torch.tensor(image_std, device=image.device).view(-1, 1, 1) | ||
image = (image - mean) / std | ||
|
||
return image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use rescale_and_normalize
from BaseImageProcessorFast
?
""" | ||
Preprocess an image or batch of images for the SAM model. | ||
|
||
Args: | ||
images (List[torch.Tensor]): | ||
List of images to process. | ||
do_resize (bool): | ||
Whether to resize the input. | ||
size (SizeDict): | ||
Size dictionary for image resizing. | ||
interpolation (Optional[F.InterpolationMode]): | ||
Interpolation mode for resizing. | ||
do_center_crop (Optional[bool], defaults to None): | ||
Whether to center crop the input. | ||
crop_size (Optional[SizeDict], defaults to None): | ||
Size dictionary for center cropping. | ||
do_rescale (Optional[bool], defaults to None): | ||
Whether to rescale the input. | ||
rescale_factor (Optional[float], defaults to None): | ||
Factor to use for rescaling. | ||
do_normalize (Optional[bool], defaults to None): | ||
Whether to normalize the input. | ||
image_mean (Optional[Union[float, List[float]]], defaults to None): | ||
Mean values for normalization. | ||
image_std (Optional[Union[float, List[float]]], defaults to None): | ||
Standard deviation values for normalization. | ||
return_tensors (Optional[Union[str, TensorType]], defaults to None): | ||
Output tensor type. | ||
mask_size (Optional[Dict[str, int]], defaults to None): | ||
Size dictionary for mask resizing. | ||
mask_pad_size (Optional[Dict[str, int]], defaults to None): | ||
Size dictionary for mask padding. | ||
do_pad (Optional[bool], defaults to None): | ||
Whether to pad the input. | ||
pad_size (Optional[Dict[str, int]], defaults to None): | ||
Size dictionary for input padding. | ||
segmentation_maps (Optional[List[torch.Tensor]], defaults to None): | ||
Optional list of segmentation maps to process. | ||
|
||
Returns: | ||
BatchFeature: | ||
A BatchFeature object containing the processed inputs with the following fields: | ||
- pixel_values (torch.Tensor): Processed image tensors | ||
- original_sizes (List[Tuple]): Original image sizes | ||
- reshaped_input_sizes (List[Tuple]): Resized input sizes | ||
- labels (torch.Tensor, optional): Processed segmentation maps if provided | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for that, as the docstring should be on the preprocess
function, and generated with add_start_docstrings
as in other fast image processors
resize_interpolation = interpolation | ||
if resize_interpolation is None: | ||
if is_torchvision_v2_available(): | ||
resize_interpolation = F.InterpolationMode.BILINEAR | ||
else: | ||
resize_interpolation = F.InterpolationMode.BILINEAR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for that, just use interpolation directly
original_sizes = [] | ||
reshaped_input_sizes = [] | ||
processed_images = [] | ||
|
||
for image in images: | ||
original_sizes.append(image.shape[-2:]) | ||
|
||
if do_resize: | ||
target_size = self._get_preprocess_shape(image.shape[-2:], size["longest_edge"]) | ||
|
||
resize_interpolation = interpolation | ||
if resize_interpolation is None: | ||
if is_torchvision_v2_available(): | ||
resize_interpolation = F.InterpolationMode.BILINEAR | ||
else: | ||
resize_interpolation = F.InterpolationMode.BILINEAR | ||
resized_image = F.resize(image, target_size, interpolation=resize_interpolation) | ||
reshaped_input_sizes.append(resized_image.shape[-2:]) | ||
else: | ||
resized_image = image | ||
reshaped_input_sizes.append(image.shape[-2:]) | ||
|
||
processed_image = self._rescale_and_normalize( | ||
resized_image, do_rescale, rescale_factor, do_normalize, image_mean, image_std | ||
) | ||
|
||
if do_pad: | ||
padded_height, padded_width = pad_size["height"], pad_size["width"] | ||
input_height, input_width = processed_image.shape[-2:] | ||
pad_bottom = max(0, padded_height - input_height) | ||
pad_right = max(0, padded_width - input_width) | ||
padding = (0, 0, pad_right, pad_bottom) | ||
processed_image = F.pad(processed_image, padding, fill=0) | ||
|
||
processed_images.append(processed_image) | ||
|
||
processed_masks = None | ||
if segmentation_maps is not None: | ||
processed_masks = [] | ||
|
||
if len(segmentation_maps) != len(images): | ||
raise ValueError( | ||
f"Number of segmentation maps ({len(segmentation_maps)}) does not match " | ||
f"number of images ({len(images)})" | ||
) | ||
|
||
for i, mask in enumerate(segmentation_maps): | ||
if mask.dim() == 2: | ||
mask = mask.unsqueeze(0) | ||
|
||
mask_h, mask_w = mask.shape[-2:] | ||
img_h, img_w = original_sizes[i] | ||
if mask_h != img_h or mask_w != img_w: | ||
raise ValueError( | ||
f"Segmentation map size ({mask_h}, {mask_w}) does not match image size ({img_h}, {img_w})" | ||
) | ||
|
||
if do_resize and mask_size is not None: | ||
mask_target_size = self._get_preprocess_shape(mask.shape[-2:], mask_size["longest_edge"]) | ||
|
||
mask_interpolation = F.InterpolationMode.NEAREST | ||
resized_mask = F.resize(mask.float(), mask_target_size, interpolation=mask_interpolation) | ||
else: | ||
resized_mask = mask | ||
|
||
if do_pad and mask_pad_size is not None: | ||
mask_pad_h, mask_pad_w = mask_pad_size["height"], mask_pad_size["width"] | ||
mask_h, mask_w = resized_mask.shape[-2:] | ||
pad_bottom = max(0, mask_pad_h - mask_h) | ||
pad_right = max(0, mask_pad_w - mask_w) | ||
padding = (0, 0, pad_right, pad_bottom) | ||
resized_mask = F.pad(resized_mask, padding, fill=0) | ||
|
||
processed_masks.append(resized_mask.long()) | ||
|
||
if return_tensors: | ||
processed_images = torch.stack(processed_images, dim=0) | ||
if processed_masks is not None: | ||
processed_masks = torch.stack(processed_masks, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to introduce some batch processing, using group_images_by_shape
and reorder_images
, as in BaseImageProcessorFast
|
||
return out | ||
|
||
def _rle_to_mask(self, rle: Dict[str, Any]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have an optional device arg here
mask_boxes = all_boxes[keep_by_nms] | ||
|
||
# Convert RLE back to binary masks | ||
masks = [self._rle_to_mask(rle) for rle in rle_masks] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass the device if specified as an arg, otherwise get it from the input tensors
# Import SamImageProcessorFast if available | ||
try: | ||
from transformers.models.sam.image_processing_sam_fast import SamImageProcessorFast | ||
|
||
# Create a wrapper class that inherits from SamImageProcessor to satisfy type checking | ||
class SamImageProcessorFastWrapper(SamImageProcessor): | ||
""" | ||
Wrapper class for SamImageProcessorFast that inherits from SamImageProcessor | ||
to satisfy type checking in SamProcessor | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.fast_processor = SamImageProcessorFast(**kwargs) | ||
|
||
# Copy attributes from fast processor to wrapper | ||
for attr_name in dir(self.fast_processor): | ||
if not attr_name.startswith("_") and not hasattr(self, attr_name): | ||
setattr(self, attr_name, getattr(self.fast_processor, attr_name)) | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.fast_processor(*args, **kwargs) | ||
|
||
def post_process_masks(self, *args, **kwargs): | ||
return self.fast_processor.post_process_masks(*args, **kwargs) | ||
|
||
def generate_crop_boxes(self, *args, **kwargs): | ||
return self.fast_processor.generate_crop_boxes(*args, **kwargs) | ||
|
||
def filter_masks(self, *args, **kwargs): | ||
return self.fast_processor.filter_masks(*args, **kwargs) | ||
|
||
def post_process_for_mask_generation(self, *args, **kwargs): | ||
return self.fast_processor.post_process_for_mask_generation(*args, **kwargs) | ||
|
||
def to_dict(self): | ||
return self.fast_processor.to_dict() | ||
|
||
def _preprocess(self, *args, **kwargs): | ||
return self.fast_processor._preprocess(*args, **kwargs) | ||
except ImportError: | ||
SamImageProcessorFast = None | ||
SamImageProcessorFastWrapper = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for that, let's just import PixtralImageProcessorFast
under if is_torchvision_available():
SamImageProcessorFast = None | ||
SamImageProcessorFastWrapper = None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to have a SamImageProcessingTester
, check how it's done in other image processor tests
@require_vision | ||
@require_torch | ||
@require_torchvision | ||
class SamImageProcessorFastTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we should have a unique SamImageProcessingTest with attributes:
image_processing_class = SamImageProcessor if is_vision_available() else None
fast_image_processing_class = SamImageProcessorFast if is_torchvision_available() else None
and each test should be done on both processors, by iterating on self.image_processor_list
:
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
# ... rest of the test
Pls have a look at how it's done for other image processor tests
related #36978
adding sam model fast image processor