Skip to content

Commit c55c998

Browse files
james77777778mattdangerw
authored andcommitted
Add VAEImageDecoder for StableDiffusionV3 (keras-team#1796)
* Add `VAEImageDecoder` for StableDiffusionV3 * Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
1 parent c330460 commit c55c998

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
16+
from keras import layers
17+
from keras import ops
18+
19+
from keras_nlp.src.utils.keras_utils import standardize_data_format
20+
21+
22+
class VAEAttention(layers.Layer):
23+
def __init__(self, filters, groups=32, data_format=None, **kwargs):
24+
super().__init__(**kwargs)
25+
self.filters = filters
26+
self.data_format = standardize_data_format(data_format)
27+
gn_axis = -1 if self.data_format == "channels_last" else 1
28+
29+
self.group_norm = layers.GroupNormalization(
30+
groups=groups,
31+
axis=gn_axis,
32+
epsilon=1e-6,
33+
dtype=self.dtype_policy,
34+
name="group_norm",
35+
)
36+
self.query_conv2d = layers.Conv2D(
37+
filters,
38+
1,
39+
1,
40+
data_format=self.data_format,
41+
dtype=self.dtype_policy,
42+
name="query_conv2d",
43+
)
44+
self.key_conv2d = layers.Conv2D(
45+
filters,
46+
1,
47+
1,
48+
data_format=self.data_format,
49+
dtype=self.dtype_policy,
50+
name="key_conv2d",
51+
)
52+
self.value_conv2d = layers.Conv2D(
53+
filters,
54+
1,
55+
1,
56+
data_format=self.data_format,
57+
dtype=self.dtype_policy,
58+
name="value_conv2d",
59+
)
60+
self.softmax = layers.Softmax(dtype="float32")
61+
self.output_conv2d = layers.Conv2D(
62+
filters,
63+
1,
64+
1,
65+
data_format=self.data_format,
66+
dtype=self.dtype_policy,
67+
name="output_conv2d",
68+
)
69+
70+
self.groups = groups
71+
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
72+
73+
def build(self, input_shape):
74+
self.group_norm.build(input_shape)
75+
self.query_conv2d.build(input_shape)
76+
self.key_conv2d.build(input_shape)
77+
self.value_conv2d.build(input_shape)
78+
self.output_conv2d.build(input_shape)
79+
80+
def call(self, inputs, training=None):
81+
x = self.group_norm(inputs)
82+
query = self.query_conv2d(x)
83+
key = self.key_conv2d(x)
84+
value = self.value_conv2d(x)
85+
86+
if self.data_format == "channels_first":
87+
query = ops.transpose(query, (0, 2, 3, 1))
88+
key = ops.transpose(key, (0, 2, 3, 1))
89+
value = ops.transpose(value, (0, 2, 3, 1))
90+
shape = ops.shape(inputs)
91+
b = shape[0]
92+
query = ops.reshape(query, (b, -1, self.filters))
93+
key = ops.reshape(key, (b, -1, self.filters))
94+
value = ops.reshape(value, (b, -1, self.filters))
95+
96+
# Compute attention.
97+
query = ops.multiply(
98+
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
99+
)
100+
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
101+
attention_scores = ops.einsum("abc,adc->abd", query, key)
102+
attention_scores = ops.cast(
103+
self.softmax(attention_scores), self.compute_dtype
104+
)
105+
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
106+
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
107+
x = ops.reshape(attention_output, shape)
108+
109+
x = self.output_conv2d(x)
110+
if self.data_format == "channels_first":
111+
x = ops.transpose(x, (0, 3, 1, 2))
112+
x = ops.add(x, inputs)
113+
return x
114+
115+
def get_config(self):
116+
config = super().get_config()
117+
config.update(
118+
{
119+
"filters": self.filters,
120+
"groups": self.groups,
121+
}
122+
)
123+
return config
124+
125+
def compute_output_shape(self, input_shape):
126+
return input_shape
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
from keras import layers
16+
17+
from keras_nlp.src.models.stable_diffusion_v3.vae_attention import VAEAttention
18+
from keras_nlp.src.utils.keras_utils import standardize_data_format
19+
20+
21+
class VAEImageDecoder(keras.Model):
22+
def __init__(
23+
self,
24+
stackwise_num_filters,
25+
stackwise_num_blocks,
26+
output_channels=3,
27+
latent_shape=(None, None, 16),
28+
data_format=None,
29+
dtype=None,
30+
**kwargs,
31+
):
32+
data_format = standardize_data_format(data_format)
33+
gn_axis = -1 if data_format == "channels_last" else 1
34+
35+
# === Functional Model ===
36+
latent_inputs = layers.Input(shape=latent_shape)
37+
38+
x = layers.Conv2D(
39+
stackwise_num_filters[0],
40+
3,
41+
1,
42+
padding="same",
43+
data_format=data_format,
44+
dtype=dtype,
45+
name="input_projection",
46+
)(latent_inputs)
47+
x = apply_resnet_block(
48+
x,
49+
stackwise_num_filters[0],
50+
data_format=data_format,
51+
dtype=dtype,
52+
name="input_block0",
53+
)
54+
x = VAEAttention(
55+
stackwise_num_filters[0],
56+
data_format=data_format,
57+
dtype=dtype,
58+
name="input_attention",
59+
)(x)
60+
x = apply_resnet_block(
61+
x,
62+
stackwise_num_filters[0],
63+
data_format=data_format,
64+
dtype=dtype,
65+
name="input_block1",
66+
)
67+
68+
# Stacks.
69+
for i, filters in enumerate(stackwise_num_filters):
70+
for j in range(stackwise_num_blocks[i]):
71+
x = apply_resnet_block(
72+
x,
73+
filters,
74+
data_format=data_format,
75+
dtype=dtype,
76+
name=f"block{i}_{j}",
77+
)
78+
if i != len(stackwise_num_filters) - 1:
79+
# No upsamling in the last blcok.
80+
x = layers.UpSampling2D(
81+
2,
82+
data_format=data_format,
83+
dtype=dtype,
84+
name=f"upsample_{i}",
85+
)(x)
86+
x = layers.Conv2D(
87+
filters,
88+
3,
89+
1,
90+
padding="same",
91+
data_format=data_format,
92+
dtype=dtype,
93+
name=f"upsample_{i}_conv",
94+
)(x)
95+
96+
# Ouput block.
97+
x = layers.GroupNormalization(
98+
groups=32,
99+
axis=gn_axis,
100+
epsilon=1e-6,
101+
dtype=dtype,
102+
name="output_norm",
103+
)(x)
104+
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
105+
image_outputs = layers.Conv2D(
106+
output_channels,
107+
3,
108+
1,
109+
padding="same",
110+
data_format=data_format,
111+
dtype=dtype,
112+
name="output_projection",
113+
)(x)
114+
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
115+
116+
# === Config ===
117+
self.stackwise_num_filters = stackwise_num_filters
118+
self.stackwise_num_blocks = stackwise_num_blocks
119+
self.output_channels = output_channels
120+
self.latent_shape = latent_shape
121+
122+
def get_config(self):
123+
config = super().get_config()
124+
config.update(
125+
{
126+
"stackwise_num_filters": self.stackwise_num_filters,
127+
"stackwise_num_blocks": self.stackwise_num_blocks,
128+
"output_channels": self.output_channels,
129+
"image_shape": self.latent_shape,
130+
}
131+
)
132+
return config
133+
134+
135+
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
136+
data_format = standardize_data_format(data_format)
137+
gn_axis = -1 if data_format == "channels_last" else 1
138+
input_filters = x.shape[gn_axis]
139+
140+
residual = x
141+
x = layers.GroupNormalization(
142+
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
143+
)(x)
144+
x = layers.Activation("swish", dtype=dtype)(x)
145+
x = layers.Conv2D(
146+
filters,
147+
3,
148+
1,
149+
padding="same",
150+
data_format=data_format,
151+
dtype=dtype,
152+
name=f"{name}_conv1",
153+
)(x)
154+
x = layers.GroupNormalization(
155+
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
156+
)(x)
157+
x = layers.Activation("swish")(x)
158+
x = layers.Conv2D(
159+
filters,
160+
3,
161+
1,
162+
padding="same",
163+
data_format=data_format,
164+
dtype=dtype,
165+
name=f"{name}_conv2",
166+
)(x)
167+
if input_filters != filters:
168+
residual = layers.Conv2D(
169+
filters,
170+
1,
171+
1,
172+
data_format=data_format,
173+
dtype=dtype,
174+
name=f"{name}_residual_projection",
175+
)(residual)
176+
x = layers.Add(dtype=dtype)([residual, x])
177+
return x

0 commit comments

Comments
 (0)