-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch-mlir] Support lowering of aten constraint ops #3943
base: main
Are you sure you want to change the base?
Conversation
praveen-g-ctt
commented
Jan 7, 2025
- aten::sym_constrain_range
- aten::sym_constrain_range_for_size
- aten::_assert_scalar
def forward(self, x): | ||
a = x.item() | ||
torch._check_is_size(a) | ||
# max should be >= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this comment. Why 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the comment in the test. This check is from the Aten native implementation.
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Constraints.cpp#L66
f59bd0b
to
1dd573c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I'd like to request a few small changes, and I'd also like to at least have a discussion for context on these ops and their relationship with torch.symbolic_int
and torch.bind_symbolic_shape
.
rewriter.create<cf::AssertOp>(loc, compareVal, | ||
rewriter.getStringAttr(assertMessage)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, this op was getting generated in some sdxl model we were importing to mlir? Are the sym_constrain_range_for_size
ops intended to provide symbolic shape constrains similar to symbolic_int
and bind_symbolic_shape
? For example, is it desirable for us to try to canonicalize patterns like:
func.func @sym_constrain(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],f32> {
%int10 = torch.constant.int 10
%int7 = torch.constant.int 7
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int-1 = torch.constant.int -1
%none = torch.constant.none
%dim0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
%dim1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
torch.aten.sym_constrain_range %dim0, %int1, %int10 : !torch.int, !torch.int, !torch.none
torch.aten.sym_constrain_range %dim1, %int1, %int7 : !torch.int, !torch.int, !torch.int
%list = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%view = torch.aten.view %arg0, %list : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
return %view : !torch.vtensor<[?],f32>
}
into something like:
func.func @sym_constrain(%arg0 : !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?],f32> {
%int-1 = torch.constant.int -1
%none = torch.constant.none
%symbolic_int0 = torch.symbolic_int "s0" {min_val = 1, max_val = 10} : !torch.int
%symbolic_int1 = torch.symbolic_int "s1" {min_val = 1, max_val = 7} : !torch.int
torch.bind_symbolic_shape %arg0, [%symbolic_int0, %symbolic_int1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%list = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%view = torch.aten.view %arg0, %list : !torch.vtensor<[?,?,3],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
return %view : !torch.vtensor<[?],f32>
}
I think it would be fine to have these kinds of ops lower to asserts as a backup, but IREE can take advantage of the symbolic shape ops by converting to IR like:
%c3 = arith.constant 3 : index
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?x3xf32>{%0, %1}
%3 = util.assume.int %0<umin = 1, umax = 10> : index
%4 = util.assume.int %1<umin = 1, umax = 7> : index
%5 = arith.muli %3, %4 : index
%6 = arith.muli %5, %c3 : index
%7 = flow.tensor.reshape %2 : tensor<?x?x3xf32>{%3, %4} -> tensor<?xf32>{%6}
%8 = hal.tensor.barrier join(%7 : tensor<?xf32>) => %arg2 : !hal.fence
%9 = hal.tensor.export %8 : tensor<?xf32>{%6} -> !hal.buffer_view
util.return %9 : !hal.buffer_view
Which can then be used to inform the compiler about appropriate intrinsics/tiling sizes for things like matmuls.
This comment doesn't necessarily need to be addressed here, but if we can get some context as to why these ops are introduced in examples, it might be helpful for improvements in the near future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the sym_constrain_range_for_size ops intended to provide symbolic shape constrains similar to symbolic_int and bind_symbolic_shape?
-> @zjgarvey Yes, the sym_constrain_range_for_size op is intended to specify the symbolic shape constrains for the scalar values.
The constraints specified in this op is evaluated along with other expressions used in the symbol evaluation such as min/max values for dynamic shaped tensors. If the value is assumed to be within the range of the specified constrains, the sym_constrain_range_for_size node gets removed altogether while generating the graph.
def test_constraints():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
a = y.shape[0]
torch.sym_constrain_range_for_size(a, max=11)
return torch.broadcast_to(x, (a, -1))
# Sample inputs
x = torch.randn(1, 4)
y = torch.randn(8)
dim_0 = Dim("dim_0", max=10)
dynamic_shapes = {
"x": {},
"y": {0: dim_0},
}
m = fx.export_and_import(
Basic(),
x,
y,
dynamic_shapes=dynamic_shapes,
func_name="test_constraints",
import_symbolic_shape_expressions=True,
)
print(m)
This comment doesn't necessarily need to be addressed here, but if we can get some context as to why these ops are introduced in examples, it might be helpful for improvements in the near future.
-> Some of the checks such as torch._check_is_size() introduces sym_constrain_range_for_size op. I was not able to generate more examples which includes other operations involving symbolic constraints along with this. I guess we can investigate further regarding this with more examples.
} | ||
|
||
if (!isa<Torch::ConstantNoneOp>(maxOp)) { | ||
// Verify that max value is greater than 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I have the same question as @mgehre-amd . Is it possible for the max value to be 0? This would just mean the symbolic value is actually 0, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zjgarvey @mgehre-amd For the sym_constrain_range_for_size op, there is an assumption that the max value would be greater than 2 when size is used. The generic variant sym_constrain_range does not have this requirement.
https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.constrain_range.html
https://github.com/pytorch/pytorch/blob/f7000350905be5073892e0b23df681c0281be0f0/torch/__init__.py#L2682
def forward(self, x): | ||
a = x.item() | ||
# The below checks introduces aten._assert_scalar op | ||
torch._check_is_size(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this check in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _check_is_size generates the _assert_scalar op with an assertion message. I have modified the test case to invoke the aten op directly with a custom assertion message
1. aten::sym_constrain_range 2. aten::sym_constrain_range_for_size 3. aten::_assert_scalar
f42e9ba
to
cc7de9a
Compare