-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
40 lines (29 loc) · 1.21 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import cv2
from config import NUM_CLASSES, IN_CHANNELS
from model import MRIModel
from argparse import ArgumentParser
def inference_model(model, image):
with torch.no_grad():
model.eval()
btch = torch.transpose(torch.Tensor(image).unsqueeze(0), -1, 1)
logits = model(btch)
pr_mask = logits.sigmoid()
return np.array(pr_mask.squeeze(0)[0, :, :])
def main():
parser = ArgumentParser()
parser.add_argument('--model', type=str, required=True, help='Model architecture')
parser.add_argument('--backbone', type=str, required=True, help='Model backbone from segmentation-models-pytorch')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--image', type=str, required=True, help='Input image')
args = parser.parse_args()
model = MRIModel(args.model, args.backbone, IN_CHANNELS, NUM_CLASSES)
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint["state_dict"])
image = cv2.imread(args.image)
pr_mask = inference_model(model, image)
cv2.imshow('Image', image)
cv2.imshow('Pred Mask', pr_mask)
cv2.waitKey(0)
if __name__ == '__main__':
main()