Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 87 additions & 14 deletions comfy/ldm/kandinsky5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.kandinsky5.utils_nabla import (
fractal_flatten,
fractal_unflatten,
fast_sta_nabla,
nabla,
)

def attention(q, k, v, heads, transformer_options={}):
return optimized_attention(
Expand Down Expand Up @@ -116,14 +122,17 @@ def _compute_qk(self, x, freqs, proj_fn, norm_fn):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)

def _forward(self, x, freqs, transformer_options={}):
def _forward(self, x, freqs, sparse_params=None, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
if sparse_params is None:
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
else:
out = nabla(q, k, v, sparse_params)
return self.out_layer(out)

def _forward_chunked(self, x, freqs, transformer_options={}):
def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
Expand All @@ -135,14 +144,17 @@ def process_chunks(proj_fn, norm_fn):
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
if sparse_params is None:
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
else:
out = nabla(q, k, v, sparse_params)
return self.out_layer(out)

def forward(self, x, freqs, transformer_options={}):
def forward(self, x, freqs, sparse_params=None, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)


class CrossAttention(SelfAttention):
Expand Down Expand Up @@ -251,12 +263,12 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=Non
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)

def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
Expand Down Expand Up @@ -369,21 +381,82 @@ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_o

visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)

blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"

B, _, T, H, W = x.shape
NABLA_THR = 31 # long (10 sec) generation
if T > NABLA_THR:
assert self.patch_size[0] == 1

# pro video model uses lower P at higher resolutions
P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9

freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
pt, ph, pw = self.patch_size
T, H, W = T // pt, H // ph, W // pw

wT, wW, wH = 11, 3, 3
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)

sparse_params = dict(
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
attention_type="nabla",
to_fractal=True,
P=P,
wT=wT, wW=wW, wH=wH,
add_sta=True,
visual_shape=(T, H, W),
method="topcdf",
)
else:
sparse_params = None
visual_embed = visual_embed.flatten(1, -2)

for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
return block(
x=args["x"],
context=args["context"],
time_embed=args["time_embed"],
freqs=args["freqs"],
sparse_params=args.get("sparse_params"),
transformer_options=args.get("transformer_options"),
)
visual_embed = blocks_replace[("double_block", i)](
{
"x": visual_embed,
"context": context,
"time_embed": time_embed,
"freqs": freqs,
"sparse_params": sparse_params,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)["x"]
else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = block(
visual_embed,
context,
time_embed,
freqs=freqs,
sparse_params=sparse_params,
transformer_options=transformer_options,
)

if T > NABLA_THR:
visual_embed = fractal_unflatten(
visual_embed,
visual_shape[1:],
)
else:
visual_embed = visual_embed.reshape(*visual_shape, -1)

visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)

def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
Expand Down
146 changes: 146 additions & 0 deletions comfy/ldm/kandinsky5/utils_nabla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import math

import torch
from torch import Tensor
from torch.nn.attention.flex_attention import BlockMask, flex_attention


def fractal_flatten(x, rope, shape):
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
x = x.flatten(1, 2)
rope = rope.flatten(1, 2)
return x, rope


def fractal_unflatten(x, shape):
pixel_size = 8
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
return x

def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
g1,
height // g2,
g2,
width // g3,
g3,
*x.shape[dim + 3 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 2,
dim + 4,
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x


def local_merge(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
height // g2,
width // g3,
g1,
g2,
g3,
*x.shape[dim + 2 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 3,
dim + 1,
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x

def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor:
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
.reshape(H, H, W, W)
.transpose(1, 2)
.flatten()
)
sta = (
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
.reshape(T, T, H * W, H * W)
.transpose(1, 2)
)
return sta.reshape(T * H * W, T * H * W)

def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask:
# Map estimation
B, h, S, D = q.shape
s1 = S // 64
qa = q.reshape(B, h, s1, 64, D).mean(-2)
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
map = qa @ ka

map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)

# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
)

@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
def nabla(query, key, value, sparse_params=None):
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
block_mask = nablaT_v2(
query,
key,
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
out = (
flex_attention(
query,
key,
value,
block_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
out = out.flatten(-2, -1)
return out
3 changes: 3 additions & 0 deletions comfy_extras/nodes_kandinsky5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def define_schema(cls):

@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
if length > 121: # 10 sec generation, for nabla
height = 128 * round(height / 128)
width = 128 * round(width / 128)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond_latent_out = {}
if start_image is not None:
Expand Down