forked from byfate/Stronger-yolo-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_prune.py
65 lines (62 loc) · 2.51 KB
/
main_prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from models import *
from trainers import *
import json
from yacscfg import _C as cfg
import os
from torch import optim
import argparse
import numpy as np
from thop import clever_format,profile
from pruning.l1norm import l1normPruner
from pruning.slimming import SlimmingPruner
from mmcv.runner import load_checkpoint
import torch
def main(args):
gpus=[str(g) for g in args.devices]
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpus)
model = eval(cfg.MODEL.modeltype)(cfg=args.MODEL).cuda().eval()
newmodel = eval(cfg.MODEL.modeltype)(cfg=args.MODEL).cuda().eval()
optimizer = optim.Adam(model.parameters(),lr=args.OPTIM.lr_initial)
scheduler=optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.OPTIM.milestones, gamma=0.1)
_Trainer = eval('Trainer_{}'.format(args.DATASET.dataset))(args=args,
model=model,
optimizer=optimizer,
lrscheduler=scheduler
)
pruner=SlimmingPruner(_Trainer,newmodel,cfg=args.Prune)
# pruner=l1normPruner(_Trainer,newmodel,pruneratio=0.)
pruner.prune()
##---------count op
input=torch.randn(1,3,512,512).cuda()
flops, params = profile(model, inputs=(input, ),verbose=False)
flops, params = clever_format([flops, params], "%.3f")
flopsnew, paramsnew = profile(newmodel, inputs=(input, ),verbose=False)
flopsnew, paramsnew = clever_format([flopsnew, paramsnew], "%.3f")
print("flops:{}->{}, params: {}->{}".format(flops,flopsnew,params,paramsnew))
if not args.Prune.do_test:
resultold=pruner.test(newmodel=False,validiter=10)
resultnew=pruner.test(newmodel=True,validiter=10)
print("original map:{},pruned map:{}".format(resultold,resultnew))
bestfinetune=pruner.finetune()
print("finetuned map:{}".format(bestfinetune))
else:
load_checkpoint(newmodel, torch.load(os.path.join(_Trainer.save_path,'checkpoint-best-ft{}.pth'.format(args.Prune.pruneratio))))
pruner.test(newmodel=True,validiter=-1)
#
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="DEMO configuration")
parser.add_argument(
"--config-file",
default='configs/strongerv3_prune.yaml'
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
main(args=cfg)