Skip to content

chore: revert attention decomposition due to flux bug #3332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,6 +2730,38 @@ def aten_ops_max_pool(
)


def attention_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
activation,
addmm,
arange,
attention,
cast,
cat,
condition,
Expand Down
165 changes: 165 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import math
from typing import Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.types import TRTTensor


def tril(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
# the lower triangle of the tensor means the rows greater than and equal to the cols
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
)
# get the rows
row_tensor = impl.elementwise.trunc_div(
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
)
# get the cols
col_tensor = impl.elementwise.fmod(
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
)
cond = impl.elementwise.ge(
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
)
return impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", cond, [row, col]
)


def scaled_dot_product_attention(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
is_causal: bool,
scale: Optional[float],
) -> TRTTensor:
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
mm = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_mm",
query,
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
if scale is None:
scale = query.shape[-1]
if scale < 0:
# dynamic shape
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
else:
# static shape
sqrt_scaled = math.sqrt(scale)
scaled = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
sqrt_scaled,
)
else:
scaled = impl.elementwise.mul(
ctx,
target,
source_ir,
name + "_scale",
mm,
scale,
)

if is_causal:
L, S = query.shape[-2], key.shape[-2]
if L >= 0 and S >= 0:
# static shape
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
else:
# if any of the L or S is dynamic shape
if L < 0:
L = impl.shape.shape(
ctx, target, source_ir, name + "_shape_0", query, -2
)
if S < 0:
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)

LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)

# this is to generate a tensor which has shape (L, S), type is int32
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
)
shape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
)

# since we want our attn_bias to be in float32, so cast it to float32
shape_tensor = cast_trt_tensor(
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
)

# initialize the attn_bias as the zeros tensor
attn_bias = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
)

# generate the mask tensor
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)
inf_tensor = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
)
cond = impl.elementwise.eq(
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
)
# mask out the certain part of the attn_bias
attn_bias = impl.condition.select(
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
)

scaled = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1, False
)
out = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_out",
softmax,
value,
)

return out
130 changes: 2 additions & 128 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -423,135 +423,9 @@ def instance_norm_decomposition(
)


@register_torch_trt_decomposition(
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value


@register_torch_trt_decomposition(
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_flash_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, None, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_efficient_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_cudnn_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
)
) # type: ignore
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
input = args[0]
shape = args[0].shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
Expand All @@ -22,6 +23,7 @@
repair_input_as_output,
fuse_prims_broadcast,
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
Expand Down
Loading
Loading