diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index d2a629ae067ffb..89072ae5de799e 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -267,7 +267,7 @@ Flax), PyTorch, and/or TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/regnet.mdx b/docs/source/en/model_doc/regnet.mdx index 666a9ee39675d0..1f87ccd051bbcc 100644 --- a/docs/source/en/model_doc/regnet.mdx +++ b/docs/source/en/model_doc/regnet.mdx @@ -27,7 +27,8 @@ Tips: - One can use [`AutoFeatureExtractor`] to prepare images for the model. - The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer) -This model was contributed by [Francesco](https://huggingface.co/Francesco). +This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of the model +was contributed by [sayakpaul](https://huggingface.com/sayakpaul) and [ariG23498](https://huggingface.com/ariG23498). The original code can be found [here](https://github.com/facebookresearch/pycls). @@ -45,4 +46,15 @@ The original code can be found [here](https://github.com/facebookresearch/pycls) ## RegNetForImageClassification [[autodoc]] RegNetForImageClassification - - forward \ No newline at end of file + - forward + +## TFRegNetModel + +[[autodoc]] TFRegNetModel + - call + + +## TFRegNetForImageClassification + +[[autodoc]] TFRegNetForImageClassification + - call \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6299af90a706e7..17670395f3c47b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2334,6 +2334,14 @@ "TFRagTokenForGeneration", ] ) + _import_structure["models.regnet"].extend( + [ + "TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFRegNetForImageClassification", + "TFRegNetModel", + "TFRegNetPreTrainedModel", + ] + ) _import_structure["models.rembert"].extend( [ "TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4649,6 +4657,12 @@ from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration + from .models.regnet import ( + TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, + TFRegNetForImageClassification, + TFRegNetModel, + TFRegNetPreTrainedModel, + ) from .models.rembert import ( TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFRemBertForCausalLM, diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 1e71556ec1a1f7..5aedace141c141 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -46,6 +46,25 @@ class TFBaseModelOutput(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFBaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFBaseModelOutputWithPooling(ModelOutput): """ @@ -80,6 +99,28 @@ class TFBaseModelOutputWithPooling(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: tf.Tensor = None + pooler_output: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): """ @@ -825,3 +866,24 @@ class TFSequenceClassifierOutputWithPast(ModelOutput): past_key_values: Optional[List[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None + + +@dataclass +class TFImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + """ + + loss: Optional[tf.Tensor] = None + logits: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 9c889597e59f6e..aaf91bdb24ec33 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -62,6 +62,7 @@ ("openai-gpt", "TFOpenAIGPTModel"), ("opt", "TFOPTModel"), ("pegasus", "TFPegasusModel"), + ("regnet", "TFRegNetModel"), ("rembert", "TFRemBertModel"), ("roberta", "TFRobertaModel"), ("roformer", "TFRoFormerModel"), @@ -173,6 +174,7 @@ # Model for Image-classsification ("convnext", "TFConvNextForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"), + ("regnet", "TFRegNetForImageClassification"), ("swin", "TFSwinForImageClassification"), ("vit", "TFViTForImageClassification"), ] diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py index 2de85e0cc19b18..5399cb3f3be08e 100644 --- a/src/transformers/models/regnet/__init__.py +++ b/src/transformers/models/regnet/__init__.py @@ -18,8 +18,7 @@ from typing import TYPE_CHECKING # rely on isort to merge the imports -from ...file_utils import _LazyModule, is_torch_available -from ...utils import OptionalDependencyNotAvailable +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available _import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]} @@ -37,6 +36,19 @@ "RegNetPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_regnet"] = [ + "TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFRegNetForImageClassification", + "TFRegNetModel", + "TFRegNetPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig @@ -54,6 +66,19 @@ RegNetPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_regnet import ( + TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, + TFRegNetForImageClassification, + TFRegNetModel, + TFRegNetPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/regnet/modeling_tf_regnet.py b/src/transformers/models/regnet/modeling_tf_regnet.py new file mode 100644 index 00000000000000..b6a8187744a70a --- /dev/null +++ b/src/transformers/models/regnet/modeling_tf_regnet.py @@ -0,0 +1,508 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TensorFlow RegNet model.""" + +from typing import Dict, Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFBaseModelOutputWithPoolingAndNoAttention, + TFSequenceClassifierOutput, +) +from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs +from ...tf_utils import shape_list +from ...utils import logging +from .configuration_regnet import RegNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "RegNetConfig" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040" +_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040" +_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'" + +TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/regnet-y-040", + # See all regnet models at https://huggingface.co/models?filter=regnet +] + + +class TFRegNetConvLayer(tf.keras.layers.Layer): + def __init__( + self, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + activation: Optional[str] = "relu", + **kwargs, + ): + super().__init__(**kwargs) + # The padding and conv has been verified in + # https://colab.research.google.com/gist/sayakpaul/854bc10eeaf21c9ee2119e0b9f3841a7/scratchpad.ipynb + self.padding = tf.keras.layers.ZeroPadding2D(padding=kernel_size // 2) + self.convolution = tf.keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + padding="VALID", + groups=groups, + use_bias=False, + name="convolution", + ) + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + self.activation = ACT2FN[activation] if activation is not None else tf.identity + + def call(self, hidden_state): + hidden_state = self.convolution(self.padding(hidden_state)) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class TFRegNetEmbeddings(tf.keras.layers.Layer): + """ + RegNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: RegNetConfig, **kwargs): + super().__init__(**kwargs) + self.num_channels = config.num_channels + self.embedder = TFRegNetConvLayer( + out_channels=config.embedding_size, + kernel_size=3, + stride=2, + activation=config.hidden_act, + name="embedder", + ) + + def call(self, pixel_values): + num_channels = shape_list(pixel_values)[1] + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + hidden_state = self.embedder(pixel_values) + return hidden_state + + +class TFRegNetShortCut(tf.keras.layers.Layer): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, out_channels: int, stride: int = 2, **kwargs): + super().__init__(**kwargs) + self.convolution = tf.keras.layers.Conv2D( + filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" + ) + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + + def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: + return self.normalization(self.convolution(inputs), training=training) + + +class TFRegNetSELayer(tf.keras.layers.Layer): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + def __init__(self, in_channels: int, reduced_channels: int, **kwargs): + super().__init__(**kwargs) + self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler") + self.attention = [ + tf.keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"), + tf.keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"), + ] + + def call(self, hidden_state): + # [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels] + pooled = self.pooler(hidden_state) + for layer_module in self.attention: + pooled = layer_module(pooled) + hidden_state = hidden_state * pooled + return hidden_state + + +class TFRegNetXLayer(tf.keras.layers.Layer): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs): + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + TFRegNetShortCut(out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else tf.keras.layers.Activation("linear", name="shortcut") + ) + # `self.layers` instead of `self.layer` because that is a reserved argument. + self.layers = [ + TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"), + TFRegNetConvLayer( + out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1" + ), + TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2"), + ] + self.activation = ACT2FN[config.hidden_act] + + def call(self, hidden_state): + residual = hidden_state + for layer_module in self.layers: + hidden_state = layer_module(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class TFRegNetYLayer(tf.keras.layers.Layer): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs): + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + TFRegNetShortCut(out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else tf.keras.layers.Activation("linear", name="shortcut") + ) + self.layers = [ + TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"), + TFRegNetConvLayer( + out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1" + ), + TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"), + TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.3"), + ] + self.activation = ACT2FN[config.hidden_act] + + def call(self, hidden_state): + residual = hidden_state + for layer_module in self.layers: + hidden_state = layer_module(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class TFRegNetStage(tf.keras.layers.Layer): + """ + A RegNet stage composed by stacked layers. + """ + + def __init__( + self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs + ): + super().__init__(**kwargs) + + layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer + self.layers = [ + # downsampling is done in the first layer with stride of 2 + layer(config, in_channels, out_channels, stride=stride, name="layers.0"), + *[layer(config, out_channels, out_channels, name=f"layers.{i+1}") for i in range(depth - 1)], + ] + + def call(self, hidden_state): + for layer_module in self.layers: + hidden_state = layer_module(hidden_state) + return hidden_state + + +class TFRegNetEncoder(tf.keras.layers.Layer): + def __init__(self, config: RegNetConfig, **kwargs): + super().__init__(**kwargs) + self.stages = list() + # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input + self.stages.append( + TFRegNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + name="stages.0", + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])): + self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i+1}")) + + def call( + self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> TFBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + +@keras_serializable +class TFRegNetMainLayer(tf.keras.layers.Layer): + config_class = RegNetConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embedder = TFRegNetEmbeddings(config, name="embedder") + self.encoder = TFRegNetEncoder(config, name="encoder") + self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler") + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> TFBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + + # Change to NCHW output format have uniformity in the modules + pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2)) + last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + ) + + +class TFRegNetPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + return self.call(inputs) + + +REGNET_START_DOCSTRING = r""" + Parameters: + This model is a Tensorflow + [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and + behavior. + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +class TFRegNetModel(TFRegNetPreTrainedModel): + def __init__(self, config: RegNetConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.regnet = TFRegNetMainLayer(config, name="regnet") + + @unpack_inputs + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training=False, + ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + if not return_dict: + return (outputs[0],) + outputs[1:] + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: RegNetConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.regnet = TFRegNetMainLayer(config, name="regnet") + # classification head + self.classifier = [ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity, + ] + + @unpack_inputs + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: tf.Tensor = None, + labels: tf.Tensor = None, + output_hidden_states: bool = None, + return_dict: bool = None, + training=False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + flattened_output = self.classifier[0](pooled_output) + logits = self.classifier[1](flattened_output) + + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 4eb40113e76cf0..01412ba82efaa1 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1696,6 +1696,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFRegNetForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFRegNetModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFRegNetPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/regnet/test_modeling_tf_regnet.py b/tests/models/regnet/test_modeling_tf_regnet.py new file mode 100644 index 00000000000000..c7504c92fa35b0 --- /dev/null +++ b/tests/models/regnet/test_modeling_tf_regnet.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow RegNet model. """ + +import inspect +import unittest +from typing import List, Tuple + +from transformers import RegNetConfig +from transformers.testing_utils import require_tf, require_vision, slow +from transformers.utils import cached_property, is_tf_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, TFRegNetForImageClassification, TFRegNetModel + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class TFRegNetModelTester: + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + return config, pixel_values, labels + + def get_config(self): + return RegNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFRegNetModel(config=config) + result = model(pixel_values, training=False) + # expected last hidden states: B, C, H // 32, W // 32 + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = TFRegNetForImageClassification(config) + result = model(pixel_values, labels=labels, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFRegNetModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else () + + test_pruning = False + test_onnx = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + + def setUp(self): + self.model_tester = TFRegNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False) + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="RegNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skipIf( + not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, + reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + ) + def test_keras_fit(self): + pass + + @unittest.skip(reason="RegNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Model doesn't have attention layers") + def test_attention_outputs(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + # RegNet's feature maps are of shape (batch_size, num_channels, height, width) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.image_size // 2, self.model_tester.image_size // 2], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + layers_type = ["basic", "bottleneck"] + for model_class in self.all_model_classes: + for layer_type in layers_type: + config.layer_type = layer_type + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Since RegNet does not have any attention we need to rewrite this test. + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + all(tf.equal(tuple_object, dict_object)), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}" + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFRegNetModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_tf +@require_vision +class RegNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + AutoFeatureExtractor.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = TFRegNetForImageClassification.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs, training=False) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([-0.4180, -1.5051, -3.4836]) + + tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 6297e328912c36..a977f47783f566 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -50,6 +50,8 @@ src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/plbart/modeling_plbart.py src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/reformer/modeling_reformer.py +src/transformers/models/regnet/modeling_regnet.py +src/transformers/models/regnet/modeling_tf_regnet.py src/transformers/models/resnet/modeling_resnet.py src/transformers/models/roberta/modeling_roberta.py src/transformers/models/roberta/modeling_tf_roberta.py