-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
172 lines (141 loc) · 5.33 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# evaluate.py
import argparse
import torch
import logging
import yaml
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from segmentation import (
SegmentationConfig,
SegmentationDataset,
get_dataloader,
UNet,
FCN,
LinkNet,
DeepLabV3,
Visualizer,
SegmentationMetrics
)
def setup_logging(save_dir: str):
"""Setup logging configuration."""
log_file = Path(save_dir) / f'evaluation_{datetime.now():%Y%m%d_%H%M%S}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(str(log_file)),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
def get_model(model_type: str, config: SegmentationConfig):
"""Get model based on type."""
models = {
'fcn': FCN,
'linknet': LinkNet,
'unet': UNet,
'deeplabv3': DeepLabV3
}
if model_type not in models:
raise ValueError(f"Model type {model_type} not supported. Choose from {list(models.keys())}")
return models[model_type](config)
def evaluate(config: SegmentationConfig, model_path: str, save_dir: str):
"""Evaluate model on test set."""
logger = setup_logging(save_dir)
device = torch.device(config.DEVICE)
# Create save directory
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# Load model
logger.info(f"Loading model from {model_path}")
model = get_model(config.MODEL_TYPE, config)
try:
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
model = model.to(device)
model.eval()
# Create dataset and dataloader
logger.info("Creating test dataset")
test_dataset = SegmentationDataset(config, mode='val') # Using val set for evaluation
test_loader = get_dataloader(test_dataset, config)
# Initialize metrics and visualizer
metrics = SegmentationMetrics()
visualizer = Visualizer(config)
# Evaluation loop
logger.info("Starting evaluation")
total_metrics = {'dice': 0, 'iou': 0}
examples = []
try:
with torch.no_grad():
for i, (images, masks) in enumerate(tqdm(test_loader, desc='Evaluating')):
# Move to device
images = images.to(device)
masks = masks.to(device)
# Forward pass
outputs = model(images)
# Calculate metrics
batch_metrics = metrics.calculate_metrics(outputs, masks)
# Update metrics
for k in total_metrics:
total_metrics[k] += batch_metrics[k]
# Save some examples for visualization
if i < 5: # Save first 5 batches
examples.append((images.cpu(), masks.cpu(), outputs.cpu()))
# Clear some memory
if device.type == 'cuda':
torch.cuda.empty_cache()
# Calculate average metrics
num_batches = len(test_loader)
avg_metrics = {k: v/num_batches for k, v in total_metrics.items()}
# Save results
logger.info("Saving results")
# Save metrics
metrics_file = save_dir / 'metrics.txt'
with open(metrics_file, 'w') as f:
for k, v in avg_metrics.items():
f.write(f'{k}: {v:.4f}\n')
logger.info(f"Metrics saved to {metrics_file}")
# Save visualizations
for i, (images, masks, outputs) in enumerate(examples):
save_path = save_dir / f'example_{i}.png'
visualizer.visualize_batch(
images, masks, outputs,
save_path=save_path
)
logger.info(f"Visualization {i+1} saved to {save_path}")
logger.info("Evaluation completed successfully")
return avg_metrics
except Exception as e:
logger.error(f"Error during evaluation: {str(e)}")
raise
def main():
parser = argparse.ArgumentParser(description='Evaluate segmentation model')
parser.add_argument('--config', type=str, required=True,
help='path to config file')
parser.add_argument('--model-path', type=str, required=True,
help='path to model checkpoint')
parser.add_argument('--save-dir', type=str, default='evaluation_results',
help='directory to save results')
args = parser.parse_args()
# Load configuration
with open(args.config, 'r') as f:
config_dict = yaml.safe_load(f)
config = SegmentationConfig()
for k, v in config_dict.items():
setattr(config, k, v)
# Run evaluation
try:
metrics = evaluate(config, args.model_path, args.save_dir)
print("\nEvaluation Results:")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
except Exception as e:
print(f"Evaluation failed: {str(e)}")
raise
if __name__ == '__main__':
main()