Skip to content

Commit 9a19171

Browse files
Add GLPNImageProcessorFast (#41725)
* Add GLPNImageProcessorFast for torch backend * Address review feedback - Simplified to_dict() method - Keep tensors as torch instead of converting to numpy for heterogeneous shapes - Removed unnecessary shape guards in post_process_depth_estimation - Improved variable names (tgt -> target_size, d -> resized) - Removed unnecessary GLPNImageProcessorKwargs class * Address review feedback - Simplified to_dict() method - Keep tensors as torch instead of converting to numpy for heterogeneous shapes - Removed unnecessary shape guards in post_process_depth_estimation - Improved variable names (tgt -> target_size, d -> resized) - Removed unnecessary GLPNImageProcessorKwargs class * commits after 2nd review * Address all review feedback and add explicit batched test - Simplified to_dict() with descriptive variable names (d->output_dict) - Fixed resize operation: changed from crop to proper resize with interpolation - Added padding for heterogeneous batch shapes in both slow and fast processors - Fused rescale and normalize operations for efficiency - Improved all variable names (tgt->target_size, d->depth_4d->resized) - Added GLPNImageProcessorKwargs class in slow processor and imported in fast - Renamed test_equivalence_slow_fast to test_slow_fast_equivalence - Added explicit test_slow_fast_equivalence_batched test - All 20 tests passing * using padding from utils * simplify glpn image processor fast * fix docstring --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
1 parent 26fca86 commit 9a19171

File tree

7 files changed

+228
-8
lines changed

7 files changed

+228
-8
lines changed

docs/source/en/model_doc/glpn.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
6161
[[autodoc]] GLPNImageProcessor
6262
- preprocess
6363

64+
## GLPNImageProcessorFast
65+
66+
[[autodoc]] GLPNImageProcessorFast
67+
- preprocess
68+
6469
## GLPNModel
6570

6671
[[autodoc]] GLPNModel

src/transformers/image_processing_utils_fast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ def resize(
306306
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
307307
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
308308
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
309+
antialias (`bool`, *optional*, defaults to `True`):
310+
Whether to use antialiasing.
309311
310312
Returns:
311313
`torch.Tensor`: The resized image.

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
104104
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
105105
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
106-
("glpn", ("GLPNImageProcessor", None)),
106+
("glpn", ("GLPNImageProcessor", "GLPNImageProcessorFast")),
107107
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
108108
("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
109109
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),

src/transformers/models/glpn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .configuration_glpn import *
2222
from .feature_extraction_glpn import *
2323
from .image_processing_glpn import *
24+
from .image_processing_glpn_fast import *
2425
from .modeling_glpn import *
2526
else:
2627
import sys

src/transformers/models/glpn/image_processing_glpn.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
valid_images,
4040
validate_preprocess_arguments,
4141
)
42+
from ...processing_utils import ImagesKwargs
4243
from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends
4344

4445

@@ -49,6 +50,17 @@
4950
logger = logging.get_logger(__name__)
5051

5152

53+
class GLPNImageProcessorKwargs(ImagesKwargs, total=False):
54+
"""
55+
size_divisor (`int`, *optional*, defaults to 32):
56+
When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
57+
multiple of `size_divisor`.
58+
"""
59+
60+
size_divisor: int
61+
resample: PILImageResampling
62+
63+
5264
@requires(backends=("vision",))
5365
class GLPNImageProcessor(BaseImageProcessor):
5466
r"""
@@ -66,22 +78,27 @@ class GLPNImageProcessor(BaseImageProcessor):
6678
do_rescale (`bool`, *optional*, defaults to `True`):
6779
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be
6880
overridden by `do_rescale` in `preprocess`.
81+
rescale_factor (`float`, *optional*, defaults to `1 / 255`):
82+
The scaling factor to apply to the pixel values. Can be overridden by `rescale_factor` in `preprocess`.
6983
"""
7084

7185
model_input_names = ["pixel_values"]
86+
valid_kwargs = GLPNImageProcessorKwargs
7287

7388
def __init__(
7489
self,
7590
do_resize: bool = True,
7691
size_divisor: int = 32,
7792
resample=PILImageResampling.BILINEAR,
7893
do_rescale: bool = True,
94+
rescale_factor: Optional[float] = 1 / 255,
7995
**kwargs,
8096
) -> None:
8197
self.do_resize = do_resize
8298
self.do_rescale = do_rescale
8399
self.size_divisor = size_divisor
84100
self.resample = resample
101+
self.rescale_factor = rescale_factor
85102
super().__init__(**kwargs)
86103

87104
def resize(
@@ -142,6 +159,7 @@ def preprocess(
142159
size_divisor: Optional[int] = None,
143160
resample=None,
144161
do_rescale: Optional[bool] = None,
162+
rescale_factor: Optional[float] = None,
145163
return_tensors: Optional[Union[TensorType, str]] = None,
146164
data_format: ChannelDimension = ChannelDimension.FIRST,
147165
input_data_format: Optional[Union[str, ChannelDimension]] = None,
@@ -181,6 +199,7 @@ def preprocess(
181199
"""
182200
do_resize = do_resize if do_resize is not None else self.do_resize
183201
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
202+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
184203
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
185204
resample = resample if resample is not None else self.resample
186205

@@ -217,7 +236,9 @@ def preprocess(
217236
]
218237

219238
if do_rescale:
220-
images = [self.rescale(image, scale=1 / 255, input_data_format=input_data_format) for image in images]
239+
images = [
240+
self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images
241+
]
221242

222243
images = [
223244
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Fast Image processor class for GLPN."""
16+
17+
from typing import Optional, Union
18+
19+
import torch
20+
from torchvision.transforms.v2 import functional as F
21+
22+
from ...image_processing_utils import BatchFeature
23+
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
24+
from ...image_utils import (
25+
PILImageResampling,
26+
SizeDict,
27+
)
28+
from ...utils import (
29+
TensorType,
30+
auto_docstring,
31+
requires_backends,
32+
)
33+
from .image_processing_glpn import GLPNImageProcessorKwargs
34+
35+
36+
@auto_docstring
37+
class GLPNImageProcessorFast(BaseImageProcessorFast):
38+
do_resize = True
39+
do_rescale = True
40+
rescale_factor = 1 / 255
41+
resample = PILImageResampling.BILINEAR
42+
size_divisor = 32
43+
valid_kwargs = GLPNImageProcessorKwargs
44+
45+
def _validate_preprocess_kwargs(self, **kwargs):
46+
# pop `do_resize` to not raise an error as `size` is not None
47+
kwargs.pop("do_resize", None)
48+
return super()._validate_preprocess_kwargs(**kwargs)
49+
50+
def resize(
51+
self,
52+
image: "torch.Tensor",
53+
size_divisor: int,
54+
interpolation: Optional["F.InterpolationMode"] = None,
55+
antialias: bool = True,
56+
**kwargs,
57+
) -> "torch.Tensor":
58+
"""
59+
Resize an image to `(size["height"], size["width"])`.
60+
61+
Args:
62+
image (`torch.Tensor`):
63+
Image to resize.
64+
size (`SizeDict`):
65+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
66+
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
67+
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
68+
antialias (`bool`, *optional*, defaults to `True`):
69+
Whether to use antialiasing.
70+
71+
Returns:
72+
`torch.Tensor`: The resized image.
73+
"""
74+
height, width = image.shape[-2:]
75+
# Rounds the height and width down to the closest multiple of size_divisor
76+
new_h = height // size_divisor * size_divisor
77+
new_w = width // size_divisor * size_divisor
78+
return super().resize(
79+
image, SizeDict(height=new_h, width=new_w), interpolation=interpolation, antialias=antialias
80+
)
81+
82+
def _preprocess(
83+
self,
84+
images: list["torch.Tensor"],
85+
do_resize: bool,
86+
size_divisor: Optional[int] = None,
87+
interpolation: Optional["F.InterpolationMode"] = None,
88+
do_rescale: bool = True,
89+
rescale_factor: Optional[float] = 1 / 255,
90+
do_normalize: bool = False,
91+
image_mean: Optional[Union[float, list[float]]] = None,
92+
image_std: Optional[Union[float, list[float]]] = None,
93+
disable_grouping: Optional[bool] = None,
94+
return_tensors: Optional[Union[str, TensorType]] = None,
95+
resample: Optional[PILImageResampling] = None,
96+
**kwargs,
97+
) -> BatchFeature:
98+
grouped_images, grouped_index = group_images_by_shape(images, disable_grouping=disable_grouping)
99+
processed_groups = {}
100+
101+
for shape, stacked_images in grouped_images.items():
102+
if do_resize:
103+
stacked_images = self.resize(stacked_images, size_divisor=size_divisor, interpolation=interpolation)
104+
stacked_images = self.rescale_and_normalize(
105+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
106+
)
107+
processed_groups[shape] = stacked_images
108+
109+
processed_images = reorder_images(processed_groups, grouped_index)
110+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
111+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
112+
113+
def post_process_depth_estimation(self, outputs, target_sizes=None):
114+
"""
115+
Convert raw model outputs to final depth predictions.
116+
Mirrors slow GLPN: PyTorch interpolate w/ bicubic, align_corners=False.
117+
"""
118+
requires_backends(self, "torch")
119+
predicted_depth = outputs.predicted_depth
120+
121+
results = []
122+
target_sizes = target_sizes or [None] * predicted_depth.shape[0]
123+
for depth, target_size in zip(predicted_depth, target_sizes):
124+
if target_size is not None:
125+
# Add batch and channel dimensions for interpolation
126+
depth_4d = depth[None, None, ...]
127+
resized = torch.nn.functional.interpolate(
128+
depth_4d, size=target_size, mode="bicubic", align_corners=False
129+
)
130+
depth = resized.squeeze(0).squeeze(0)
131+
results.append({"predicted_depth": depth})
132+
133+
return results
134+
135+
136+
__all__ = ["GLPNImageProcessorFast"]

tests/models/glpn/test_image_processing_glpn.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919

2020
from transformers.testing_utils import require_torch, require_vision
21-
from transformers.utils import is_torch_available, is_vision_available
21+
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
2222

2323
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
2424

@@ -31,6 +31,9 @@
3131

3232
from transformers import GLPNImageProcessor
3333

34+
if is_torchvision_available():
35+
from transformers import GLPNImageProcessorFast
36+
3437

3538
class GLPNImageProcessingTester:
3639
def __init__(
@@ -87,19 +90,32 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
8790
torchify=torchify,
8891
)
8992

93+
def prepare_depth_outputs(self):
94+
if not is_torch_available():
95+
return None
96+
depth_tensors = prepare_image_inputs(
97+
batch_size=self.batch_size,
98+
num_channels=1,
99+
min_resolution=self.min_resolution,
100+
max_resolution=self.max_resolution,
101+
equal_resolution=True,
102+
torchify=True,
103+
)
104+
depth_tensors = [depth_tensor.squeeze(0) for depth_tensor in depth_tensors]
105+
stacked_depth_tensors = torch.stack(depth_tensors, dim=0)
106+
return type("DepthOutput", (), {"predicted_depth": stacked_depth_tensors})
107+
90108

91109
@require_torch
92110
@require_vision
93111
class GLPNImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
94112
image_processing_class = GLPNImageProcessor if is_vision_available() else None
113+
fast_image_processing_class = GLPNImageProcessorFast if is_torchvision_available() else None
95114

96115
def setUp(self):
97116
super().setUp()
98117
self.image_processor_tester = GLPNImageProcessingTester(self)
99-
100-
@property
101-
def image_processor_dict(self):
102-
return self.image_processor_tester.prepare_image_processor_dict()
118+
self.image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
103119

104120
def test_image_processor_properties(self):
105121
image_processing = self.image_processing_class(**self.image_processor_dict)
@@ -115,7 +131,6 @@ def test_call_pil(self):
115131
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
116132
for image in image_inputs:
117133
self.assertIsInstance(image, Image.Image)
118-
119134
# Test not batched input (GLPNImageProcessor doesn't support batching)
120135
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
121136
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
@@ -161,3 +176,43 @@ def test_call_numpy_4_channels(self):
161176
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
162177
self.assertTrue(tuple(encoded_images.shape) == (1, *expected_output_image_shape))
163178
self.image_processing_class.num_channels = 3
179+
180+
# override as glpn image processors don't support heterogeneous batching
181+
@require_vision
182+
@require_torch
183+
def test_slow_fast_equivalence_batched(self):
184+
if not self.test_slow_image_processor or not self.test_fast_image_processor:
185+
self.skipTest(reason="Skipping slow/fast equivalence test")
186+
187+
if self.image_processing_class is None or self.fast_image_processing_class is None:
188+
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
189+
190+
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
191+
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
192+
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
193+
194+
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
195+
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
196+
197+
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
198+
199+
def test_post_process_depth_equivalence(self):
200+
# Check that both processors produce equivalent post-processed depth maps
201+
if self.fast_image_processing_class is None:
202+
self.skipTest("TorchVision not available")
203+
204+
outputs = self.image_processor_tester.prepare_depth_outputs()
205+
slow = self.image_processing_class(**self.image_processor_dict)
206+
fast = self.fast_image_processing_class(**self.image_processor_dict)
207+
208+
# target_sizes simulate resized inference outputs
209+
target_sizes = [(240, 320)] * self.image_processor_tester.batch_size
210+
processed_slow = slow.post_process_depth_estimation(outputs, target_sizes=target_sizes)
211+
processed_fast = fast.post_process_depth_estimation(outputs, target_sizes=target_sizes)
212+
213+
# Compare per-sample predicted depth tensors
214+
for pred_slow, pred_fast in zip(processed_slow, processed_fast):
215+
depth_slow = pred_slow["predicted_depth"]
216+
depth_fast = pred_fast["predicted_depth"]
217+
torch.testing.assert_close(depth_fast, depth_slow, atol=1e-1, rtol=1e-3)
218+
self.assertLessEqual(torch.mean(torch.abs(depth_fast.float() - depth_slow.float())).item(), 5e-3)

0 commit comments

Comments
 (0)