From 0ce1c5929b01f578f4aed641b2e141594f53ffc2 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 30 Jan 2024 15:14:52 -0800 Subject: [PATCH] chore: use ir flag --- tests/py/dynamo/models/test_output_format.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/py/dynamo/models/test_output_format.py b/tests/py/dynamo/models/test_output_format.py index 3d2e747ceb..5f2bdedf07 100644 --- a/tests/py/dynamo/models/test_output_format.py +++ b/tests/py/dynamo/models/test_output_format.py @@ -1,11 +1,8 @@ import unittest import pytest -import timm import torch import torch_tensorrt as torchtrt -import torchvision.models as models -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -31,7 +28,7 @@ def forward(self, x): model = MyModule().eval().cuda() input = torch.randn((1, 3, 224, 224)).to("cuda") - trt_ep = torchtrt.compile(model, ir="dynamo", inputs=[input], min_block_size=1) + trt_ep = torchtrt.compile(model, ir=ir, inputs=[input], min_block_size=1) assertions.assertTrue( isinstance(trt_ep, torch.export.ExportedProgram), msg=f"test_output_format output type does not match with torch.export.ExportedProgram", @@ -39,7 +36,7 @@ def forward(self, x): trt_ts = torchtrt.compile( model, - ir="dynamo", + ir=ir, inputs=[input], min_block_size=1, output_format="torchscript", @@ -51,7 +48,7 @@ def forward(self, x): trt_gm = torchtrt.compile( model, - ir="dynamo", + ir=ir, inputs=[input], min_block_size=1, output_format="graph_module", @@ -60,3 +57,5 @@ def forward(self, x): isinstance(trt_gm, torch.fx.GraphModule), msg=f"test_output_format output type does not match with torch.fx.GraphModule", ) + # Clean up model env + torch._dynamo.reset()