Skip to content

Commit e9aae31

Browse files
Z Image model. (#10892)
1 parent 0c18842 commit e9aae31

File tree

7 files changed

+199
-152
lines changed

7 files changed

+199
-152
lines changed

comfy/ldm/lumina/model.py

Lines changed: 81 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
1212
from comfy.ldm.modules.attention import optimized_attention_masked
1313
from comfy.ldm.flux.layers import EmbedND
14+
from comfy.ldm.flux.math import apply_rope
1415
import comfy.patcher_extension
1516

1617

@@ -31,6 +32,7 @@ def __init__(
3132
n_heads: int,
3233
n_kv_heads: Optional[int],
3334
qk_norm: bool,
35+
out_bias: bool = False,
3436
operation_settings={},
3537
):
3638
"""
@@ -59,7 +61,7 @@ def __init__(
5961
self.out = operation_settings.get("operations").Linear(
6062
n_heads * self.head_dim,
6163
dim,
62-
bias=False,
64+
bias=out_bias,
6365
device=operation_settings.get("device"),
6466
dtype=operation_settings.get("dtype"),
6567
)
@@ -70,35 +72,6 @@ def __init__(
7072
else:
7173
self.q_norm = self.k_norm = nn.Identity()
7274

73-
@staticmethod
74-
def apply_rotary_emb(
75-
x_in: torch.Tensor,
76-
freqs_cis: torch.Tensor,
77-
) -> torch.Tensor:
78-
"""
79-
Apply rotary embeddings to input tensors using the given frequency
80-
tensor.
81-
82-
This function applies rotary embeddings to the given query 'xq' and
83-
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
84-
input tensors are reshaped as complex numbers, and the frequency tensor
85-
is reshaped for broadcasting compatibility. The resulting tensors
86-
contain rotary embeddings and are returned as real tensors.
87-
88-
Args:
89-
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
90-
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
91-
exponentials.
92-
93-
Returns:
94-
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
95-
and key tensor with rotary embeddings.
96-
"""
97-
98-
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
99-
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
100-
return t_out.reshape(*x_in.shape)
101-
10275
def forward(
10376
self,
10477
x: torch.Tensor,
@@ -134,8 +107,7 @@ def forward(
134107
xq = self.q_norm(xq)
135108
xk = self.k_norm(xk)
136109

137-
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
138-
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
110+
xq, xk = apply_rope(xq, xk, freqs_cis)
139111

140112
n_rep = self.n_local_heads // self.n_local_kv_heads
141113
if n_rep >= 1:
@@ -215,6 +187,8 @@ def __init__(
215187
norm_eps: float,
216188
qk_norm: bool,
217189
modulation=True,
190+
z_image_modulation=False,
191+
attn_out_bias=False,
218192
operation_settings={},
219193
) -> None:
220194
"""
@@ -235,10 +209,10 @@ def __init__(
235209
super().__init__()
236210
self.dim = dim
237211
self.head_dim = dim // n_heads
238-
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
212+
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
239213
self.feed_forward = FeedForward(
240214
dim=dim,
241-
hidden_dim=4 * dim,
215+
hidden_dim=dim,
242216
multiple_of=multiple_of,
243217
ffn_dim_multiplier=ffn_dim_multiplier,
244218
operation_settings=operation_settings,
@@ -252,16 +226,27 @@ def __init__(
252226

253227
self.modulation = modulation
254228
if modulation:
255-
self.adaLN_modulation = nn.Sequential(
256-
nn.SiLU(),
257-
operation_settings.get("operations").Linear(
258-
min(dim, 1024),
259-
4 * dim,
260-
bias=True,
261-
device=operation_settings.get("device"),
262-
dtype=operation_settings.get("dtype"),
263-
),
264-
)
229+
if z_image_modulation:
230+
self.adaLN_modulation = nn.Sequential(
231+
operation_settings.get("operations").Linear(
232+
min(dim, 256),
233+
4 * dim,
234+
bias=True,
235+
device=operation_settings.get("device"),
236+
dtype=operation_settings.get("dtype"),
237+
),
238+
)
239+
else:
240+
self.adaLN_modulation = nn.Sequential(
241+
nn.SiLU(),
242+
operation_settings.get("operations").Linear(
243+
min(dim, 1024),
244+
4 * dim,
245+
bias=True,
246+
device=operation_settings.get("device"),
247+
dtype=operation_settings.get("dtype"),
248+
),
249+
)
265250

266251
def forward(
267252
self,
@@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
323308
The final layer of NextDiT.
324309
"""
325310

326-
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
311+
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
327312
super().__init__()
328313
self.norm_final = operation_settings.get("operations").LayerNorm(
329314
hidden_size,
@@ -340,10 +325,15 @@ def __init__(self, hidden_size, patch_size, out_channels, operation_settings={})
340325
dtype=operation_settings.get("dtype"),
341326
)
342327

328+
if z_image_modulation:
329+
min_mod = 256
330+
else:
331+
min_mod = 1024
332+
343333
self.adaLN_modulation = nn.Sequential(
344334
nn.SiLU(),
345335
operation_settings.get("operations").Linear(
346-
min(hidden_size, 1024),
336+
min(hidden_size, min_mod),
347337
hidden_size,
348338
bias=True,
349339
device=operation_settings.get("device"),
@@ -373,12 +363,16 @@ def __init__(
373363
n_heads: int = 32,
374364
n_kv_heads: Optional[int] = None,
375365
multiple_of: int = 256,
376-
ffn_dim_multiplier: Optional[float] = None,
366+
ffn_dim_multiplier: float = 4.0,
377367
norm_eps: float = 1e-5,
378368
qk_norm: bool = False,
379369
cap_feat_dim: int = 5120,
380370
axes_dims: List[int] = (16, 56, 56),
381371
axes_lens: List[int] = (1, 512, 512),
372+
rope_theta=10000.0,
373+
z_image_modulation=False,
374+
time_scale=1.0,
375+
pad_tokens_multiple=None,
382376
image_model=None,
383377
device=None,
384378
dtype=None,
@@ -390,6 +384,8 @@ def __init__(
390384
self.in_channels = in_channels
391385
self.out_channels = in_channels
392386
self.patch_size = patch_size
387+
self.time_scale = time_scale
388+
self.pad_tokens_multiple = pad_tokens_multiple
393389

394390
self.x_embedder = operation_settings.get("operations").Linear(
395391
in_features=patch_size * patch_size * in_channels,
@@ -411,6 +407,7 @@ def __init__(
411407
norm_eps,
412408
qk_norm,
413409
modulation=True,
410+
z_image_modulation=z_image_modulation,
414411
operation_settings=operation_settings,
415412
)
416413
for layer_id in range(n_refiner_layers)
@@ -434,7 +431,7 @@ def __init__(
434431
]
435432
)
436433

437-
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
434+
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
438435
self.cap_embedder = nn.Sequential(
439436
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
440437
operation_settings.get("operations").Linear(
@@ -457,18 +454,24 @@ def __init__(
457454
ffn_dim_multiplier,
458455
norm_eps,
459456
qk_norm,
457+
z_image_modulation=z_image_modulation,
458+
attn_out_bias=False,
460459
operation_settings=operation_settings,
461460
)
462461
for layer_id in range(n_layers)
463462
]
464463
)
465464
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
466-
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
465+
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
466+
467+
if self.pad_tokens_multiple is not None:
468+
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
469+
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
467470

468471
assert (dim // n_heads) == sum(axes_dims)
469472
self.axes_dims = axes_dims
470473
self.axes_lens = axes_lens
471-
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
474+
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
472475
self.dim = dim
473476
self.n_heads = n_heads
474477

@@ -503,108 +506,42 @@ def patchify_and_embed(
503506
bsz = len(x)
504507
pH = pW = self.patch_size
505508
device = x[0].device
506-
dtype = x[0].dtype
507509

508-
if cap_mask is not None:
509-
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
510-
else:
511-
l_effective_cap_len = [num_tokens] * bsz
510+
if self.pad_tokens_multiple is not None:
511+
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
512+
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
512513

513-
if cap_mask is not None and not torch.is_floating_point(cap_mask):
514-
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
514+
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
515+
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
515516

516-
img_sizes = [(img.size(1), img.size(2)) for img in x]
517-
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
517+
B, C, H, W = x.shape
518+
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
518519

519-
max_seq_len = max(
520-
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
521-
)
522-
max_cap_len = max(l_effective_cap_len)
523-
max_img_len = max(l_effective_img_len)
524-
525-
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
526-
527-
for i in range(bsz):
528-
cap_len = l_effective_cap_len[i]
529-
img_len = l_effective_img_len[i]
530-
H, W = img_sizes[i]
531-
H_tokens, W_tokens = H // pH, W // pW
532-
assert H_tokens * W_tokens == img_len
533-
534-
rope_options = transformer_options.get("rope_options", None)
535-
h_scale = 1.0
536-
w_scale = 1.0
537-
h_start = 0
538-
w_start = 0
539-
if rope_options is not None:
540-
h_scale = rope_options.get("scale_y", 1.0)
541-
w_scale = rope_options.get("scale_x", 1.0)
542-
543-
h_start = rope_options.get("shift_y", 0.0)
544-
w_start = rope_options.get("shift_x", 0.0)
545-
546-
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
547-
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
548-
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
549-
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
550-
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
551-
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
552-
553-
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
554-
555-
# build freqs_cis for cap and image individually
556-
cap_freqs_cis_shape = list(freqs_cis.shape)
557-
# cap_freqs_cis_shape[1] = max_cap_len
558-
cap_freqs_cis_shape[1] = cap_feats.shape[1]
559-
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
560-
561-
img_freqs_cis_shape = list(freqs_cis.shape)
562-
img_freqs_cis_shape[1] = max_img_len
563-
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
564-
565-
for i in range(bsz):
566-
cap_len = l_effective_cap_len[i]
567-
img_len = l_effective_img_len[i]
568-
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
569-
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
520+
H_tokens, W_tokens = H // pH, W // pW
521+
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
522+
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
523+
x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
524+
x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
570525

571-
# refine context
572-
for layer in self.context_refiner:
573-
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
574-
575-
# refine image
576-
flat_x = []
577-
for i in range(bsz):
578-
img = x[i]
579-
C, H, W = img.size()
580-
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
581-
flat_x.append(img)
582-
x = flat_x
583-
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
584-
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
585-
for i in range(bsz):
586-
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
587-
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
588-
589-
padded_img_embed = self.x_embedder(padded_img_embed)
590-
padded_img_mask = padded_img_mask.unsqueeze(1)
591-
for layer in self.noise_refiner:
592-
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
526+
if self.pad_tokens_multiple is not None:
527+
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
528+
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
529+
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
593530

594-
if cap_mask is not None:
595-
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
596-
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
597-
else:
598-
mask = None
531+
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
599532

600-
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
601-
for i in range(bsz):
602-
cap_len = l_effective_cap_len[i]
603-
img_len = l_effective_img_len[i]
533+
# refine context
534+
for layer in self.context_refiner:
535+
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
604536

605-
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
606-
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
537+
padded_img_mask = None
538+
for layer in self.noise_refiner:
539+
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
607540

541+
padded_full_embed = torch.cat((cap_feats, x), dim=1)
542+
mask = None
543+
img_sizes = [(H, W)] * bsz
544+
l_effective_cap_len = [cap_feats.shape[1]] * bsz
608545
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
609546

610547
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@@ -627,7 +564,7 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwa
627564
y: (N,) tensor of text tokens/features
628565
"""
629566

630-
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
567+
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
631568
adaln_input = t
632569

633570
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
211211
Embeds scalar timesteps into vector representations.
212212
"""
213213

214-
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
214+
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
215215
super().__init__()
216+
if output_size is None:
217+
output_size = hidden_size
216218
self.mlp = nn.Sequential(
217219
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
218220
nn.SiLU(),
219-
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
221+
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
220222
)
221223
self.frequency_embedding_size = frequency_embedding_size
222224

0 commit comments

Comments
 (0)