Skip to content

Commit

Permalink
chore: use ir flag
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Jan 30, 2024
1 parent 793f17b commit 0ce1c59
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions tests/py/dynamo/models/test_output_format.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import unittest

import pytest
import timm
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()

Expand All @@ -31,15 +28,15 @@ def forward(self, x):
model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

trt_ep = torchtrt.compile(model, ir="dynamo", inputs=[input], min_block_size=1)
trt_ep = torchtrt.compile(model, ir=ir, inputs=[input], min_block_size=1)
assertions.assertTrue(
isinstance(trt_ep, torch.export.ExportedProgram),
msg=f"test_output_format output type does not match with torch.export.ExportedProgram",
)

trt_ts = torchtrt.compile(
model,
ir="dynamo",
ir=ir,
inputs=[input],
min_block_size=1,
output_format="torchscript",
Expand All @@ -51,7 +48,7 @@ def forward(self, x):

trt_gm = torchtrt.compile(
model,
ir="dynamo",
ir=ir,
inputs=[input],
min_block_size=1,
output_format="graph_module",
Expand All @@ -60,3 +57,5 @@ def forward(self, x):
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_output_format output type does not match with torch.fx.GraphModule",
)
# Clean up model env
torch._dynamo.reset()

0 comments on commit 0ce1c59

Please sign in to comment.