-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_exported_pose_estimator.py
34 lines (23 loc) · 1.35 KB
/
test_exported_pose_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import torch
from common.test import load_exported_model
from pose_estimation.metrics import CocoPoseEvaluation
from pose_estimation.datasets import PoseEstimationCoco
from pose_estimation.trainers.pose_estimator_trainer import create_validation_image_transform
def main():
parser = argparse.ArgumentParser(description='Test exported pose estimator')
parser.add_argument('--dataset_root', type=str, help='Choose the dataset root path', required=True)
parser.add_argument('--torch_script_path', type=str, help='Choose the TorchScript path')
parser.add_argument('--trt_path', type=str, help='Choose the TensorRT path')
parser.add_argument('--output_path', type=str, help='Choose the output path', required=True)
args = parser.parse_args()
model, device = load_exported_model(args.torch_script_path, args.trt_path)
dataset = PoseEstimationCoco(args.dataset_root,
train=False,
data_augmentation=False,
image_transforms=create_validation_image_transform())
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
coco_pose_evaluation = CocoPoseEvaluation(model, device, dataset_loader, args.output_path)
coco_pose_evaluation.evaluate()
if __name__ == '__main__':
main()