Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-mlir] Support lowering of aten constraint ops #3943

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

praveen-g-ctt
Copy link

  1. aten::sym_constrain_range
  2. aten::sym_constrain_range_for_size
  3. aten::_assert_scalar

def forward(self, x):
a = x.item()
torch._check_is_size(a)
# max should be >= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment. Why 2?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the comment in the test. This check is from the Aten native implementation.

https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Constraints.cpp#L66

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I'd like to request a few small changes, and I'd also like to at least have a discussion for context on these ops and their relationship with torch.symbolic_int and torch.bind_symbolic_shape.

Comment on lines +3627 to +3623
rewriter.create<cf::AssertOp>(loc, compareVal,
rewriter.getStringAttr(assertMessage));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, this op was getting generated in some sdxl model we were importing to mlir? Are the sym_constrain_range_for_size ops intended to provide symbolic shape constrains similar to symbolic_int and bind_symbolic_shape? For example, is it desirable for us to try to canonicalize patterns like:

func.func @sym_constrain(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],f32> {
    %int10 = torch.constant.int 10
    %int7 = torch.constant.int 7
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int-1 = torch.constant.int -1
    %none = torch.constant.none
    %dim0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
    %dim1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
    torch.aten.sym_constrain_range %dim0, %int1, %int10 : !torch.int, !torch.int, !torch.none
    torch.aten.sym_constrain_range %dim1, %int1, %int7 : !torch.int, !torch.int, !torch.int
    %list = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
    %view = torch.aten.view %arg0, %list : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
    return %view : !torch.vtensor<[?],f32>
}

into something like:

func.func @sym_constrain(%arg0 : !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?],f32> {
    %int-1 = torch.constant.int -1
    %none = torch.constant.none
    %symbolic_int0 = torch.symbolic_int "s0" {min_val = 1, max_val = 10} : !torch.int
    %symbolic_int1 = torch.symbolic_int "s1" {min_val = 1, max_val = 7} : !torch.int
    torch.bind_symbolic_shape %arg0, [%symbolic_int0, %symbolic_int1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
    %list = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
    %view = torch.aten.view %arg0, %list : !torch.vtensor<[?,?,3],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
    return %view : !torch.vtensor<[?],f32>
}

I think it would be fine to have these kinds of ops lower to asserts as a backup, but IREE can take advantage of the symbolic shape ops by converting to IR like:

    %c3 = arith.constant 3 : index
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?x3xf32>{%0, %1}
    %3 = util.assume.int %0<umin = 1, umax = 10> : index
    %4 = util.assume.int %1<umin = 1, umax = 7> : index
    %5 = arith.muli %3, %4 : index
    %6 = arith.muli %5, %c3 : index
    %7 = flow.tensor.reshape %2 : tensor<?x?x3xf32>{%3, %4} -> tensor<?xf32>{%6}
    %8 = hal.tensor.barrier join(%7 : tensor<?xf32>) => %arg2 : !hal.fence
    %9 = hal.tensor.export %8 : tensor<?xf32>{%6} -> !hal.buffer_view
    util.return %9 : !hal.buffer_view

Which can then be used to inform the compiler about appropriate intrinsics/tiling sizes for things like matmuls.

This comment doesn't necessarily need to be addressed here, but if we can get some context as to why these ops are introduced in examples, it might be helpful for improvements in the near future.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the sym_constrain_range_for_size ops intended to provide symbolic shape constrains similar to symbolic_int and bind_symbolic_shape?

-> @zjgarvey Yes, the sym_constrain_range_for_size op is intended to specify the symbolic shape constrains for the scalar values.

The constraints specified in this op is evaluated along with other expressions used in the symbol evaluation such as min/max values for dynamic shaped tensors. If the value is assumed to be within the range of the specified constrains, the sym_constrain_range_for_size node gets removed altogether while generating the graph.

def test_constraints():
    class Basic(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x, y):
            a = y.shape[0]
            torch.sym_constrain_range_for_size(a, max=11)
            return torch.broadcast_to(x, (a, -1))

    # Sample inputs
    x = torch.randn(1, 4)
    y = torch.randn(8)

    dim_0 = Dim("dim_0", max=10)
    dynamic_shapes = {
        "x": {},
        "y": {0: dim_0},
    }

    m = fx.export_and_import(
        Basic(),
        x,
        y,
        dynamic_shapes=dynamic_shapes,
        func_name="test_constraints",
        import_symbolic_shape_expressions=True,
    )
    print(m)

This comment doesn't necessarily need to be addressed here, but if we can get some context as to why these ops are introduced in examples, it might be helpful for improvements in the near future.

-> Some of the checks such as torch._check_is_size() introduces sym_constrain_range_for_size op. I was not able to generate more examples which includes other operations involving symbolic constraints along with this. I guess we can investigate further regarding this with more examples.

lib/Conversion/TorchToLinalg/Uncategorized.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToLinalg/Uncategorized.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToLinalg/Uncategorized.cpp Outdated Show resolved Hide resolved
lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp Outdated Show resolved Hide resolved
}

if (!isa<Torch::ConstantNoneOp>(maxOp)) {
// Verify that max value is greater than 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have the same question as @mgehre-amd . Is it possible for the max value to be 0? This would just mean the symbolic value is actually 0, correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey @mgehre-amd For the sym_constrain_range_for_size op, there is an assumption that the max value would be greater than 2 when size is used. The generic variant sym_constrain_range does not have this requirement.

https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.constrain_range.html
https://github.com/pytorch/pytorch/blob/f7000350905be5073892e0b23df681c0281be0f0/torch/__init__.py#L2682

def forward(self, x):
a = x.item()
# The below checks introduces aten._assert_scalar op
torch._check_is_size(a)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check in the test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _check_is_size generates the _assert_scalar op with an assertion message. I have modified the test case to invoke the aten op directly with a custom assertion message

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants