Skip to content

Commit

Permalink
Fix usage example
Browse files Browse the repository at this point in the history
  • Loading branch information
ohadravid committed Dec 22, 2024
1 parent 25075e2 commit c98e36d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Here's an example usage
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)
# Later, you can load it and run inference
Expand All @@ -52,7 +52,7 @@ b) Torchscript
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
# Later, you can load it and run inference
Expand All @@ -73,7 +73,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")
# Later, you can load it and run inference
Expand Down

0 comments on commit c98e36d

Please sign in to comment.