Skip to content

Commit

Permalink
replace flux rope with a triton kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sjjeong94 committed Oct 1, 2024
1 parent e27cf1b commit 07a8a8e
Show file tree
Hide file tree
Showing 13 changed files with 372 additions and 28 deletions.
6 changes: 3 additions & 3 deletions benchmarks/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
)
def bench_layer_norm_modulation(batch_size, seq_len, embed_dim, provider, device="cuda"):
# create data
x = torch.randn(batch_size, seq_len, embed_dim).to(device)
scale = torch.randn(batch_size, 1, embed_dim).to(device)
shift = torch.randn(batch_size, 1, embed_dim).to(device)
x = torch.randn([batch_size, seq_len, embed_dim], device=device)
scale = torch.randn([batch_size, 1, embed_dim], device=device)
shift = torch.randn([batch_size, 1, embed_dim], device=device)

def y_fwd():
if provider == "triton":
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
)
def bench_rms_norm(batch_size, num_heads, seq_len, head_dim, provider, device="cuda"):
# create data
x = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
scale = torch.randn(head_dim).to(device)
x = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
scale = torch.randn([head_dim], device=device)

def y_fwd():
if provider == "triton":
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
)
def bench_apply_rope(batch_size, num_heads, seq_len, head_dim, provider, device="cuda"):
# create data
xq = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
xk = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
freqs_cis = torch.randn(1, 1, seq_len, head_dim // 2, 2, 2).to(device)
xq = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
xk = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
freqs_cis = torch.randn([1, 1, seq_len, head_dim // 2, 2, 2], device=device)

def y_fwd():
if provider == "triton":
Expand Down
12 changes: 1 addition & 11 deletions scripts/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
Expand Down Expand Up @@ -315,7 +305,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
torch.cuda.synchronize()
print(f"baseline block time: {start.elapsed_time(end):.2f} ms")

tart, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start.record()
for i in range(run_count):
out_compiled = block_compiled(x=x, vec=vec, pe=pe)
Expand Down
322 changes: 322 additions & 0 deletions scripts/flux_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
# https://github.com/black-forest-labs/flux

from dataclasses import dataclass

import torch
from einops import rearrange
from torch import Tensor, nn

import triton_kernels as tk


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)

x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

return x


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))

def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale


class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)

def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)


class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)

def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x


@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor


class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)

return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)


class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()

mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)

self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)

def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)

# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt


class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5

self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

self.norm = QKNorm(head_dim)

self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)

def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)

q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)

# compute attention
q, k = apply_rope(q, k, pe)
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v)
attn = rearrange(attn, "B H L D -> B L (H D)")

# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output


class SingleStreamBlockTriton(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5

self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

self.norm = QKNorm(head_dim)

self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)

@torch.compile
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)

q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)

# compute attention
q, k = apply_rope(q, k, pe)
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v)
attn = rearrange(attn, "B H L D -> B L (H D)")

# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output


if __name__ == "__main__":
hidden_size = 3072
num_heads = 24
mlp_ratio = 4.0
head_dim = hidden_size // num_heads

batch_size = 1
seq_len = 4336

device = "cuda"

block = SingleStreamBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
)
block_triton = SingleStreamBlockTriton(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
)
block_triton.load_state_dict(block.state_dict())
block = block.to(device)
block_triton = block_triton.to(device)

x = torch.randn(batch_size, seq_len, hidden_size).to(device)
vec = torch.randn(batch_size, hidden_size).to(device)
pe = torch.randn(batch_size, 1, seq_len, head_dim // 2, 2, 2).to(device)

out = block(x=x, vec=vec, pe=pe)
out_triton = block_triton(x=x, vec=vec, pe=pe)

torch.testing.assert_close(out, out_triton, atol=1e-5, rtol=0)

# warmup
warmup_count = 5

for i in range(warmup_count):
out = block(x=x, vec=vec, pe=pe)
for i in range(warmup_count):
out_compiled = block_triton(x=x, vec=vec, pe=pe)

# run
run_count = 100

start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start.record()
for i in range(run_count):
out = block(x=x, vec=vec, pe=pe)
end.record()
torch.cuda.synchronize()
print(f"baseline block time: {start.elapsed_time(end):.2f} ms")

start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start.record()
for i in range(run_count):
out_compiled = block_triton(x=x, vec=vec, pe=pe)
end.record()
torch.cuda.synchronize()
print(f"compiled block time: {start.elapsed_time(end):.2f} ms")
2 changes: 1 addition & 1 deletion tests/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [256, 512, 1024])
@pytest.mark.parametrize("seq_len", [256, 512, 1024, 4336])
@pytest.mark.parametrize("embed_dim", [1024, 2048, 3072])
@pytest.mark.parametrize("device", ["cuda"])
def test_layer_norm_modulation(batch_size, seq_len, embed_dim, device):
Expand Down
Loading

0 comments on commit 07a8a8e

Please sign in to comment.