|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Fast Image processor class for GLPN.""" |
| 16 | + |
| 17 | +from typing import Optional, Union |
| 18 | + |
| 19 | +import torch |
| 20 | +from torchvision.transforms.v2 import functional as F |
| 21 | + |
| 22 | +from ...image_processing_utils import BatchFeature |
| 23 | +from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images |
| 24 | +from ...image_utils import ( |
| 25 | + PILImageResampling, |
| 26 | + SizeDict, |
| 27 | +) |
| 28 | +from ...utils import ( |
| 29 | + TensorType, |
| 30 | + auto_docstring, |
| 31 | + requires_backends, |
| 32 | +) |
| 33 | +from .image_processing_glpn import GLPNImageProcessorKwargs |
| 34 | + |
| 35 | + |
| 36 | +@auto_docstring |
| 37 | +class GLPNImageProcessorFast(BaseImageProcessorFast): |
| 38 | + do_resize = True |
| 39 | + do_rescale = True |
| 40 | + rescale_factor = 1 / 255 |
| 41 | + resample = PILImageResampling.BILINEAR |
| 42 | + size_divisor = 32 |
| 43 | + valid_kwargs = GLPNImageProcessorKwargs |
| 44 | + |
| 45 | + def _validate_preprocess_kwargs(self, **kwargs): |
| 46 | + # pop `do_resize` to not raise an error as `size` is not None |
| 47 | + kwargs.pop("do_resize", None) |
| 48 | + return super()._validate_preprocess_kwargs(**kwargs) |
| 49 | + |
| 50 | + def resize( |
| 51 | + self, |
| 52 | + image: "torch.Tensor", |
| 53 | + size_divisor: int, |
| 54 | + interpolation: Optional["F.InterpolationMode"] = None, |
| 55 | + antialias: bool = True, |
| 56 | + **kwargs, |
| 57 | + ) -> "torch.Tensor": |
| 58 | + """ |
| 59 | + Resize an image to `(size["height"], size["width"])`. |
| 60 | +
|
| 61 | + Args: |
| 62 | + image (`torch.Tensor`): |
| 63 | + Image to resize. |
| 64 | + size (`SizeDict`): |
| 65 | + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. |
| 66 | + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): |
| 67 | + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. |
| 68 | + antialias (`bool`, *optional*, defaults to `True`): |
| 69 | + Whether to use antialiasing. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + `torch.Tensor`: The resized image. |
| 73 | + """ |
| 74 | + height, width = image.shape[-2:] |
| 75 | + # Rounds the height and width down to the closest multiple of size_divisor |
| 76 | + new_h = height // size_divisor * size_divisor |
| 77 | + new_w = width // size_divisor * size_divisor |
| 78 | + return super().resize( |
| 79 | + image, SizeDict(height=new_h, width=new_w), interpolation=interpolation, antialias=antialias |
| 80 | + ) |
| 81 | + |
| 82 | + def _preprocess( |
| 83 | + self, |
| 84 | + images: list["torch.Tensor"], |
| 85 | + do_resize: bool, |
| 86 | + size_divisor: Optional[int] = None, |
| 87 | + interpolation: Optional["F.InterpolationMode"] = None, |
| 88 | + do_rescale: bool = True, |
| 89 | + rescale_factor: Optional[float] = 1 / 255, |
| 90 | + do_normalize: bool = False, |
| 91 | + image_mean: Optional[Union[float, list[float]]] = None, |
| 92 | + image_std: Optional[Union[float, list[float]]] = None, |
| 93 | + disable_grouping: Optional[bool] = None, |
| 94 | + return_tensors: Optional[Union[str, TensorType]] = None, |
| 95 | + resample: Optional[PILImageResampling] = None, |
| 96 | + **kwargs, |
| 97 | + ) -> BatchFeature: |
| 98 | + grouped_images, grouped_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
| 99 | + processed_groups = {} |
| 100 | + |
| 101 | + for shape, stacked_images in grouped_images.items(): |
| 102 | + if do_resize: |
| 103 | + stacked_images = self.resize(stacked_images, size_divisor=size_divisor, interpolation=interpolation) |
| 104 | + stacked_images = self.rescale_and_normalize( |
| 105 | + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std |
| 106 | + ) |
| 107 | + processed_groups[shape] = stacked_images |
| 108 | + |
| 109 | + processed_images = reorder_images(processed_groups, grouped_index) |
| 110 | + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images |
| 111 | + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) |
| 112 | + |
| 113 | + def post_process_depth_estimation(self, outputs, target_sizes=None): |
| 114 | + """ |
| 115 | + Convert raw model outputs to final depth predictions. |
| 116 | + Mirrors slow GLPN: PyTorch interpolate w/ bicubic, align_corners=False. |
| 117 | + """ |
| 118 | + requires_backends(self, "torch") |
| 119 | + predicted_depth = outputs.predicted_depth |
| 120 | + |
| 121 | + results = [] |
| 122 | + target_sizes = target_sizes or [None] * predicted_depth.shape[0] |
| 123 | + for depth, target_size in zip(predicted_depth, target_sizes): |
| 124 | + if target_size is not None: |
| 125 | + # Add batch and channel dimensions for interpolation |
| 126 | + depth_4d = depth[None, None, ...] |
| 127 | + resized = torch.nn.functional.interpolate( |
| 128 | + depth_4d, size=target_size, mode="bicubic", align_corners=False |
| 129 | + ) |
| 130 | + depth = resized.squeeze(0).squeeze(0) |
| 131 | + results.append({"predicted_depth": depth}) |
| 132 | + |
| 133 | + return results |
| 134 | + |
| 135 | + |
| 136 | +__all__ = ["GLPNImageProcessorFast"] |
0 commit comments