Skip to content

Commit

Permalink
chore: revert attention decomposition due to flux bug (#3332)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Dec 20, 2024
1 parent 8eff5a6 commit 68c2d45
Show file tree
Hide file tree
Showing 8 changed files with 623 additions and 478 deletions.
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,6 +2730,38 @@ def aten_ops_max_pool(
)


def attention_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
activation,
addmm,
arange,
attention,
cast,
cat,
condition,
Expand Down
165 changes: 165 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import math
from typing import Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.types import TRTTensor


def tril(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
# the lower triangle of the tensor means the rows greater than and equal to the cols
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
)
# get the rows
row_tensor = impl.elementwise.trunc_div(
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
)
# get the cols
col_tensor = impl.elementwise.fmod(
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
)
cond = impl.elementwise.ge(
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
)
return impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", cond, [row, col]
)


def scaled_dot_product_attention(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
is_causal: bool,
scale: Optional[float],
) -> TRTTensor:
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
mm = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_mm",
query,
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
if scale is None:
scale = query.shape[-1]
if scale < 0:
# dynamic shape
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
else:
# static shape
sqrt_scaled = math.sqrt(scale)
scaled = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
sqrt_scaled,
)
else:
scaled = impl.elementwise.mul(
ctx,
target,
source_ir,
name + "_scale",
mm,
scale,
)

if is_causal:
L, S = query.shape[-2], key.shape[-2]
if L >= 0 and S >= 0:
# static shape
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
else:
# if any of the L or S is dynamic shape
if L < 0:
L = impl.shape.shape(
ctx, target, source_ir, name + "_shape_0", query, -2
)
if S < 0:
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)

LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)

# this is to generate a tensor which has shape (L, S), type is int32
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
)
shape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
)

# since we want our attn_bias to be in float32, so cast it to float32
shape_tensor = cast_trt_tensor(
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
)

# initialize the attn_bias as the zeros tensor
attn_bias = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
)

# generate the mask tensor
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)
inf_tensor = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
)
cond = impl.elementwise.eq(
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
)
# mask out the certain part of the attn_bias
attn_bias = impl.condition.select(
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
)

scaled = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1, False
)
out = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_out",
softmax,
value,
)

return out
130 changes: 2 additions & 128 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -423,135 +423,9 @@ def instance_norm_decomposition(
)


@register_torch_trt_decomposition(
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value


@register_torch_trt_decomposition(
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_flash_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, None, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_efficient_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_cudnn_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
)
) # type: ignore
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
input = args[0]
shape = args[0].shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
Expand All @@ -22,6 +23,7 @@
repair_input_as_output,
fuse_prims_broadcast,
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
Expand Down
Loading

0 comments on commit 68c2d45

Please sign in to comment.