Skip to content

Commit fd6f977

Browse files
Add ViTDetBackbone (#1776)
* add vit det vit_det_backbone * update docstring * code reformat * fix tests * address review comments * bump year on all files * address review comments * rename backbone * fix tests * change back to ViT * address review comments * update image shape
1 parent ececd14 commit fd6f977

File tree

4 files changed

+824
-0
lines changed

4 files changed

+824
-0
lines changed

keras_nlp/api/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@
212212
from keras_nlp.src.models.task import Task
213213
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
214214
from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier
215+
from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone
215216
from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import (
216217
WhisperAudioFeatureExtractor,
217218
)
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2024 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import keras
16+
from keras import ops
17+
18+
from keras_nlp.src.api_export import keras_nlp_export
19+
from keras_nlp.src.models.backbone import Backbone
20+
from keras_nlp.src.models.vit_det.vit_layers import AddPositionalEmbedding
21+
from keras_nlp.src.models.vit_det.vit_layers import ViTDetPatchingAndEmbedding
22+
from keras_nlp.src.models.vit_det.vit_layers import WindowedTransformerEncoder
23+
24+
25+
@keras_nlp_export("keras_nlp.models.ViTDetBackbone")
26+
class ViTDetBackbone(Backbone):
27+
"""An implementation of ViT image encoder.
28+
29+
The ViTDetBackbone uses a windowed transformer encoder and relative
30+
positional encodings. The code has been adapted from [Segment Anything
31+
paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
32+
https://github.com/facebookresearch/segment-anything) and [Detectron2](
33+
https://github.com/facebookresearch/detectron2).
34+
35+
Args:
36+
hidden_size (int): The latent dimensionality to be projected
37+
into in the output of each stacked windowed transformer encoder.
38+
num_layers (int): The number of transformer encoder layers to
39+
stack in the Vision Transformer.
40+
intermediate_dim (int): The dimensionality of the hidden Dense
41+
layer in the transformer MLP head.
42+
num_heads (int): the number of heads to use in the
43+
`MultiHeadAttentionWithRelativePE` layer of each transformer
44+
encoder.
45+
global_attention_layer_indices (list): Indexes for blocks using
46+
global attention.
47+
image_shape (tuple[int], optional): The size of the input image in
48+
`(H, W, C)` format. Defaults to `(1024, 1024, 3)`.
49+
include_rescaling (bool, optional): Whether to rescale the inputs. If
50+
set to `True`, inputs will be passed through a
51+
`Rescaling(1/255.0)` layer. Defaults to `False`.
52+
patch_size (int, optional): the patch size to be supplied to the
53+
Patching layer to turn input images into a flattened sequence of
54+
patches. Defaults to `16`.
55+
num_output_channels (int, optional): The number of channels (features)
56+
in the output (image encodings). Defaults to `256`.
57+
use_bias (bool, optional): Whether to use bias to project the keys,
58+
queries, and values in the attention layer. Defaults to `True`.
59+
use_abs_pos (bool, optional): Whether to add absolute positional
60+
embeddings to the output patches. Defaults to `True`.
61+
use_rel_pos (bool, optional): Whether to use relative positional
62+
emcodings in the attention layer. Defaults to `True`.
63+
window_size (int, optional): The size of the window for windowed
64+
attention in the transformer encoder blocks. Defaults to `14`.
65+
layer_norm_epsilon (int, optional): The epsilon to use in the layer
66+
normalization blocks in transformer encoder. Defaults to `1e-6`.
67+
68+
Examples:
69+
```python
70+
input_data = np.ones((2, 224, 224, 3), dtype="float32")
71+
72+
# Pretrained ViTDetBackbone backbone.
73+
model = keras_nlp.models.ViTDetBackbone.from_preset("vit_det")
74+
model(input_data)
75+
76+
# Randomly initialized ViTDetBackbone backbone with a custom config.
77+
model = keras_nlp.models.ViTDetBackbone(
78+
image_shape = (16, 16, 3),
79+
patch_size = 2,
80+
hidden_size = 4,
81+
num_layers = 2,
82+
global_attention_layer_indices = [2, 5, 8, 11],
83+
intermediate_dim = 4 * 4,
84+
num_heads = 2,
85+
num_output_channels = 2,
86+
window_size = 2,
87+
)
88+
model(input_data)
89+
```
90+
"""
91+
92+
def __init__(
93+
self,
94+
hidden_size,
95+
num_layers,
96+
intermediate_dim,
97+
num_heads,
98+
global_attention_layer_indices,
99+
include_rescaling=True,
100+
image_shape=(1024, 1024, 3),
101+
patch_size=16,
102+
num_output_channels=256,
103+
use_bias=True,
104+
use_abs_pos=True,
105+
use_rel_pos=True,
106+
window_size=14,
107+
layer_norm_epsilon=1e-6,
108+
**kwargs
109+
):
110+
# === Functional model ===
111+
img_input = keras.layers.Input(shape=image_shape)
112+
# Check that the input image is well specified.
113+
if img_input.shape[-3] is None or img_input.shape[-2] is None:
114+
raise ValueError(
115+
"Height and width of the image must be specified"
116+
" in `image_shape`."
117+
)
118+
if img_input.shape[-3] != img_input.shape[-2]:
119+
raise ValueError(
120+
"Input image must be square i.e. the height must"
121+
" be equal to the width in the `image_shape`"
122+
" tuple/tensor."
123+
)
124+
img_size = img_input.shape[-3]
125+
x = img_input
126+
if include_rescaling:
127+
# Use common rescaling strategy across keras_cv
128+
x = keras.layers.Rescaling(1.0 / 255.0)(x)
129+
# VITDet scales inputs based on the standard ImageNet mean/stddev.
130+
x = (x - ops.array([0.485, 0.456, 0.406], dtype=x.dtype)) / (
131+
ops.array([0.229, 0.224, 0.225], dtype=x.dtype)
132+
)
133+
x = ViTDetPatchingAndEmbedding(
134+
kernel_size=(patch_size, patch_size),
135+
strides=(patch_size, patch_size),
136+
embed_dim=hidden_size,
137+
)(x)
138+
if use_abs_pos:
139+
x = AddPositionalEmbedding(img_size, patch_size, hidden_size)(x)
140+
for i in range(num_layers):
141+
x = WindowedTransformerEncoder(
142+
project_dim=hidden_size,
143+
intermediate_dim=intermediate_dim,
144+
num_heads=num_heads,
145+
use_bias=use_bias,
146+
use_rel_pos=use_rel_pos,
147+
window_size=(
148+
window_size
149+
if i not in global_attention_layer_indices
150+
else 0
151+
),
152+
input_size=(img_size // patch_size, img_size // patch_size),
153+
)(x)
154+
x = keras.layers.Conv2D(
155+
filters=num_output_channels, kernel_size=1, use_bias=False
156+
)(x)
157+
x = keras.layers.LayerNormalization(epsilon=1e-6)(x)
158+
x = keras.layers.Conv2D(
159+
filters=num_output_channels,
160+
kernel_size=3,
161+
padding="same",
162+
use_bias=False,
163+
)(x)
164+
x = keras.layers.LayerNormalization(epsilon=1e-6)(x)
165+
166+
super().__init__(inputs=img_input, outputs=x, **kwargs)
167+
168+
# === Config ===
169+
self.patch_size = patch_size
170+
self.image_shape = image_shape
171+
self.hidden_size = hidden_size
172+
self.num_layers = num_layers
173+
self.intermediate_dim = intermediate_dim
174+
self.num_heads = num_heads
175+
self.num_output_channels = num_output_channels
176+
self.use_bias = use_bias
177+
self.use_rel_pos = use_rel_pos
178+
self.use_abs_pos = use_abs_pos
179+
self.window_size = window_size
180+
self.global_attention_layer_indices = global_attention_layer_indices
181+
self.layer_norm_epsilon = layer_norm_epsilon
182+
self.include_rescaling = include_rescaling
183+
184+
def get_config(self):
185+
config = super().get_config()
186+
config.update(
187+
{
188+
"image_shape": self.image_shape,
189+
"include_rescaling": self.include_rescaling,
190+
"patch_size": self.patch_size,
191+
"hidden_size": self.hidden_size,
192+
"num_layers": self.num_layers,
193+
"intermediate_dim": self.intermediate_dim,
194+
"num_heads": self.num_heads,
195+
"num_output_channels": self.num_output_channels,
196+
"use_bias": self.use_bias,
197+
"use_abs_pos": self.use_abs_pos,
198+
"use_rel_pos": self.use_rel_pos,
199+
"window_size": self.window_size,
200+
"global_attention_layer_indices": self.global_attention_layer_indices,
201+
"layer_norm_epsilon": self.layer_norm_epsilon,
202+
}
203+
)
204+
return config
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone
19+
from keras_nlp.src.tests.test_case import TestCase
20+
21+
22+
class ViTDetBackboneTest(TestCase):
23+
def setUp(self):
24+
self.init_kwargs = {
25+
"include_rescaling": True,
26+
"image_shape": (16, 16, 3),
27+
"patch_size": 2,
28+
"hidden_size": 4,
29+
"num_layers": 2,
30+
"global_attention_layer_indices": [2, 5, 8, 11],
31+
"intermediate_dim": 4 * 4,
32+
"num_heads": 2,
33+
"num_output_channels": 2,
34+
"window_size": 2,
35+
}
36+
self.input_data = np.ones((1, 16, 16, 3), dtype="float32")
37+
38+
def test_backbone_basics(self):
39+
self.run_backbone_test(
40+
cls=ViTDetBackbone,
41+
init_kwargs=self.init_kwargs,
42+
input_data=self.input_data,
43+
expected_output_shape=(1, 8, 8, 2),
44+
run_mixed_precision_check=False,
45+
run_quantization_check=False,
46+
)
47+
48+
@pytest.mark.large
49+
def test_saved_model(self):
50+
self.run_model_saving_test(
51+
cls=ViTDetBackbone,
52+
init_kwargs=self.init_kwargs,
53+
input_data=self.input_data,
54+
)

0 commit comments

Comments
 (0)