forked from ristea/aed-mae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
136 lines (121 loc) · 5.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import datetime
import json
import os
import time
from pathlib import Path
from timm.optim import optim_factory
from timm.utils import NativeScaler
from torch.utils.tensorboard import SummaryWriter
from configs.configs import get_configs_avenue, get_configs_shanghai
from data.test_dataset import AbnormalDatasetGradientsTest
from data.train_dataset import AbnormalDatasetGradientsTrain
from engine_train import train_one_epoch, test_one_epoch
from inference import inference
from model.model_factory import mae_cvt_patch16, mae_cvt_patch8
from util import misc
import torch
def main(args):
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
log_writer = SummaryWriter(log_dir=args.output_dir)
device = args.device
if args.run_type =='train':
dataset_train = AbnormalDatasetGradientsTrain(args)
print(dataset_train)
sampler_train = torch.utils.data.RandomSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False,
)
dataset_test = AbnormalDatasetGradientsTest(args)
print(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, num_workers=args.num_workers,
pin_memory=args.pin_mem, drop_last=False,
)
# define the model
if args.dataset == 'avenue':
model = mae_cvt_patch16(norm_pix_loss=args.norm_pix_loss, img_size=args.input_size,
use_only_masked_tokens_ab=args.use_only_masked_tokens_ab,
abnormal_score_func=args.abnormal_score_func,
masking_method=args.masking_method,
grad_weighted_loss=args.grad_weighted_rec_loss).float()
else:
model = mae_cvt_patch8(norm_pix_loss=args.norm_pix_loss, img_size=args.input_size,
use_only_masked_tokens_ab=args.use_only_masked_tokens_ab,
abnormal_score_func=args.abnormal_score_func,
masking_method=args.masking_method,
grad_weighted_loss=args.grad_weighted_rec_loss).float()
model.to(device)
if args.run_type == "train":
do_training(args, data_loader_test, data_loader_train, device, log_writer, model)
elif args.run_type == "inference":
student = torch.load(args.output_dir + "/checkpoint-best-student.pth")['model']
teacher = torch.load(args.output_dir + "/checkpoint-best.pth")['model']
for key in student:
if 'student' in key:
teacher[key] = student[key]
model.load_state_dict(teacher, strict=False)
with torch.no_grad():
inference(model, data_loader_test, device, args=args)
def do_training(args, data_loader_test, data_loader_train, device, log_writer, model):
print("actual lr: %.2e" % args.lr)
# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.param_groups_weight_decay(model, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()
misc.load_model(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
best_micro = 0.0
best_micro_student = 0.0
for epoch in range(args.start_epoch, args.epochs):
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch,
log_writer=log_writer,
args=args
)
log_stats_train = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch}
test_stats = test_one_epoch(
model, data_loader_test, device, epoch, log_writer=log_writer, args=args
)
log_stats_test = {**{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch}
if args.output_dir:
misc.save_model(args=args, model=model, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, latest=True)
if test_stats['micro'] > best_micro:
best_micro = test_stats['micro']
misc.save_model(args=args, model=model, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, best=True)
if args.start_TS_epoch <= epoch:
if test_stats['micro'] > best_micro_student:
best_micro_student = test_stats['micro']
misc.save_model(args=args, model=model, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, best=True, student=True)
if args.output_dir:
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log_train.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats_train) + "\n")
with open(os.path.join(args.output_dir, "log_test.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats_test) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='avenue')
args = parser.parse_args()
if args.dataset == 'avenue':
args = get_configs_avenue()
else:
args = get_configs_shanghai()#
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)