Skip to content

Conversation

@qubvel
Copy link
Contributor

@qubvel qubvel commented Mar 3, 2025

What does this PR do?`

Improves code readability for object-detection models, should be merged after

Note:

RT-DETR and RT-DETR-v2 are refactored, for other models just Freezed BatchNorm2d module updated.

RT-DETR refactoring:

  • Make code closer to transformers standards (no one-letter vars, common var names, ...)
  • Remove some unused outputs (seems not necessary, but add overhead on implementation and docs)
  • Clean up signatures from unused arguments (where it's possible)
  • Better comments and shape comments
  • Add default docstring for object-detection
  • Unify RT-DETR and RT-DETRv2 implementations (simpler modular file for v2)

In addition:

  • Fix CoreML export for RT-DETR (avoid 6D tensors in deformable_attention)

Mostly, the changes are backward compatible, but some internal modules' signatures have changed. I suspect RT-DETR internal modules are used on their own somewhere outside the transformers repository. A quick search on GitHub shows only transformers forks, and I have never seen custom code samples on HF Hub for object detection. Therefore, I think we can change them to establish better standards for upcoming models such as D-Fine (a state-of-the-art real-time object detector) and RT-DETRv3, both of which are based on RT-DETR.

Test changes

I tried to keep tests untouched, but there are a few modifications:

  • the number of outputs in ModelOuput is reduced

Who can review?

cc @SangbumChoi @jadechoghari as model contributors, in case you want to have a look. Please let me know if anything is missing / incorrectly re-implemented

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qubvel qubvel added the Vision label Mar 3, 2025
Copy link
Contributor

@SangbumChoi SangbumChoi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good :) (In general I think the PR is in still draft.)

backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this replace_batch_norm has no_grad function inside the definitions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we are just copying tensor data, so no need for no-grad. I tested with no-grad removed, replace_batch_norm works fine

@@ -326,141 +224,411 @@ class RTDetrObjectDetectionOutput(ModelOutput):
pred_boxes: torch.FloatTensor = None
auxiliary_outputs: Optional[List[Dict]] = None
last_hidden_state: torch.FloatTensor = None
intermediate_hidden_states: torch.FloatTensor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also agree removing these variables. Just for the record I thought this was kind of legacy code that other detection model all had this style of this code. HF team might want to refactor all this implementation since it is not used for inference. However, some of the transformer based model might require this for fine-tuning such as auxiliary or intermediate loss

Copy link
Contributor Author

@qubvel qubvel Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will double-check, but it seems redundant. As far as I recall it is similar to decoder hidden states (shifted by one, because taken after the layer applied)

# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
params = {"kernel_size": 1, "stride": 1, "activation": config.activation_function}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there reason for writing params as dictionary rather than just manually write all variables in the corresponding function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it might not be the best approach, I did it just for code readability:

  • No line breaks for module definitions (in my opinion, it's a bit easier to read)
  • It also shows that parameters are identical across modules

However, it's very subjective.

num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
embed_dim=config.encoder_hidden_dim,
num_heads=config.num_attention_heads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about this change? Even though the default config is same, there are two different configuration for each decoder and encoder layer for num_heads.

also same comment for the below attention_dropout and dropout is quite different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! That's indeed should be reverted, thanks a lot

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's actually git diff issue! It's two different modules DecoderLayer (red) and EncoderLayer (green). So there were no changes aclually

level_anchors = torch.concat([grid_xy, grid_wh], dim=-1).reshape(height * width, 4)
anchors.append(level_anchors)
anchors = torch.concat(anchors).unsqueeze(0)

# define the valid range for anchor coordinates
eps = 1e-2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe define this value as an additional argument?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants