Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLPN #16199

Merged
merged 21 commits into from
Mar 22, 2022
Merged

Add GLPN #16199

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update copied from statements and clean up
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Mar 21, 2022
commit 243f12dd74ef6e4e2456faa8c0db3187ae020e06
12 changes: 0 additions & 12 deletions src/transformers/models/glpn/configuration_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class GLPNConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.

Args:
image_size (`int`, *optional*, defaults to 512):
The size (resolution) of each image.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_encoder_blocks (`int`, *optional*, defaults to 4):
Expand All @@ -49,8 +47,6 @@ class GLPNConfig(PretrainedConfig):
Sequence reduction ratios in each encoder block.
hidden_sizes (`List[int]`, *optional*, defaults to `[32, 64, 160, 256]`):
Dimension of each of the encoder blocks.
downsampling_rates (`List[int]`, *optional*, defaults to `[1, 4, 8, 16]`):
Downsample rate of the image resolution compared to the original image size before each encoder block.
patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`):
Patch size before each encoder block.
strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
Expand All @@ -67,8 +63,6 @@ class GLPNConfig(PretrainedConfig):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probability before the classification head.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
drop_path_rate (`float`, *optional*, defaults to 0.1):
Expand Down Expand Up @@ -100,21 +94,18 @@ class GLPNConfig(PretrainedConfig):

def __init__(
self,
image_size=224,
num_channels=3,
num_encoder_blocks=4,
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
hidden_sizes=[32, 64, 160, 256],
downsampling_rates=[1, 4, 8, 16],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
num_attention_heads=[1, 2, 5, 8],
mlp_ratios=[4, 4, 4, 4],
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
classifier_dropout_prob=0.1,
initializer_range=0.02,
drop_path_rate=0.1,
layer_norm_eps=1e-6,
Expand All @@ -126,21 +117,18 @@ def __init__(
):
super().__init__(**kwargs)

self.image_size = image_size
self.num_channels = num_channels
self.num_encoder_blocks = num_encoder_blocks
self.depths = depths
self.sr_ratios = sr_ratios
self.hidden_sizes = hidden_sizes
self.downsampling_rates = downsampling_rates
self.patch_sizes = patch_sizes
self.strides = strides
self.mlp_ratios = mlp_ratios
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.initializer_range = initializer_range
self.drop_path_rate = drop_path_rate
self.layer_norm_eps = layer_norm_eps
Expand Down
71 changes: 37 additions & 34 deletions src/transformers/models/glpn/modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
""" PyTorch GLPN model."""


import collections
import math
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -53,13 +52,6 @@
]


# Copied from transformers.models.segformer.modeling_segformer.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)


# Copied from transformers.models.segformer.modeling_segformer.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Expand Down Expand Up @@ -93,35 +85,36 @@ def forward(self, x):

# Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
class GLPNOverlapPatchEmbeddings(nn.Module):
"""Construct the patch embeddings from an image."""
"""Construct the overlapping patch embeddings."""

def __init__(self, image_size, patch_size, stride, num_channels, hidden_size):
def __init__(self, patch_size, stride, num_channels, hidden_size):
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
self.height, self.width = image_size[0] // patch_size[0], image_size[1] // patch_size[1]
self.num_patches = self.height * self.width
self.proj = nn.Conv2d(
num_channels,
hidden_size,
kernel_size=patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2),
padding=patch_size // 2,
)

self.layer_norm = nn.LayerNorm(hidden_size)

def forward(self, pixel_values):
x = self.proj(pixel_values)
_, _, height, width = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.layer_norm(x)
return x, height, width
embeddings = self.proj(pixel_values)
_, _, height, width = embeddings.shape
# (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
# this can be fed to a Transformer layer
embeddings = embeddings.flatten(2).transpose(1, 2)
embeddings = self.layer_norm(embeddings)
return embeddings, height, width


# Copied from transformers.models.segformer.modeling_segformer.SegformerEfficientSelfAttention
class GLPNEfficientSelfAttention(nn.Module):
def __init__(self, config, hidden_size, num_attention_heads, sr_ratio):
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
paper](https://arxiv.org/abs/2102.12122)."""

def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
Expand All @@ -141,15 +134,17 @@ def __init__(self, config, hidden_size, num_attention_heads, sr_ratio):

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(hidden_size, hidden_size, kernel_size=sr_ratio, stride=sr_ratio)
self.sr_ratio = sequence_reduction_ratio
if sequence_reduction_ratio > 1:
self.sr = nn.Conv2d(
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
)
self.layer_norm = nn.LayerNorm(hidden_size)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_for_scores(self, hidden_states):
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
hidden_states = hidden_states.view(*new_shape)
return hidden_states.permute(0, 2, 1, 3)

def forward(
self,
Expand All @@ -162,8 +157,11 @@ def forward(

if self.sr_ratio > 1:
batch_size, seq_len, num_channels = hidden_states.shape
# Reshape to (batch_size, num_channels, height, width)
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
# Apply sequence reduction
hidden_states = self.sr(hidden_states)
# Reshape back to (batch_size, seq_len, num_channels)
hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
hidden_states = self.layer_norm(hidden_states)

Expand Down Expand Up @@ -208,10 +206,13 @@ def forward(self, hidden_states, input_tensor):

# Copied from transformers.models.segformer.modeling_segformer.SegformerAttention with Segformer->GLPN
class GLPNAttention(nn.Module):
def __init__(self, config, hidden_size, num_attention_heads, sr_ratio):
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
super().__init__()
self.self = GLPNEfficientSelfAttention(
config=config, hidden_size=hidden_size, num_attention_heads=num_attention_heads, sr_ratio=sr_ratio
config=config,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
sequence_reduction_ratio=sequence_reduction_ratio,
)
self.output = GLPNSelfOutput(config, hidden_size=hidden_size)
self.pruned_heads = set()
Expand Down Expand Up @@ -285,11 +286,14 @@ def forward(self, hidden_states, height, width):
class GLPNLayer(nn.Module):
"""This corresponds to the Block class in the original implementation."""

def __init__(self, config, hidden_size, num_attention_heads, drop_path, sr_ratio, mlp_ratio):
def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(hidden_size)
self.attention = GLPNAttention(
config, hidden_size=hidden_size, num_attention_heads=num_attention_heads, sr_ratio=sr_ratio
config,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
sequence_reduction_ratio=sequence_reduction_ratio,
)
self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.layer_norm_2 = nn.LayerNorm(hidden_size)
Expand Down Expand Up @@ -335,7 +339,6 @@ def __init__(self, config):
for i in range(config.num_encoder_blocks):
embeddings.append(
GLPNOverlapPatchEmbeddings(
image_size=config.image_size // config.downsampling_rates[i],
patch_size=config.patch_sizes[i],
stride=config.strides[i],
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
Expand All @@ -359,7 +362,7 @@ def __init__(self, config):
hidden_size=config.hidden_sizes[i],
num_attention_heads=config.num_attention_heads[i],
drop_path=dpr[cur + j],
sr_ratio=config.sr_ratios[i],
sequence_reduction_ratio=config.sr_ratios[i],
mlp_ratio=config.mlp_ratios[i],
)
)
Expand Down