-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_model.py
137 lines (117 loc) · 3.94 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import argparse
from datetime import datetime
import wandb
import torch
from torch.optim import Adam
from wdd.model.cnn_spp import CNN_SPP_Net,make_spp_training_net
from wdd.model.model_training import train_model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--learning_rate", type=float,
required = True,
help='learning rate for optimizer')
parser.add_argument(
"--weight_decay", type=float,
required = True,
help='weight decay for optimizer')
parser.add_argument(
"--batch_step_size", type=int,
required = True,
help='batch_step_size for optimizer')
parser.add_argument(
"--num_cnn_layers", type=int,
required = True,
help='number of cnn layers')
parser.add_argument(
"--num_spp_outputs", type=int,
required = True,
help='number of spp output channels')
parser.add_argument(
"--num_linear_layers", type=int,
required = True,
help='number of linear layers')
parser.add_argument(
"--transform_prob_threshold", type = float,
required = True,
help = "probability threshold for transform")
parser.add_argument(
"--epochs", type = int,
required = True,
help = "number of epochs.")
parser.add_argument(
"--name", type=str,
required=True,
help = "base name for wandb logging")
parser.add_argument(
"--outpath", type=str,
required=True,
help = "outpath for wandb logging")
parser.add_argument(
"--binary", type=int,
required=True,
help="run binary classifier or not")
parser.add_argument(
"--just_defects", type=int,
required=True,
help="train on just defects or or not")
args = parser.parse_args()
cnn_channels=tuple(2**(i) for i in range(args.num_cnn_layers))
spp_output_sizes=[(1+2*i,1+2*i) for i in range(args.num_spp_outputs)]
if args.binary==True:
linear_output_sizes=tuple(2*2**(i-1) for i in range(args.num_linear_layers,0,-1))
elif args.just_defects==True:
linear_output_sizes=tuple(8*2**(i-1) for i in range(args.num_linear_layers,0,-1))
else:
linear_output_sizes=tuple(9*2**(i-1) for i in range(args.num_linear_layers,0,-1))
model_parameters=dict(
cnn_channels=cnn_channels,
spp_output_sizes=spp_output_sizes,
linear_output_sizes=linear_output_sizes,
)
config=dict(
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
batch_step_size=args.batch_step_size,
num_cnn_layers=args.num_cnn_layers,
num_spp_outputs=args.num_spp_outputs,
num_linear_layers=args.num_linear_layers,
transform_prob_threshold=args.transform_prob_threshold,
epochs=args.epochs,
model_parameters=model_parameters,
use_cuda=False,
binary=args.binary,
just_defects=args.just_defects,
)
net=make_spp_training_net(config)
learning_rate=config['learning_rate']
weight_decay=config['weight_decay']
optimizer = Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
#init wandb logging
outpath=args.outpath
name=f'{args.name}_{datetime.now().isoformat()}'
wandb.init(
project='wafer-defect-detection',
config=config,
name=name,
dir=outpath,
)
wandb.watch(
net,
log='all',
log_freq=10,
)
wandb.define_metric("training_loss", summary="min")
wandb.define_metric("validation_loss", summary="min")
wandb.define_metric("balanced_f1", summary="max")
wandb.define_metric("exp_avg_validation_loss",summary="min")
train_model(
net,
optimizer,
args.epochs,
name,
log=True,
)
torch.save(net.state_dict(),os.path.join(wandb.run.dir, name+'.pt'))
wandb.finish()