Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/runpod_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def wait_for_pod(pod_id):
"Timed out waiting for RunPod to reach RUNNING state")

# Wait for ports to be assigned
max_attempts = 35
max_attempts = 50
attempts = 0
while attempts < max_attempts:
response = requests.get(f"{PODS_API}/{pod_id}", headers=HEADERS)
Expand Down
2 changes: 1 addition & 1 deletion csrc/sliding_tile_attention/st_attn/st_attn_h100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ sta_forward(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o,
auto threads = NUM_WORKERS * kittens::WARP_THREADS;
if (has_text) {
// TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_DIM*4) == 0, "sequence length must be divisible by 192");
dim3 grid_image(seq_len/(CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4-2), qo_heads, batch);
dim3 grid_image(seq_len/(CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4)-2, qo_heads, batch);
dim3 grid_text(2, qo_heads, batch);
if (!process_text) {
if (kernel_t_size == 3 && kernel_h_size == 3 && kernel_w_size == 3) {
Expand Down
3 changes: 2 additions & 1 deletion fastvideo/v1/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ def __init__(
num_heads: int,
head_size: int,
softmax_scale: float,
dropout_rate: float = 0.0,
causal: bool = False,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
raise NotImplementedError

Expand Down
27 changes: 18 additions & 9 deletions fastvideo/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from typing import List, Optional, Type

import torch
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func as flash_attn_2_func

try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func

# flash_attn 3 has slightly different API: it returns lse by default
flash_attn_func = lambda q, k, v, softmax_scale, causal: flash_attn_3_func(
q, k, v, softmax_scale, causal)[0]
except ImportError:
flash_attn_func = flash_attn_2_func

from fastvideo.v1.attention.backends.abstract import (AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -45,12 +54,12 @@ def __init__(
self,
num_heads: int,
head_size: int,
dropout_rate: float,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.dropout_rate = dropout_rate
self.causal = causal
self.softmax_scale = softmax_scale

Expand All @@ -61,10 +70,10 @@ def forward(
value: torch.Tensor,
attn_metadata: AttentionMetadata,
):
output = flash_attn_func(query,
key,
value,
dropout_p=self.dropout_rate,
softmax_scale=self.softmax_scale,
causal=self.causal)
output = flash_attn_func(
query, # type: ignore[no-untyped-call]
key,
value,
softmax_scale=self.softmax_scale,
causal=self.causal)
return output
7 changes: 4 additions & 3 deletions fastvideo/v1/attention/backends/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def __init__(
self,
num_heads: int,
head_size: int,
dropout_rate: float,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.dropout_rate = dropout_rate
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout = extra_impl_args.get("dropout_p", 0.0)

def forward(
self,
Expand All @@ -60,7 +61,7 @@ def forward(
value = value.transpose(1, 2)
attn_kwargs = {
"attn_mask": None,
"dropout_p": self.dropout_rate,
"dropout_p": self.dropout,
"is_causal": self.causal,
"scale": self.softmax_scale
}
Expand Down
23 changes: 11 additions & 12 deletions fastvideo/v1/attention/backends/sliding_tile_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_builder_cls() -> Type["SlidingTileAttentionMetadataBuilder"]:

@dataclass
class SlidingTileAttentionMetadata(AttentionMetadata):
text_length: int
current_timestep: int


class SlidingTileAttentionMetadataBuilder(AttentionMetadataBuilder):
Expand All @@ -80,9 +80,7 @@ def build(
inference_args: InferenceArgs,
) -> SlidingTileAttentionMetadata:

# TODO(will): not implemented yet
return SlidingTileAttentionMetadata(current_timestep=current_timestep,
text_length=0)
return SlidingTileAttentionMetadata(current_timestep=current_timestep, )


class SlidingTileAttentionImpl(AttentionImpl):
Expand All @@ -91,10 +89,11 @@ def __init__(
self,
num_heads: int,
head_size: int,
dropout_rate: float,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
# TODO(will-refactor): for now this is the mask strategy, but maybe we should
# have a more general config for STA?
Expand All @@ -106,7 +105,7 @@ def __init__(
mask_strategy = json.load(f)

mask_strategy = dict_to_3d_list(mask_strategy)

self.prefix = prefix
self.mask_strategy = mask_strategy
sp_group = get_sp_group()
self.sp_size = sp_group.world_size
Expand Down Expand Up @@ -171,8 +170,11 @@ def forward(
assert self.mask_strategy[
0] is not None, "mask_strategy[0] cannot be None for SlidingTileAttention"

text_length = attn_metadata.text_length

timestep = attn_metadata.current_timestep
# pattern:'.double_blocks.0.attn.impl' or '.single_blocks.0.attn.impl'
layer_idx = int(self.prefix.split('.')[-3])
# TODO: remove hardcode
text_length = q.shape[1] - (30 * 48 * 80)
query = q.transpose(1, 2)
key = k.transpose(1, 2)
value = v.transpose(1, 2)
Expand All @@ -182,13 +184,10 @@ def forward(
current_rank = sp_group.rank_in_group
start_head = current_rank * head_num
windows = [
self.mask_strategy[head_idx + start_head]
self.mask_strategy[timestep][layer_idx][head_idx + start_head]
for head_idx in range(head_num)
]

hidden_states = sliding_tile_attention(query, key, value, windows,
text_length).transpose(1, 2)

hidden_states = hidden_states.transpose(1, 2)

return hidden_states
29 changes: 15 additions & 14 deletions fastvideo/v1/attention/layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional
from typing import List, Optional

import torch
import torch.nn as nn
Expand All @@ -12,6 +12,7 @@
from fastvideo.v1.distributed.parallel_state import (
get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size)
from fastvideo.v1.forward_context import ForwardContext, get_forward_context
from fastvideo.v1.platforms import _Backend


class DistributedAttention(nn.Module):
Expand All @@ -22,13 +23,12 @@ def __init__(self,
num_heads: int,
head_size: int,
num_kv_heads: Optional[int] = None,
dropout_rate: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
supported_attention_backends: Optional[List[_Backend]] = None,
prefix: str = "",
**extra_impl_args) -> None:
super().__init__()
# self.dropout_rate = dropout_rate
# self.causal = causal
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
Expand All @@ -38,14 +38,17 @@ def __init__(self,
num_kv_heads = num_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, dtype, distributed=True)
attn_backend = get_attn_backend(
head_size,
dtype,
supported_attention_backends=supported_attention_backends)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads=num_heads,
head_size=head_size,
dropout_rate=dropout_rate,
causal=causal,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
prefix=f"{prefix}.impl",
**extra_impl_args)
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -97,7 +100,6 @@ def forward(
qkv = sequence_model_parallel_all_to_all_4D(qkv,
scatter_dim=2,
gather_dim=1)

# Apply backend-specific preprocess_qkv
qkv = self.impl.preprocess_qkv(qkv, ctx_attn_metadata)

Expand All @@ -124,8 +126,7 @@ def forward(
output = output[:, :seq_len * world_size]
# TODO: make this asynchronous
replicated_output = sequence_model_parallel_all_gather(
replicated_output, dim=2)

replicated_output.contiguous(), dim=2)
# Apply backend-specific postprocess_output
output = self.impl.postprocess_output(output, ctx_attn_metadata)

Expand All @@ -143,13 +144,11 @@ def __init__(self,
num_heads: int,
head_size: int,
num_kv_heads: Optional[int] = None,
dropout_rate: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
supported_attention_backends: Optional[List[_Backend]] = None,
**extra_impl_args) -> None:
super().__init__()
# self.dropout_rate = dropout_rate
# self.causal = causal
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
Expand All @@ -158,11 +157,13 @@ def __init__(self,
num_kv_heads = num_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, dtype, distributed=False)
attn_backend = get_attn_backend(
head_size,
dtype,
supported_attention_backends=supported_attention_backends)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads=num_heads,
head_size=head_size,
dropout_rate=dropout_rate,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
causal=causal,
Expand Down
30 changes: 7 additions & 23 deletions fastvideo/v1/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import os
from contextlib import contextmanager
from functools import cache
from typing import Generator, Optional, Type, cast
from typing import Generator, List, Optional, Type, cast

import torch

Expand Down Expand Up @@ -82,29 +81,15 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
distributed: bool,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
distributed=distributed,
)


@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
distributed: bool,
supported_attention_backends: Optional[List[_Backend]] = None,
) -> Type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE FASTVIDEO_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
if not supported_attention_backends:
raise ValueError("supported_attention_backends is empty")
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
Expand All @@ -116,12 +101,11 @@ def _cached_get_attn_backend(
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)

if selected_backend is None:
selected_backend = _Backend.FLASH_ATTN

# get device-specific attn_backend
if selected_backend not in supported_attention_backends:
selected_backend = None
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, distributed)
selected_backend, head_size, dtype)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")
Expand Down
14 changes: 0 additions & 14 deletions fastvideo/v1/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
FASTVIDEO_RINGBUFFER_WARNING_INTERVAL: int = 60
FASTVIDEO_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
FASTVIDEO_USE_TRITON_FLASH_ATTN: bool = False
FASTVIDEO_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
FASTVIDEO_CACHE_ROOT: str = os.path.expanduser("~/.cache/fastvideo")
Expand Down Expand Up @@ -127,18 +125,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"LD_LIBRARY_PATH":
lambda: os.environ.get("LD_LIBRARY_PATH", None),

# flag to control if fastvideo should use triton flash attention
"FASTVIDEO_USE_TRITON_FLASH_ATTN":
lambda:
(os.environ.get("FASTVIDEO_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),

# Force fastvideo to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"FASTVIDEO_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(
os.environ.get("FASTVIDEO_FLASH_ATTN_VERSION", None)),

# Internal flag to enable Dynamo fullgraph capture
"FASTVIDEO_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
Expand Down
4 changes: 3 additions & 1 deletion fastvideo/v1/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class ScaleResidual(nn.Module):
Applies gated residual connection.
"""

def __init__(self) -> None:
def __init__(self, prefix: str = ""):
super().__init__()

def forward(self, residual: torch.Tensor, x: torch.Tensor,
Expand All @@ -139,6 +139,7 @@ def __init__(
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
prefix: str = "",
):
super().__init__()
if norm_type == "rms":
Expand Down Expand Up @@ -189,6 +190,7 @@ def __init__(
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
prefix: str = "",
):
super().__init__()
if norm_type == "rms":
Expand Down
Loading
Loading