Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Vovanm88 committed Aug 13, 2024
0 parents commit f19a097
Show file tree
Hide file tree
Showing 7 changed files with 1,086 additions and 0 deletions.
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS

__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
219 changes: 219 additions & 0 deletions layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import math
from dataclasses import dataclass

import torch
from einops import rearrange
from torch import Tensor, nn

from .xflux.src.flux.math import attention, rope
from .xflux.src.flux.modules.layers import LoRALinearLayer

def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)

args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding

class DoubleStreamBlockLorasMixerProcessor(nn.Module):
def __init__(self,):
super().__init__()
self.qkv_lora1 = []
self.proj_lora1 = []
self.qkv_lora2 = []
self.proj_lora2 = []
self.lora_weight = []
self.names = []
def add_lora(self, processor):
if isinstance(processor, DoubleStreamBlockLorasMixerProcessor):
self.qkv_lora1+=processor.qkv_lora1
self.qkv_lora2+=processor.qkv_lora2
self.proj_lora1+=processor.proj_lora1
self.proj_lora2+=processor.proj_lora2
self.lora_weight+=processor.lora_weight
else:
if hasattr(processor, "qkv_lora1"):
self.qkv_lora1.append(processor.qkv_lora1)
if hasattr(processor, "proj_lora1"):
self.proj_lora1.append(processor.proj_lora1)
if hasattr(processor, "qkv_lora2"):
self.qkv_lora2.append(processor.qkv_lora2)
if hasattr(processor, "proj_lora2"):
self.proj_lora2.append(processor.proj_lora2)
if hasattr(processor, "lora_weight"):
self.lora_weight.append(processor.lora_weight)
def get_loras(self):
return (
self.qkv_lora1, self.qkv_lora2,
self.proj_lora1, self.proj_lora2,
self.lora_weight
)
def set_loras(self, qkv1s, qkv2s, proj1s, proj2s, w8s):
for el in qkv1s:
self.qkv_lora1.append(el)
for el in qkv2s:
self.qkv_lora2.append(el)
for el in proj1s:
self.proj_lora1.append(el)
for el in proj2s:
self.proj_lora2.append(el)
for el in w8s:
self.lora_weight.append(el)

def add_shift(self, layer, origin, inputs, gating = 1.0):
#shift = torch.zeros_like(origin)
count = len(layer)
for i in range(count):
origin += layer[i](inputs)*self.lora_weight[i]*gating

def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
img_mod1, img_mod2 = attn.img_mod(vec)
txt_mod1, txt_mod2 = attn.txt_mod(vec)

# prepare image for attention
img_modulated = attn.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift

#img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
img_qkv = attn.img_attn.qkv(img_modulated)
#print(self.qkv_lora1)
self.add_shift(self.qkv_lora1, img_qkv, img_modulated)


img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = attn.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift


#txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
txt_qkv = attn.txt_attn.qkv(txt_modulated)
self.add_shift(self.qkv_lora2, txt_qkv, txt_modulated)

txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn1 = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]

# calculate the img bloks
#img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
self.add_shift(self.proj_lora1, img, img_attn, img_mod1.gate)

img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
#txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
self.add_shift(self.proj_lora2, txt, txt_attn, txt_mod1.gate)

txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
return img, txt

class DoubleStreamBlockLoraProcessor(nn.Module):
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
super().__init__()
self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
self.lora_weight = lora_weight

def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
img_mod1, img_mod2 = attn.img_mod(vec)
txt_mod1, txt_mod2 = attn.txt_mod(vec)

# prepare image for attention
img_modulated = attn.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = attn.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn1 = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
return img, txt

class DoubleStreamBlockProcessor(nn.Module):
def __init__(self):
super().__init__()
def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
img_mod1, img_mod2 = attn.img_mod(vec)
txt_mod1, txt_mod2 = attn.txt_mod(vec)

# prepare image for attention
img_modulated = attn.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = attn.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = attn.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = attn.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn1 = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
return img, txt
def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
self.__call__(attn, img, txt, vec, pe, **attention_kwargs)
Loading

0 comments on commit f19a097

Please sign in to comment.