Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@
Dinov2WithRegistersModel,
Dinov2WithRegistersPreTrainedModel,
)
from .models.dinov3_convnext import DINOv3ConvNextModel, DINOv3ConvNextPreTrainedModel
from .models.dinov3_vit import DINOv3ViTModel, DINOv3ViTPreTrainedModel, DINOv3ViTImageProcessorFast
from .models.distilbert import (
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
diffllama,
dinov2,
dinov2_with_registers,
dinov3_convnext,
dinov3_vit,
distilbert,
dpr,
dpt,
Expand Down
4 changes: 4 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
("diffllama", "DiffLlamaConfig"),
("dinov2", "Dinov2Config"),
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
("dinov3_convnext", "DINOv3ConvNextConfig"),
("dinov3_vit", "DINOv3ViTConfig"),
("deit", "DeiTConfig"),
("distilbert", "DistilBertConfig"),
("dpr", "DPRConfig"),
Expand Down Expand Up @@ -355,6 +357,8 @@
("diffllama", "DiffLlama"),
("dinov2", "DINOv2"),
("dinov2_with_registers", "DINOv2 with Registers"),
("dinov3_convnext", "DINOv3 ConvNext"),
("dinov3_vit", "DINOv3 ViT"),
("distilbert", "DistilBERT"),
("dpr", "DPR"),
("dpt", "DPT"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
("depth_anything", ("DPTImageProcessor",)),
("depth_pro", ("DepthProImageProcessor",)),
("dinov2", ("BitImageProcessor",)),
("dinov3_vit", ("DINOv3ViTImageProcessorFast",)),
("dpt", ("DPTImageProcessor",)),
("efficientnet", ("EfficientNetImageProcessor",)),
("flava", ("FlavaImageProcessor",)),
Expand Down
4 changes: 4 additions & 0 deletions mindone/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
("diffllama", "DiffLlamaModel"),
("dinov2", "Dinov2Model"),
("dinov2_with_registers", "Dinov2WithRegistersModel"),
("dinov3_convnext", "DINOv3ConvNextModel"),
("dinov3_vit", "DINOv3ViTModel"),
("distilbert", "DistilBertModel"),
("dpr", "DPRQuestionEncoder"),
("dpt", "DPTModel"),
Expand Down Expand Up @@ -514,6 +516,8 @@
("depth_pro", "DepthProModel"),
("dinov2", "Dinov2Model"),
("dinov2_with_registers", "Dinov2WithRegistersModel"),
("dinov3_convnext", "DINOv3ConvNextModel"),
("dinov3_vit", "DINOv3ViTModel"),
("dpt", "DPTModel"),
("efficientnet", "EfficientNetModel"),
("focalnet", "FocalNetModel"),
Expand Down
17 changes: 17 additions & 0 deletions mindone/transformers/models/dinov3_convnext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_dinov3_convnext import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wildcard imports (from ... import *) are discouraged by PEP 8 as they make it unclear which names are present in the namespace. It's better to explicitly import the required names. Based on __all__ in modeling_dinov3_convnext.py, you should import DINOv3ConvNextModel and DINOv3ConvNextPreTrainedModel.

Suggested change
from .modeling_dinov3_convnext import *
from .modeling_dinov3_convnext import DINOv3ConvNextModel, DINOv3ConvNextPreTrainedModel

Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MindSpore ConvNext model."""

from typing import Optional

import numpy as np
import mindspore as ms
from mindspore import mint, nn

from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
)
from ...modeling_utils import PreTrainedModel
from transformers.models.dinov3_convnext.configuration_dinov3_convnext import DINOv3ConvNextConfig


# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input: ms.Tensor, drop_prob: float = 0.0, training: bool = False) -> ms.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + mint.rand(shape, dtype=input.dtype, )
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output


# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->DINOv3ConvNext
class DINOv3ConvNextDropPath(nn.Cell):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob

def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)

def extra_repr(self) -> str:
return f"p={self.drop_prob}"


class DINOv3ConvNextLayerNorm(mint.nn.LayerNorm):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""

def __init__(self, *args, data_format="channels_last", **kwargs):
super().__init__(*args, **kwargs)
if data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {data_format}")
self.data_format = data_format

def construct(self, features: ms.Tensor) -> ms.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
"""
if self.data_format == "channels_first":
features = features.permute(0, 2, 3, 1)
features = super().construct(features)
features = features.permute(0, 3, 1, 2)
else:
features = super().construct(features)
return features


class DINOv3ConvNextLayer(nn.Cell):
"""This corresponds to the `Block` class in the original implementation.

There are two equivalent implementations:
1) DwConv, LayerNorm (channels_first), Conv, GELU, Conv (all in (N, C, H, W) format)
2) DwConv, Permute, LayerNorm (channels_last), Linear, GELU, Linear, Permute

The authors used (2) as they find it slightly faster in PyTorch.

Args:
config ([`DINOv3ConvNextConfig`]):
Model config.
channels (`int`):
Number of input (and output) channels.
drop_path (`float`):
Drop path rate. Default: 0.0.
"""

def __init__(self, config: DINOv3ConvNextConfig, channels: int, drop_path: float = 0.0):
super().__init__()
self.depthwise_conv = mint.nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels)
self.layer_norm = DINOv3ConvNextLayerNorm(channels, eps=config.layer_norm_eps)
self.pointwise_conv1 = mint.nn.Linear(channels, 4 * channels) # can be seen as a 1x1 conv
self.activation_fn = ACT2FN[config.hidden_act]
self.pointwise_conv2 = mint.nn.Linear(4 * channels, channels) # can be seen as a 1x1 conv
self.gamma = ms.Parameter(mint.full((channels,), config.layer_scale_init_value), requires_grad=True)
self.drop_path = DINOv3ConvNextDropPath(drop_path) if drop_path > 0.0 else mint.nn.Identity()

def construct(self, features: ms.Tensor) -> ms.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width)
"""
residual = features
features = self.depthwise_conv(features)
features = features.permute(0, 2, 3, 1) # to channels last
features = self.layer_norm(features)
features = self.pointwise_conv1(features)
features = self.activation_fn(features)
features = self.pointwise_conv2(features)
features = features * self.gamma
features = features.permute(0, 3, 1, 2) # back to channels first
features = residual + self.drop_path(features)
return features


class DINOv3ConvNextStage(nn.Cell):
""" """

def __init__(self, config: DINOv3ConvNextConfig, stage_idx: int):
super().__init__()

in_channels = config.hidden_sizes[stage_idx - 1] if stage_idx > 0 else config.num_channels
out_channels = config.hidden_sizes[stage_idx]

if stage_idx == 0:
self.downsample_layers = nn.CellList(
[
mint.nn.Conv2d(config.num_channels, out_channels, kernel_size=4, stride=4),
DINOv3ConvNextLayerNorm(out_channels, eps=config.layer_norm_eps, data_format="channels_first"),
]
)
else:
self.downsample_layers = nn.CellList(
[
DINOv3ConvNextLayerNorm(in_channels, eps=config.layer_norm_eps, data_format="channels_first"),
mint.nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=2),
]
)

num_stage_layers = config.depths[stage_idx]
num_previous_layers = sum(config.depths[:stage_idx])
num_total_layers = sum(config.depths)
drop_path_rates = np.linspace(0, config.drop_path_rate, num_total_layers).tolist()

self.layers = nn.CellList(
[
DINOv3ConvNextLayer(config, channels=out_channels, drop_path=drop_path_rates[i])
for i in range(num_previous_layers, num_previous_layers + num_stage_layers)
]
)

def construct(self, features: ms.Tensor) -> ms.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width)
"""
for layer in self.downsample_layers:
features = layer(features)
for layer in self.layers:
features = layer(features)
return features


class DINOv3ConvNextPreTrainedModel(PreTrainedModel):
config: DINOv3ConvNextConfig
base_model_prefix = "dinov3_convnext"
main_input_name = "pixel_values"
_no_split_modules = ["DINOv3ConvNextLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (mint.nn.LayerNorm, DINOv3ConvNextLayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DINOv3ConvNextLayer):
if module.gamma is not None:
module.gamma.data.fill_(self.config.layer_scale_init_value)
Comment on lines +196 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The weight initialization method _init_weights uses PyTorch-style in-place modification on .data, which is not supported for mindspore.Parameter. You should use helper functions like normal_, zeros_, and constant_ from mindone.models.utils to initialize the parameters correctly. Please also add from mindone.models.utils import constant_, normal_, zeros_ to the imports at the top of the file.

Suggested change
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (mint.nn.LayerNorm, DINOv3ConvNextLayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DINOv3ConvNextLayer):
if module.gamma is not None:
module.gamma.data.fill_(self.config.layer_scale_init_value)
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
zeros_(module.bias)
elif isinstance(module, (mint.nn.LayerNorm, DINOv3ConvNextLayerNorm)):
zeros_(module.bias)
constant_(module.weight, 1.0)
elif isinstance(module, DINOv3ConvNextLayer):
if module.gamma is not None:
constant_(module.gamma, self.config.layer_scale_init_value)



class DINOv3ConvNextModel(DINOv3ConvNextPreTrainedModel):
def __init__(self, config: DINOv3ConvNextConfig):
super().__init__(config)
self.config = config
self.stages = nn.CellList([DINOv3ConvNextStage(config, stage_idx) for stage_idx in range(config.num_stages)])
self.layer_norm = mint.nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps) # final norm layer
self.pool = mint.nn.AdaptiveAvgPool2d(1)
self.post_init()

def construct(
self, pixel_values: ms.Tensor, output_hidden_states: Optional[bool] = None
) -> BaseModelOutputWithPoolingAndNoAttention:
hidden_states = pixel_values

output_hidden_states = output_hidden_states or self.config.output_hidden_states
all_hidden_states = [hidden_states] if output_hidden_states else []

for stage in self.stages:
hidden_states = stage(hidden_states)

# store intermediate stage outputs
if output_hidden_states:
all_hidden_states.append(hidden_states)

# make global representation, a.k.a [CLS] token
pooled_output = self.pool(hidden_states)

# (batch_size, channels, height, width) -> (batch_size, height * width, channels)
pooled_output = pooled_output.flatten(2).transpose(1, 2)
hidden_states = hidden_states.flatten(2).transpose(1, 2)

# concat "cls" and "patch tokens" as (batch_size, 1 + height * width, channels)
hidden_states = mint.cat([pooled_output, hidden_states], dim=1)
hidden_states = self.layer_norm(hidden_states)

return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=hidden_states,
pooler_output=hidden_states[:, 0],
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
)


__all__ = ["DINOv3ConvNextModel", "DINOv3ConvNextPreTrainedModel"]
18 changes: 18 additions & 0 deletions mindone/transformers/models/dinov3_vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .image_processing_dinov3_vit_fast import DINOv3ViTImageProcessorFast
from .modeling_dinov3_vit import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wildcard imports (from ... import *) are discouraged by PEP 8 as they make it unclear which names are present in the namespace. It's better to explicitly import the required names. Based on __all__ in modeling_dinov3_vit.py, you should import DINOv3ViTModel and DINOv3ViTPreTrainedModel.

Suggested change
from .modeling_dinov3_vit import *
from .modeling_dinov3_vit import DINOv3ViTModel, DINOv3ViTPreTrainedModel

Loading