Skip to content

[Model] Add UVDoc Model Support#43385

Merged
vasqu merged 73 commits intohuggingface:mainfrom
XingweiDeng:feat/uvdoc
Mar 20, 2026
Merged

[Model] Add UVDoc Model Support#43385
vasqu merged 73 commits intohuggingface:mainfrom
XingweiDeng:feat/uvdoc

Conversation

@XingweiDeng
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @XingweiDeng ! Thank you for working on this. And don't hesitate if you need any clarification on the review!

Comment thread src/transformers/models/uvdoc/modular_uvdoc.py
Comment on lines +77 to +90
def __init__(
self,
num_filter: int = 32,
in_channels: int = 3,
kernel_size: int = 5,
block_stride_values: list | None = None,
feature_map_multipliers: list | None = None,
block_counts_per_stage: list | None = None,
dilation_values: dict | None = None,
padding_mode: str = "reflect",
upsample_size: list | None = None,
upsample_mode: str = "bilinear",
**kwargs,
):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have None values in the config. Instantiating the config without specifying args should return a valid model config with a checkpoint on the hub. If some of these parameters are always None, we can remove them altogether

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +39 to +40
The strides for downsampling operations in the backbone network, corresponding to the scale factor between
consecutive stages of the model. Smaller strides reduce the spatial dimension of feature maps while retaining
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's something wrong with this docstring

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +102 to +104
# For image feature extraction pipeline compatibility: single class "image"
self.id2label = {0: "image"}
self.num_labels = 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove or put above

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +505 to +525
@auto_docstring(
custom_intro=r"""
The model takes raw document images (pixel values) as input, processes them through the UVDoc backbone to predict spatial transformation parameters,
and outputs the rectified (corrected) document image tensor.
"""
)
class UVDocForDocumentRectification(UVDocPreTrainedModel):
_keys_to_ignore_on_load_missing = ["num_batches_tracked"]

def __init__(self, config: UVDocConfig):
super().__init__(config)
self.model = UVDocModel(config)
self.post_init()

@can_return_tuple
def forward(
self,
pixel_values: torch.FloatTensor,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention:
return self.model(pixel_values)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a ForDocumentRectification task in the library, so let's just keep the UVDocModel and add the custom_intro to it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +58 to +61
dilation_values (`Dict[str, Union[int, List[int]]]`, *optional*, defaults to `None`):
A dictionary of dilation rates for dilated convolutional layers in bridge modules.
Dilated convolution expands the receptive field without increasing kernel size,
critical for capturing long-range geometric dependencies in distorted documents.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this be a nested list instead of a dict?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
else:
for dilation in dilation_values:
self.blocks.append(UVDocConvLayer(in_channels, in_channels, padding=dilation, dilation=dilation))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a nested list (and lists with one element for the first case) to avoid needing to separate these two cases?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +157 to +176
if isinstance(scale, (str, float, int)):
scale = torch.tensor(float(scale), device=images.device)
else:
scale = torch.tensor(255.0, device=images.device)

results = []
for image in images:
image = image[0] if isinstance(image, tuple) else image
image = image.squeeze().permute(1, 2, 0)
image = image * scale
image = image.flip(dims=[-1]).to(dtype=torch.uint8, non_blocking=True, copy=False)

results.append(
{
"images": image,
"labels": torch.zeros(1, dtype=torch.long, device=image.device), # Single class: image
}
)

return results
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really only revert the scaling (and the scale should have 255.0 as default in the signature), on top of the logic that is for now in the model.
The rest should be up to the user (changing channel position, reverting if needed to BGR, squeezing etc.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to remove these operations that generate rectified_images and directly return rectified_images for users to use according to their own ideas, or should we provide them with a parameter to choose whether to perform this operation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about changing channel position, reverting if needed to BGR, squeezing etc.

self.post_init()

@capture_outputs
@can_return_tuple
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for can_return_tuple if we have capture_outputs. And we should have auto_docstring on the forward as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks overall much better now, just my last comments

Comment on lines +47 to +76
This is the configuration class to store the configuration of a [`UVDocModel`]. It is used to instantiate
a UVDoc model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the UVDoc
[PaddlePaddle/UVDoc_safetensors](https://huggingface.co/PaddlePaddle/UVDoc_safetensors) architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
kernel_size (`int`, *optional*, defaults to 5):
Kernel size for convolutional layers in the backbone network.
resnet_head (`Sequence[list[int] | tuple[int, ...]]`, *optional*, defaults to `((3, 32), (32, 32))`):
Configuration for the ResNet head layers in format [in_channels, out_channels].
resnet_down (`Sequence[list[int] | tuple[int, ...]]`, *optional*, defaults to `((32, 32), (32, 64), (64, 128))`):
Configuration for the ResNet downsampling stages in format [in_channels, out_channels].
stage_configs (`Sequence[Sequence[tuple[int, int, int, bool] | list[int | bool]]]`, *optional*, defaults to `(((32, 32, 1, False),
(32, 32, 3, False), (32, 32, 3, False)), ((32, 64, 1, True), (64, 64, 3, False), (64, 64, 3, False), (64, 64, 3, False)), ((64, 128, 1, True), (128, 128, 3, False), (128, 128, 3, False),
(128, 128, 3, False), (128, 128, 3, False), (128, 128, 3, False)))`):
Configuration for the ResNet stages in format [in_channels, out_channels, dilation_value, downsample].
bridge_connector (`list[int] | tuple[int, ...]`, *optional*, defaults to `(128, 128)`):
Configuration for the bridge connector in format [in_channels, out_channels].
out_point_positions2D (`Sequence[list[int] | tuple[int, ...]]`, *optional*, defaults to `((128, 32), (32, 2))`):
Configuration for the output point positions 2D layer in format [in_channels, out_channels].
dilation_values (`list[list[int]] | tuple[tuple[int, ...], ...]`, *optional*, defaults to `((1,), (2,), (5,), (8, 3, 2), (12, 7, 4), (18, 12, 6))`):
Dilation rates for dilated convolutional layers in bridge modules. Each inner tuple/list contains dilation
rates for a single bridge block.
padding_mode (`str`, *optional*, defaults to `"reflect"`):
Padding mode for convolutional layers. Supported modes are `"reflect"`, `"constant"`, and `"replicate"`.
hidden_act (`str`, *optional*, defaults to `"prelu"`):
Activation function for hidden layers.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to change to new style docstrings, no intro needed and only args need to be added that don't exist in autodoc already


bridge_connector: list[int] | tuple[int, ...] = (128, 128)
out_point_positions2D: Sequence[list[int] | tuple[int, ...]] = ((128, 32), (32, 2))
dilation_values: list[list[int]] | tuple[tuple[int, ...], ...] = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing a sequence here

Comment on lines +467 to +473
outputs = []
for layer in self.bridge:
output = layer(hidden_states)
outputs.append(output)
hidden_states = torch.cat(outputs, dim=1)

return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outputs = []
for layer in self.bridge:
output = layer(hidden_states)
outputs.append(output)
hidden_states = torch.cat(outputs, dim=1)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states)
for layer in self.bridge:
hidden_states = layer(hidden_states)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states)

We provide what you want in hidden_states then, will give another comment below with what I mean

Comment on lines +476 to +483
@dataclass
class UVDocBackboneOutput(BackboneOutput):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
"""

last_hidden_state: torch.FloatTensor | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't be needed then

pixel_values: torch.FloatTensor,
**kwargs,
) -> UVDocBackboneOutput:
kwargs["output_hidden_states"] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs["output_hidden_states"] = True
kwargs["output_hidden_states"] = True # required to extract layers for the stages

nit

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebump

def forward(
self,
hidden_states: torch.Tensor,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,
**kwargs: Unpack[TransformersKwargs],

Let's add them everywhere

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok tbh, here we don't need them at all but doesnt hurt

Comment on lines +578 to +579
outputs = self.backbone(pixel_values, **kwargs)
head_outputs = self.head(outputs.last_hidden_state, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outputs = self.backbone(pixel_values, **kwargs)
head_outputs = self.head(outputs.last_hidden_state, **kwargs)
backbone_outputs = self.backbone(pixel_values, **kwargs)
fused_outputs = torch.cat(outputs.feature_maps, dim=1)
last_hidden_state = self.head(fused_outputs, **kwargs)

We do the fusing (cat) here as we want to avoid the manual loop collection

for image_processing_class in self.image_processing_classes.values():
image_processor = image_processing_class(**self.image_processor_dict)

batch_size = 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to intend here to have this whole block under the loop, same for the other tests here

class UVDocBackboneTest(BackboneTesterMixin, unittest.TestCase):
all_model_classes = (UVDocBackbone,) if is_torch_available() else ()
has_attentions = False
config_class = UVDocConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to share the same config for both - could we filter a bit out?

Comment on lines +243 to +246
@unittest.skip(reason="Large number of hidden layers but small spatial dimensions")
def test_num_layers_is_small(self):
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reduce the parameters here please

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a quick glance 🤗

(
128,
128,
1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We repeat this (128, 128) tuple - can we simplify this a bit

pixel_values: torch.FloatTensor,
**kwargs,
) -> UVDocBackboneOutput:
kwargs["output_hidden_states"] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebump



@auto_docstring
class UVDocPreTrainedModel(PPOCRV5ServerDetPreTrainedModel):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a base model prefix with backbone this should reference our base model in the model, i.e. base_model_prefix = "backbone"

Comment thread src/transformers/models/uvdoc/modular_uvdoc.py
Comment on lines +567 to +577
kwargs["output_hidden_states"] = True
hidden_states = self.resnet(pixel_values)
outputs = self.bridge(hidden_states, **kwargs)
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (outputs.hidden_states[idx],)
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs["output_hidden_states"] = True
hidden_states = self.resnet(pixel_values)
outputs = self.bridge(hidden_states, **kwargs)
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (outputs.hidden_states[idx],)
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states,
)
kwargs["output_hidden_states"] = True
hidden_states = self.resnet(pixel_values)
outputs = self.bridge(hidden_states, **kwargs)
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (outputs.hidden_states[idx],)
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states,
)

nit: just some spaces for readability

Comment on lines +238 to +240
# @unittest.skip(reason="Large number of hidden layers but small spatial dimensions")
# def test_num_layers_is_small(self):
# pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @unittest.skip(reason="Large number of hidden layers but small spatial dimensions")
# def test_num_layers_is_small(self):
# pass

probably known, just in case if missed

Comment thread utils/check_repo.py Outdated
"VibeVoiceAcousticTokenizerEncoderModel", # Tested through VibeVoiceAcousticTokenizerModel
"VibeVoiceAcousticTokenizerDecoderModel", # Tested through VibeVoiceAcousticTokenizerModel
"PI0Model", # special arch, tested through PI0ForConditionalGeneration
"UVDocBridge", # Building part of a bigger model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment that it is tested through the UVDocModel itself (like the other comments above)


@auto_docstring(checkpoint="PaddlePaddle/UVDoc_safetensors")
@strict(accept_kwargs=True)
class UVDocConfig(BackboneConfigMixin, PreTrainedConfig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea like I mentioned internally, we likely should split this into two configs tbh

  • One for the backbone only
  • One general that has that as subconfig and the rest

@huggingface huggingface deleted a comment from github-actions Bot Mar 20, 2026
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just see my 2 last points

Comment thread src/transformers/models/uvdoc/modular_uvdoc.py Outdated
Comment thread src/transformers/models/uvdoc/modular_uvdoc.py Outdated
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 20, 2026

run-slow: uvdoc

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/uvdoc"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN df2b9e17 workflow commit (merge commit)
PR 05eee3ee branch commit (from PR)
main e168f86e base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu vasqu enabled auto-merge March 20, 2026 21:50
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, uvdoc

@vasqu vasqu added this pull request to the merge queue Mar 20, 2026
Merged via the queue into huggingface:main with commit 52bc9b7 Mar 20, 2026
28 checks passed
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 20, 2026

And that's the last one!!! Huge work and thanks @XingweiDeng for sticking with the last touches 🫡

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants