Skip to content

Commit b175fc3

Browse files
authored
[DINOv2] Update pooler output (#25392)
Update pooler output
1 parent d0c1aeb commit b175fc3

File tree

1 file changed

+5
-22
lines changed

1 file changed

+5
-22
lines changed

src/transformers/models/dinov2/modeling_dinov2.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,14 @@ def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False
583583
DINOV2_START_DOCSTRING,
584584
)
585585
class Dinov2Model(Dinov2PreTrainedModel):
586-
def __init__(self, config: Dinov2Config, add_pooling_layer: bool = True):
586+
def __init__(self, config: Dinov2Config):
587587
super().__init__(config)
588588
self.config = config
589589

590590
self.embeddings = Dinov2Embeddings(config)
591591
self.encoder = Dinov2Encoder(config)
592592

593593
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
594-
self.pooler = Dinov2Pooler(config) if add_pooling_layer else None
595594

596595
# Initialize weights and apply final processing
597596
self.post_init()
@@ -651,10 +650,10 @@ def forward(
651650
)
652651
sequence_output = encoder_outputs[0]
653652
sequence_output = self.layernorm(sequence_output)
654-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
653+
pooled_output = sequence_output[:, 0, :]
655654

656655
if not return_dict:
657-
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
656+
head_outputs = (sequence_output, pooled_output)
658657
return head_outputs + encoder_outputs[1:]
659658

660659
return BaseModelOutputWithPooling(
@@ -665,22 +664,6 @@ def forward(
665664
)
666665

667666

668-
# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->Dinov2
669-
class Dinov2Pooler(nn.Module):
670-
def __init__(self, config: Dinov2Config):
671-
super().__init__()
672-
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
673-
self.activation = nn.Tanh()
674-
675-
def forward(self, hidden_states):
676-
# We "pool" the model by simply taking the hidden state corresponding
677-
# to the first token.
678-
first_token_tensor = hidden_states[:, 0]
679-
pooled_output = self.dense(first_token_tensor)
680-
pooled_output = self.activation(pooled_output)
681-
return pooled_output
682-
683-
684667
@add_start_docstrings(
685668
"""
686669
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
@@ -693,7 +676,7 @@ def __init__(self, config: Dinov2Config) -> None:
693676
super().__init__(config)
694677

695678
self.num_labels = config.num_labels
696-
self.dinov2 = Dinov2Model(config, add_pooling_layer=False)
679+
self.dinov2 = Dinov2Model(config)
697680

698681
# Classifier head
699682
self.classifier = (
@@ -770,7 +753,7 @@ def forward(
770753
loss = loss_fct(logits, labels)
771754

772755
if not return_dict:
773-
output = (logits,) + outputs[1:]
756+
output = (logits,) + outputs[2:]
774757
return ((loss,) + output) if loss is not None else output
775758

776759
return ImageClassifierOutput(

0 commit comments

Comments
 (0)