Skip to content

Commit 2620e4c

Browse files
committed
Replace manual implementation of CLIPAttention with MultiHeadAttention
1 parent 7ddf4ec commit 2620e4c

File tree

2 files changed

+53
-98
lines changed

2 files changed

+53
-98
lines changed

keras_nlp/src/models/stable_diffusion_v3/clip_attention.py

Lines changed: 44 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -15,97 +15,47 @@
1515
from keras import ops
1616

1717

18-
class CLIPAttention(layers.Layer):
19-
def __init__(self, num_heads, hidden_dim, dropout=0.0, **kwargs):
20-
super().__init__(**kwargs)
21-
if hidden_dim % num_heads != 0:
22-
raise ValueError(
23-
"`hidden_dim` must be divisible by num_heads. "
24-
f"Received: num_heads={num_heads}, hidden_dim={hidden_dim}"
25-
)
26-
self.num_heads = num_heads
27-
self.hidden_dim = hidden_dim
28-
self.dropout = dropout
29-
self.head_dim = self.hidden_dim // self.num_heads
30-
31-
self.dropout_layer = layers.Dropout(self.dropout)
32-
self.scale = self.head_dim**-0.5
33-
self.query_dense = layers.Dense(
34-
units=self.hidden_dim, dtype=self.dtype_policy, name="query"
35-
)
36-
self.key_dense = layers.Dense(
37-
units=self.hidden_dim, dtype=self.dtype_policy, name="key"
38-
)
39-
self.value_dense = layers.Dense(
40-
units=self.hidden_dim, dtype=self.dtype_policy, name="value"
41-
)
42-
self.softmax = layers.Softmax(dtype="float32")
43-
self.output_dense = layers.Dense(
44-
units=self.hidden_dim,
45-
dtype=self.dtype_policy,
46-
name="attention_output",
47-
)
48-
49-
def build(self, input_shape):
50-
self.query_dense.build(input_shape)
51-
self.key_dense.build(input_shape)
52-
self.value_dense.build(input_shape)
53-
self.output_dense.build([None, None, self.hidden_dim])
54-
55-
def compute_output_shape(self, input_shape):
56-
output_shape = list(input_shape)
57-
output_shape[-1] = self.hidden_dim
58-
return output_shape
59-
60-
def _transpose_for_scores(self, inputs):
61-
batch_size = ops.shape(inputs)[0]
62-
inputs = ops.reshape(
63-
inputs, (batch_size, -1, self.num_heads, self.head_dim)
64-
)
65-
return ops.transpose(inputs, axes=[0, 2, 1, 3])
66-
67-
def call(self, x, attention_mask=None, training=None):
68-
batch_size = ops.shape(x)[0]
69-
query = self.query_dense(x)
70-
key = self.key_dense(x)
71-
value = self.value_dense(x)
72-
query = self._transpose_for_scores(query)
73-
key = self._transpose_for_scores(key)
74-
value = self._transpose_for_scores(value)
75-
76-
attention_logits = ops.matmul(
77-
query, ops.transpose(key, axes=[0, 1, 3, 2])
78-
)
79-
dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_logits.dtype)
80-
attention_logits = ops.divide(attention_logits, dk)
81-
82-
if attention_mask is not None:
83-
attention_logits = ops.add(attention_logits, attention_mask)
84-
85-
orig_dtype = attention_logits.dtype
86-
attention_softmax = self.softmax(attention_logits)
87-
attention_softmax = ops.cast(attention_softmax, orig_dtype)
88-
89-
if self.dropout:
90-
attention_softmax = self.dropout_layer(
91-
attention_softmax, training=training
92-
)
93-
94-
attention_output = ops.matmul(attention_softmax, value)
95-
attention_output = ops.transpose(attention_output, axes=[0, 2, 1, 3])
96-
attention_output = ops.reshape(
97-
attention_output, (batch_size, -1, self.hidden_dim)
98-
)
99-
attention_output = self.output_dense(attention_output)
100-
return attention_output
101-
102-
def get_config(self):
103-
config = super().get_config()
104-
config.update(
105-
{
106-
"num_heads": self.num_heads,
107-
"hidden_dim": self.hidden_dim,
108-
"dropout": self.dropout,
109-
}
110-
)
111-
return config
18+
class CLIPAttention(layers.MultiHeadAttention):
19+
def __init__(
20+
self,
21+
num_heads,
22+
key_dim,
23+
value_dim=None,
24+
dropout=0.0,
25+
use_bias=True,
26+
output_shape=None,
27+
attention_axes=None,
28+
kernel_initializer="glorot_uniform",
29+
bias_initializer="zeros",
30+
kernel_regularizer=None,
31+
bias_regularizer=None,
32+
activity_regularizer=None,
33+
kernel_constraint=None,
34+
bias_constraint=None,
35+
seed=None,
36+
**kwargs,
37+
):
38+
super().__init__(
39+
num_heads,
40+
key_dim,
41+
value_dim,
42+
dropout,
43+
use_bias,
44+
output_shape,
45+
attention_axes,
46+
kernel_initializer,
47+
bias_initializer,
48+
kernel_regularizer,
49+
bias_regularizer,
50+
activity_regularizer,
51+
kernel_constraint,
52+
bias_constraint,
53+
seed,
54+
**kwargs,
55+
)
56+
57+
def _masked_softmax(self, attention_scores, attention_mask=None):
58+
# In CLIP model, the computation of `attention_mask` is performed
59+
# differently from `MultiHeadAttention`.
60+
attention_scores = ops.add(attention_scores, attention_mask)
61+
return self._softmax(attention_scores)

keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def __init__(
3333
**kwargs,
3434
):
3535
super().__init__(**kwargs)
36+
if hidden_dim % num_heads != 0:
37+
raise ValueError(
38+
"`hidden_dim` must be divisible by `num_heads`. "
39+
f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
40+
)
3641
self.hidden_dim = hidden_dim
3742
self.num_heads = num_heads
3843
self.intermediate_dim = intermediate_dim
@@ -45,8 +50,8 @@ def __init__(
4550
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
4651
)
4752
self.attention = CLIPAttention(
48-
self.num_heads,
49-
self.hidden_dim,
53+
num_heads,
54+
hidden_dim // num_heads,
5055
dtype=self.dtype_policy,
5156
name="attention",
5257
)
@@ -65,7 +70,7 @@ def __init__(
6570

6671
def build(self, input_shape):
6772
self.layer_norm_1.build(input_shape)
68-
self.attention.build(input_shape)
73+
self.attention.build(input_shape, input_shape, input_shape)
6974
self.layer_norm_2.build(input_shape)
7075
self.dense_1.build(input_shape)
7176
input_shape = self.dense_1.compute_output_shape(input_shape)
@@ -85,7 +90,7 @@ def _compute_attention(self, x, attention_mask=None, training=None):
8590
else None
8691
)
8792
mask = attention_mask
88-
return self.attention(x, attention_mask=mask, training=training)
93+
return self.attention(x, x, x, attention_mask=mask, training=training)
8994

9095
def call(self, x, attention_mask=None, training=None):
9196
residual = x

0 commit comments

Comments
 (0)