Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brevitas make_fx generating different graph #762

Closed
vivekkhandelwal1 opened this issue Nov 17, 2023 · 4 comments
Closed

Brevitas make_fx generating different graph #762

vivekkhandelwal1 opened this issue Nov 17, 2023 · 4 comments

Comments

@vivekkhandelwal1
Copy link

Hi @volcacius, I see a difference in the graph obtained from the upstream PyTorch make_fx and the brevitas make_fx. The graph obtained from the brevitas make_fx does not apply the PyTorch decomposition. I'm using the dev (https://github.com/Xilinx/brevitas/tree/dev) branch. Here's the code to repro this issue.

import torch
from brevitas.backport.fx.experimental.proxy_tensor import (
            make_fx as brevitas_make_fx,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions

class FlashAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key, value):
        return torch.ops.aten._scaled_dot_product_flash_attention(query, key, value, dropout_p=0.0, is_causal=True, return_debug_mask=False, scale=None)[0]

attn = FlashAttention()

query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)

inputs = (query, key, value)

decomps_list = [
        torch.ops.aten._scaled_dot_product_flash_attention.default,
]

fx_g = make_fx(
            attn,
            decomposition_table=get_decompositions(decomps_list),
)(*inputs)

brevitas_fx_g = brevitas_make_fx(
                attn,
                decomposition_table=get_decompositions(decomps_list),
)(*inputs)

print("Torch Fx graph: \n", fx_g.graph)
print("Brevitas Fx graph: \n", brevitas_fx_g.graph)

Graphs obtained:

Torch Fx graph: 
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %empty : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([1, 5, 1, 5],), kwargs = {dtype: torch.float32, device: cpu, pin_memory: False})
    %empty_1 : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([],), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
    %empty_2 : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([],), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
    %empty_3 : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([],), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
    %empty_4 : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([],), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
    %empty_5 : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([],), kwargs = {dtype: torch.float32, device: cpu, pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%arg0_1, 0.668740304976422), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([5, 5],), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
    %tril : [num_users=2] = call_function[target=torch.ops.aten.tril.default](args = (%ones,), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%tril,), kwargs = {dtype: torch.float32})
    %logical_not : [num_users=1] = call_function[target=torch.ops.aten.logical_not.default](args = (%tril,), kwargs = {})
    %masked_fill_ : [num_users=1] = call_function[target=torch.ops.aten.masked_fill_.Scalar](args = (%zeros_like, %logical_not, -inf), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%arg1_1, -2, -1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%transpose, 0.668740304976422), kwargs = {})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul, [1, 1, 5, 5]), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand, [1, 5, 5]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_1, [1, 1, 5, 5]), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_1, [1, 5, 5]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view, %view_1), kwargs = {})
    %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm, [1, 1, 5, 5]), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %masked_fill_), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [1, 1, 5, 5]), kwargs = {})
    %view_3 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_2, [1, 5, 5]), kwargs = {})
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg2_1, [1, 1, 5, 5]), kwargs = {})
    %view_4 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_3, [1, 5, 5]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_3, %view_4), kwargs = {})
    %view_5 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm_1, [1, 1, 5, 5]), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_5, 1, 2), kwargs = {})
    %transpose_2 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%transpose_1, 1, 2), kwargs = {})
    return transpose_2
Brevitas Fx graph: 
 graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %arg2_1 : [#users=1] = placeholder[target=arg2_1]
    %_scaled_dot_product_flash_attention_default : [#users=9] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%arg0_1, %arg1_1, %arg2_1, 0.0, True), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 0), kwargs = {})
    %getitem_1 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 1), kwargs = {})
    %getitem_2 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 2), kwargs = {})
    %getitem_3 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 3), kwargs = {})
    %getitem_4 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 4), kwargs = {})
    %getitem_5 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 5), kwargs = {})
    %getitem_6 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 6), kwargs = {})
    %getitem_7 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 7), kwargs = {})
    %getitem_8 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 8), kwargs = {})
    return getitem
@vivekkhandelwal1
Copy link
Author

@volcacius, PTAL at this issue.
CC: @jinchen62

@volcacius
Copy link
Contributor

Which Pytorch version are you on? With 2.1.1 I see:

Torch Fx graph: 
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_scaled_dot_product_flash_attention : [num_users=9] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%arg0_1, %arg1_1, %arg2_1, 0.0, True), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 1), kwargs = {})
    %getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 2), kwargs = {})
    %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 3), kwargs = {})
    %getitem_4 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 4), kwargs = {})
    %getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 5), kwargs = {})
    %getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 6), kwargs = {})
    %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 7), kwargs = {})
    %getitem_8 : [num_users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention, 8), kwargs = {})
    return getitem
Brevitas Fx graph: 
 graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %arg2_1 : [#users=1] = placeholder[target=arg2_1]
    %_scaled_dot_product_flash_attention_default : [#users=9] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%arg0_1, %arg1_1, %arg2_1, 0.0, True), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 0), kwargs = {})
    %getitem_1 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 1), kwargs = {})
    %getitem_2 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 2), kwargs = {})
    %getitem_3 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 3), kwargs = {})
    %getitem_4 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 4), kwargs = {})
    %getitem_5 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 5), kwargs = {})
    %getitem_6 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 6), kwargs = {})
    %getitem_7 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 7), kwargs = {})
    %getitem_8 : [#users=0] = call_function[target=operator.getitem](args = (%_scaled_dot_product_flash_attention_default, 8), kwargs = {})
    return getitem

@vivekkhandelwal1
Copy link
Author

Which Pytorch version are you on? With 2.1.1 I see:

torch 2.2.0.dev20231115+cpu

@volcacius
Copy link
Contributor

It should be fixed in this PR #763 I just merged in dev. Thanks for spotting it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants