File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed
keras_nlp/src/models/stable_diffusion_v3 Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -237,11 +237,11 @@ def get_config(self):
237237class MMDiT (keras .Model ):
238238 def __init__ (
239239 self ,
240- patch_size , # 2
241- num_heads , # 24
242- hidden_dim , # 64 * 24
243- depth , # 24
244- position_size , # 192
240+ patch_size ,
241+ num_heads ,
242+ hidden_dim ,
243+ depth ,
244+ position_size ,
245245 output_dim ,
246246 mlp_ratio = 4.0 ,
247247 latent_shape = (64 , 64 , 16 ),
@@ -253,7 +253,9 @@ def __init__(
253253 ):
254254 data_format = standardize_data_format (data_format )
255255 if data_format != "channels_last" :
256- raise NotImplementedError
256+ raise NotImplementedError (
257+ "Currently only 'channels_last' is supported."
258+ )
257259 position_sequence_length = position_size * position_size
258260 output_dim_in_final = patch_size ** 2 * output_dim
259261
You can’t perform that action at this time.
0 commit comments