Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_torch_xla_version,
is_xformers_available,
is_xformers_version,
is_mindie_sd_available,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS

Expand All @@ -63,6 +64,7 @@
_CAN_USE_NPU_ATTN = is_torch_npu_available()
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
_CAN_USE_MINDIESD_ATTN = is_mindie_sd_available()


if _CAN_USE_FLASH_ATTN:
Expand Down Expand Up @@ -142,6 +144,13 @@
else:
xops = None


if _CAN_USE_MINDIESD_ATTN:
from mindiesd import attention_forward as mindie_sd_attn_forward
else:
mindie_sd_attn_forward = None


# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
Expand Down Expand Up @@ -215,6 +224,9 @@ class AttentionBackendName(str, Enum):
# `xformers`
XFORMERS = "xformers"

# mindie_sd
_MINDIE_SD_LASER = "_mindie_sd_la"


class _AttentionBackendRegistry:
_backends = {}
Expand Down Expand Up @@ -470,6 +482,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
)

elif backend == AttentionBackendName._MINDIE_SD_LASER:
if not _CAN_USE_MINDIESD_ATTN:
raise RuntimeError(
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
)


@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
Expand Down Expand Up @@ -893,6 +911,47 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")


def _mindie_sd_laser_attn_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for MindIE SD Laser Attention.")
if return_lse:
raise ValueError("MindIE SD attention backend does not support setting `return_lse=True`.")

out = mindie_sd_attn_forward(
query,
key,
value,
opt_mode="manual",
op_type="ascend_laser_attention",
layout="BNSD"
)

# out = out.transpose(1, 2).contiguous()

return out

def _mindie_sd_laser_attn_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for MindIE SD Laser Attention.")


# ===== Context parallel =====


Expand Down Expand Up @@ -2012,3 +2071,47 @@ def _xformers_attention(
out = out.flatten(2, 3)

return out


@_AttentionBackendRegistry.register(
AttentionBackendName._MINDIE_SD_LASER,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _mindie_sd_laser_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if return_lse:
raise ValueError("MINDIE SD attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
# query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = mindie_sd_attn_forward(
query,
key,
value,
opt_mode="manual",
op_type="ascend_laser_attention",
layout="BNSD"
)
# out = out.transpose(1, 2).contiguous()
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_mindie_sd_laser_attn_forward_op,
backward_op=_mindie_sd_laser_attn_backward_op,
_parallel_config=_parallel_config,
)
return out
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
is_wandb_available,
is_xformers_available,
is_xformers_version,
is_mindie_sd_available,
requires_backends,
)
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_mindie_sd_available, _mindie_sd_version = _is_package_available("mindiesd")


def is_torch_available():
Expand Down Expand Up @@ -414,6 +415,9 @@ def is_aiter_available():
def is_kornia_available():
return _kornia_available

def is_mindie_sd_available():
return _mindie_sd_available


# docstyle-ignore
FLAX_IMPORT_ERROR = """
Expand Down