From 3851e4ec86a9e23702d7cc4481fe9a9a56cd9a1e Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Mon, 2 Dec 2024 08:54:54 -0800 Subject: [PATCH] Add conv2d e2e test from convnext model --- .../torch_mlir_e2e_test/test_suite/conv.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index e6332579d575..ad5dc6064a71 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -256,6 +256,38 @@ def Convolution2DStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) +class Convolution2DNextStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 80, 72, 72], torch.float32, True), + ([80, 1, 7, 7], torch.float32, True), + ([80], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.convolution( + inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[3, 3], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=80, + ) + + +@register_test_case(module_factory=lambda: Convolution2DNextStaticModule()) +def Convolution2DNextStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 80, 72, 72), tu.rand(80, 1, 7, 7), tu.rand(80)) + + class Convolution2DStridedModule(torch.nn.Module): def __init__(self): super().__init__()