Skip to content

Commit c60112e

Browse files
Update CLIP to a functional model (keras-team#2393)
* Refactor CLIP to a functional model update model input format update golden values update CLIP to functional model update tests code reformat use dict instead of list Update keras_cv/models/feature_extractor/clip/clip_model.py Co-authored-by: Tirth Patel <tirthasheshpatel@gmail.com> remove build and compute output shape update model input format update golden values Refactor CLIP Refactor includes: - CLIPProcessor is now a Keras layer and uses some utilities from KerasNLP to support all types of python types and array inputs - CLIPImageEncoder, CLIPTextEncoder, and CLIPEncoder now implement a `.compute_output_shape` method (required for CLIP to work with the functional API) - CLIPHead added to remove raw variables from the CLIP Task models; having variables in `keras.Model` class is tricky since functional API doesn't allow state. - CLIP checkpointing script has been updated to now work with the new API: new weights will be uploaded to Kaggle. TODO: attribute KerasNLP wherever relevant TODO: upload new weights to Kaggle TODO: refactor the CLIPProcessor class and the CLIP class to also pull tokenizer vocab and merges from Kaggle. remove build and compute output shape Some fixes for the refactor Fix the tests, update presets update to layers instead of models * Attempt to fix the Keras 2 error * remove initializers * Remove all initializers in ClipAttention * code reformat * update initializers * make this keras 3 only * update skip test * add reason for skip * skipping tests individually :| * Don't run tests for Keras 2 --------- Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli> Co-authored-by: Tirth Patel <tirthasheshpatel@gmail.com>
1 parent bfeba12 commit c60112e

File tree

10 files changed

+620
-739
lines changed

10 files changed

+620
-739
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ then
6969
keras_cv/models/object_detection/yolo_v8 \
7070
keras_cv/models/object_detection_3d \
7171
keras_cv/models/segmentation \
72+
keras_cv/models/feature_extractor/clip \
7273
keras_cv/models/stable_diffusion
7374
else
7475
pytest --cache-clear --check_gpu --run_large --durations 0 \
@@ -83,5 +84,6 @@ else
8384
keras_cv/models/object_detection/yolo_v8 \
8485
keras_cv/models/object_detection_3d \
8586
keras_cv/models/segmentation \
87+
keras_cv/models/feature_extractor/clip \
8688
keras_cv/models/stable_diffusion
8789
fi

keras_cv/models/feature_extractor/clip/clip_encoder.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import numpy as np
15-
1614
from keras_cv.api_export import keras_cv_export
1715
from keras_cv.backend import keras
1816
from keras_cv.backend import ops
1917

2018

21-
def get_initializer(initializer_range=0.02):
22-
"""
23-
Creates a `keras.initializers.TruncatedNormal` with the given range.
24-
25-
Args:
26-
initializer_range (*float*, defaults to 0.02): Standard deviation of the
27-
initializer range.
28-
29-
Returns:
30-
`keras.initializers.TruncatedNormal`: The truncated normal initializer.
31-
"""
32-
return keras.initializers.TruncatedNormal(stddev=initializer_range)
33-
34-
3519
@keras_cv_export("keras_cv.models.feature_extractor.QuickGELU")
3620
class QuickGELU(keras.layers.Layer):
3721
def __init__(self, **kwargs):
@@ -54,13 +38,6 @@ def __init__(
5438
self.proj_dim = proj_dim
5539
self.num_heads = num_heads
5640
self.num_hidden_layers = num_hidden_layers
57-
self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02
58-
59-
self.in_proj_std = (
60-
np.power(self.proj_dim, -0.5)
61-
* (np.power(2 * self.num_hidden_layers, -0.5))
62-
* 0.02
63-
)
6441
self.attn = CLIPAttention(
6542
self.proj_dim,
6643
self.num_heads,
@@ -156,9 +133,14 @@ def __init__(self, width, num_layers, heads, **kwargs):
156133
]
157134

158135
def build(self, input_shape):
159-
super().build(input_shape)
160136
for block in self.resblocks:
161137
block.build(input_shape)
138+
self.built = True
139+
140+
def compute_output_shape(self, input_shape):
141+
for block in self.resblocks:
142+
input_shape = block.compute_output_shape(input_shape)
143+
return input_shape
162144

163145
def call(
164146
self,
@@ -174,9 +156,6 @@ def call(
174156
)
175157
return x
176158

177-
def compute_output_shape(self, inputs_shape):
178-
return inputs_shape
179-
180159
def get_config(self):
181160
config = super().get_config()
182161
config.update(
@@ -213,30 +192,20 @@ def __init__(
213192
)
214193

215194
self.scale = self.head_dim**-0.5
216-
in_proj_std = (
217-
(self.proj_dim**-0.5)
218-
* ((2 * self.num_hidden_layers) ** -0.5)
219-
* 0.02
220-
)
221-
out_proj_std = (self.proj_dim**-0.5) * 0.02
222195
self.q_proj = keras.layers.Dense(
223196
units=self.proj_dim,
224-
kernel_initializer=get_initializer(in_proj_std),
225197
name="q_proj",
226198
)
227199
self.k_proj = keras.layers.Dense(
228200
units=self.proj_dim,
229-
kernel_initializer=get_initializer(in_proj_std),
230201
name="k_proj",
231202
)
232203
self.v_proj = keras.layers.Dense(
233204
units=self.proj_dim,
234-
kernel_initializer=get_initializer(in_proj_std),
235205
name="v_proj",
236206
)
237207
self.out_proj = keras.layers.Dense(
238208
units=self.proj_dim,
239-
kernel_initializer=get_initializer(out_proj_std),
240209
name="out_proj",
241210
)
242211

keras_cv/models/feature_extractor/clip/clip_image_model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
from keras_cv.backend import keras
1717
from keras_cv.backend import ops
1818
from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder
19-
from keras_cv.models.feature_extractor.clip.clip_encoder import get_initializer
2019

2120

22-
@keras_cv_export("keras_cv.models.feature_extractor.CLIPPatchingAndEmbedding")
2321
class CLIPPatchingAndEmbedding(keras.layers.Layer):
2422
def __init__(
2523
self, width, patch_size, input_resolution, output_dim, **kwargs
@@ -33,7 +31,6 @@ def __init__(
3331
padding="valid",
3432
use_bias=False,
3533
data_format="channels_last",
36-
kernel_initializer=get_initializer(0.02),
3734
name="patch_embed.embedding",
3835
)
3936
self.width = width
@@ -42,17 +39,13 @@ def __init__(
4239
self.num_patches = ops.power(
4340
(self.input_resolution // self.patch_size), 2
4441
)
45-
self.class_embedding_initializer = get_initializer(
46-
ops.power(self.width, -0.5) * 0.02
47-
)
4842
self.output_dim = output_dim
4943

5044
def build(self, input_shape):
5145
super().build(input_shape)
5246
self.conv1.build(input_shape)
5347
self.class_embedding = self.add_weight(
5448
shape=((self.width,)),
55-
initializer=self.class_embedding_initializer,
5649
name="patch_embed.class_embedding",
5750
)
5851

@@ -67,6 +60,13 @@ def build(self, input_shape):
6760
name="patch_embed.positional_embedding",
6861
)
6962

63+
def compute_output_shape(self, input_shape):
64+
return [
65+
None,
66+
(self.input_resolution // self.patch_size) ** 2 + 1,
67+
self.width,
68+
]
69+
7070
def call(self, x):
7171
batch_size = ops.shape(x)[0]
7272
patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel]
@@ -143,12 +143,15 @@ def __init__(
143143
)
144144

145145
def build(self, input_shape):
146-
super().build(input_shape)
147146
self.embeddings.build(input_shape)
148147
self.pre_norm.build([None, None, self.width])
149148
self.encoder.build(None)
150149
self.post_norm.build([None, self.width])
151-
self.image_projector.build([None, None, self.width])
150+
self.image_projector.build([None, self.width])
151+
self.built = True
152+
153+
def compute_output_shape(self, input_shape):
154+
return [input_shape[0], self.output_dim]
152155

153156
def call(self, image):
154157
x = self.embeddings(image)

0 commit comments

Comments
 (0)