-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
replace flux rope with a triton kernel
- Loading branch information
Showing
13 changed files
with
372 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
# https://github.com/black-forest-labs/flux | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from einops import rearrange | ||
from torch import Tensor, nn | ||
|
||
import triton_kernels as tk | ||
|
||
|
||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: | ||
q, k = apply_rope(q, k, pe) | ||
|
||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | ||
x = rearrange(x, "B H L D -> B L (H D)") | ||
|
||
return x | ||
|
||
|
||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: | ||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | ||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | ||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | ||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | ||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | ||
|
||
|
||
class RMSNorm(torch.nn.Module): | ||
def __init__(self, dim: int): | ||
super().__init__() | ||
self.scale = nn.Parameter(torch.ones(dim)) | ||
|
||
def forward(self, x: Tensor): | ||
x_dtype = x.dtype | ||
x = x.float() | ||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) | ||
return (x * rrms).to(dtype=x_dtype) * self.scale | ||
|
||
|
||
class QKNorm(torch.nn.Module): | ||
def __init__(self, dim: int): | ||
super().__init__() | ||
self.query_norm = RMSNorm(dim) | ||
self.key_norm = RMSNorm(dim) | ||
|
||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: | ||
q = self.query_norm(q) | ||
k = self.key_norm(k) | ||
return q.to(v), k.to(v) | ||
|
||
|
||
class SelfAttention(nn.Module): | ||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): | ||
super().__init__() | ||
self.num_heads = num_heads | ||
head_dim = dim // num_heads | ||
|
||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
self.norm = QKNorm(head_dim) | ||
self.proj = nn.Linear(dim, dim) | ||
|
||
def forward(self, x: Tensor, pe: Tensor) -> Tensor: | ||
qkv = self.qkv(x) | ||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | ||
q, k = self.norm(q, k, v) | ||
x = attention(q, k, v, pe=pe) | ||
x = self.proj(x) | ||
return x | ||
|
||
|
||
@dataclass | ||
class ModulationOut: | ||
shift: Tensor | ||
scale: Tensor | ||
gate: Tensor | ||
|
||
|
||
class Modulation(nn.Module): | ||
def __init__(self, dim: int, double: bool): | ||
super().__init__() | ||
self.is_double = double | ||
self.multiplier = 6 if double else 3 | ||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) | ||
|
||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: | ||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) | ||
|
||
return ( | ||
ModulationOut(*out[:3]), | ||
ModulationOut(*out[3:]) if self.is_double else None, | ||
) | ||
|
||
|
||
class DoubleStreamBlock(nn.Module): | ||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): | ||
super().__init__() | ||
|
||
mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||
self.num_heads = num_heads | ||
self.hidden_size = hidden_size | ||
self.img_mod = Modulation(hidden_size, double=True) | ||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) | ||
|
||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
self.img_mlp = nn.Sequential( | ||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), | ||
nn.GELU(approximate="tanh"), | ||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), | ||
) | ||
|
||
self.txt_mod = Modulation(hidden_size, double=True) | ||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) | ||
|
||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
self.txt_mlp = nn.Sequential( | ||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), | ||
nn.GELU(approximate="tanh"), | ||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), | ||
) | ||
|
||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: | ||
img_mod1, img_mod2 = self.img_mod(vec) | ||
txt_mod1, txt_mod2 = self.txt_mod(vec) | ||
|
||
# prepare image for attention | ||
img_modulated = self.img_norm1(img) | ||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift | ||
img_qkv = self.img_attn.qkv(img_modulated) | ||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | ||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) | ||
|
||
# prepare txt for attention | ||
txt_modulated = self.txt_norm1(txt) | ||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift | ||
txt_qkv = self.txt_attn.qkv(txt_modulated) | ||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | ||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) | ||
|
||
# run actual attention | ||
q = torch.cat((txt_q, img_q), dim=2) | ||
k = torch.cat((txt_k, img_k), dim=2) | ||
v = torch.cat((txt_v, img_v), dim=2) | ||
|
||
attn = attention(q, k, v, pe=pe) | ||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] | ||
|
||
# calculate the img bloks | ||
img = img + img_mod1.gate * self.img_attn.proj(img_attn) | ||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) | ||
|
||
# calculate the txt bloks | ||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) | ||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) | ||
return img, txt | ||
|
||
|
||
class SingleStreamBlock(nn.Module): | ||
""" | ||
A DiT block with parallel linear layers as described in | ||
https://arxiv.org/abs/2302.05442 and adapted modulation interface. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
num_heads: int, | ||
mlp_ratio: float = 4.0, | ||
qk_scale: float | None = None, | ||
): | ||
super().__init__() | ||
self.hidden_dim = hidden_size | ||
self.num_heads = num_heads | ||
head_dim = hidden_size // num_heads | ||
self.scale = qk_scale or head_dim**-0.5 | ||
|
||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||
# qkv and mlp_in | ||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) | ||
# proj and mlp_out | ||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) | ||
|
||
self.norm = QKNorm(head_dim) | ||
|
||
self.hidden_size = hidden_size | ||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
|
||
self.mlp_act = nn.GELU(approximate="tanh") | ||
self.modulation = Modulation(hidden_size, double=False) | ||
|
||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: | ||
mod, _ = self.modulation(vec) | ||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift | ||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) | ||
|
||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | ||
q, k = self.norm(q, k, v) | ||
|
||
# compute attention | ||
q, k = apply_rope(q, k, pe) | ||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) | ||
attn = rearrange(attn, "B H L D -> B L (H D)") | ||
|
||
# compute activation in mlp stream, cat again and run second linear layer | ||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) | ||
return x + mod.gate * output | ||
|
||
|
||
class SingleStreamBlockTriton(nn.Module): | ||
""" | ||
A DiT block with parallel linear layers as described in | ||
https://arxiv.org/abs/2302.05442 and adapted modulation interface. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
num_heads: int, | ||
mlp_ratio: float = 4.0, | ||
qk_scale: float | None = None, | ||
): | ||
super().__init__() | ||
self.hidden_dim = hidden_size | ||
self.num_heads = num_heads | ||
head_dim = hidden_size // num_heads | ||
self.scale = qk_scale or head_dim**-0.5 | ||
|
||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||
# qkv and mlp_in | ||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) | ||
# proj and mlp_out | ||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) | ||
|
||
self.norm = QKNorm(head_dim) | ||
|
||
self.hidden_size = hidden_size | ||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
|
||
self.mlp_act = nn.GELU(approximate="tanh") | ||
self.modulation = Modulation(hidden_size, double=False) | ||
|
||
@torch.compile | ||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: | ||
mod, _ = self.modulation(vec) | ||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift | ||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) | ||
|
||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | ||
q, k = self.norm(q, k, v) | ||
|
||
# compute attention | ||
q, k = apply_rope(q, k, pe) | ||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) | ||
attn = rearrange(attn, "B H L D -> B L (H D)") | ||
|
||
# compute activation in mlp stream, cat again and run second linear layer | ||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) | ||
return x + mod.gate * output | ||
|
||
|
||
if __name__ == "__main__": | ||
hidden_size = 3072 | ||
num_heads = 24 | ||
mlp_ratio = 4.0 | ||
head_dim = hidden_size // num_heads | ||
|
||
batch_size = 1 | ||
seq_len = 4336 | ||
|
||
device = "cuda" | ||
|
||
block = SingleStreamBlock( | ||
hidden_size=hidden_size, | ||
num_heads=num_heads, | ||
mlp_ratio=mlp_ratio, | ||
) | ||
block_triton = SingleStreamBlockTriton( | ||
hidden_size=hidden_size, | ||
num_heads=num_heads, | ||
mlp_ratio=mlp_ratio, | ||
) | ||
block_triton.load_state_dict(block.state_dict()) | ||
block = block.to(device) | ||
block_triton = block_triton.to(device) | ||
|
||
x = torch.randn(batch_size, seq_len, hidden_size).to(device) | ||
vec = torch.randn(batch_size, hidden_size).to(device) | ||
pe = torch.randn(batch_size, 1, seq_len, head_dim // 2, 2, 2).to(device) | ||
|
||
out = block(x=x, vec=vec, pe=pe) | ||
out_triton = block_triton(x=x, vec=vec, pe=pe) | ||
|
||
torch.testing.assert_close(out, out_triton, atol=1e-5, rtol=0) | ||
|
||
# warmup | ||
warmup_count = 5 | ||
|
||
for i in range(warmup_count): | ||
out = block(x=x, vec=vec, pe=pe) | ||
for i in range(warmup_count): | ||
out_compiled = block_triton(x=x, vec=vec, pe=pe) | ||
|
||
# run | ||
run_count = 100 | ||
|
||
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | ||
start.record() | ||
for i in range(run_count): | ||
out = block(x=x, vec=vec, pe=pe) | ||
end.record() | ||
torch.cuda.synchronize() | ||
print(f"baseline block time: {start.elapsed_time(end):.2f} ms") | ||
|
||
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | ||
start.record() | ||
for i in range(run_count): | ||
out_compiled = block_triton(x=x, vec=vec, pe=pe) | ||
end.record() | ||
torch.cuda.synchronize() | ||
print(f"compiled block time: {start.elapsed_time(end):.2f} ms") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.