Skip to content

Commit

Permalink
Check torch onnx export validity (#78)
Browse files Browse the repository at this point in the history
* add torch export verifier

* Add comments

* lint fix

* resolve comments

* fix lint
  • Loading branch information
ramkrishna2910 authored Jan 9, 2024
1 parent a52414c commit 8204801
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
40 changes: 37 additions & 3 deletions src/turnkeyml/build/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
from typing import Union
import torch
import torch.onnx.verification
import numpy as np
import onnxruntime
import onnxmltools
Expand Down Expand Up @@ -279,9 +280,45 @@ def fire(self, state: build.State):
default_warnings = warnings.showwarning
warnings.showwarning = _warn_to_stdout

stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)

# Verify if the exported model matches the input torch model
try:
# The `torch.onnx.verification.find_mismatch()` takes input arguments to the
# model as `input_args (Tuple[Any, ...])`
export_verification = torch.onnx.verification.find_mismatch(
state.model,
tuple(state.inputs.values()),
opset_version=state.config.onnx_opset)

# `export_verification.has_mismatch()` returns True if a mismatch is found and
# False otherwise. If no mismatch is found,# `is_export_valid` is set to "Valid",
# indicating successful verification.
# If a mismatch is found, `is_export_valid` is set to "Invalid", indicating
# the verification failed.
if not export_verification.has_mismatch():
is_export_valid = "valid"
else:
is_export_valid = "invalid"

# The except block catches any type of exception that might occur during the
# verification process. If any exception occurs,`is_export_valid` is set to
# "Unverified", indicating that the verification process could not be completed,
# and therefore the model's export status is unverified.
except Exception: # pylint: disable=broad-except
is_export_valid = "unverified"

stats.save_model_eval_stat(
fs.Keys.TORCH_ONNX_EXPORT_VALIDITY,
is_export_valid,
)

# Export the model to ONNX
output_path = base_onnx_file(state)
os.makedirs(onnx_dir(state), exist_ok=True)

torch.onnx.export(
state.model,
dummy_inputs,
Expand Down Expand Up @@ -309,9 +346,6 @@ def fire(self, state: build.State):
if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
Expand Down
4 changes: 3 additions & 1 deletion src/turnkeyml/common/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ class Keys:
BUILD_STATUS = "build_status"
# Indicates status of the most recent benchmark tool run: FunctionStatus
BENCHMARK_STATUS = "benchmark_status"

# Indicates the match between the TorchScript IR graph and
# the exported onnx model (verified with torch.onnx.verification)
TORCH_ONNX_EXPORT_VALIDITY = "torch_export_validity"

class FunctionStatus:
RUNNING = "running"
Expand Down

0 comments on commit 8204801

Please sign in to comment.