-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRCNN.py
116 lines (85 loc) · 3.48 KB
/
RCNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# encoding: utf-8
# Author : Floyed<[email protected]>
# Datetime : 2022/7/26 18:56
# User : Floyed
# Product : PyCharm
# Project : BrainCog
# File : vgg_snn.py
# explain :
from functools import partial
from torch.nn import functional as F
import torchvision,pprint
from copy import deepcopy
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.model_zoo.resnet19_snn import *
from braincog.model_zoo.resnet import resnet34
from braincog.model_zoo.sew_resnet import *
from braincog.model_zoo.vgg_snn import *
from braincog.datasets import is_dvs_data
def n_detach(self):
self.mem=self.mem.detach()
self.spike=self.spike.detach()
def detach(self):
for mod in self.modules():
if hasattr(mod, 'n_detach'):
mod.n_detach()
BaseNode.n_detach=n_detach
BaseModule.detach=detach
def n_deepcopy(self,ori):
self.mem=ori.mem.clone().detach()
self.spike=ori.spike.clone().detach()
#print(self.mem.shape)
#return copy_node
def m_deepcopy(self,ori):
for mod,orimod in zip(self.modules(),ori.modules()):
if hasattr(mod, 'n_deepcopy') and hasattr(orimod, 'n_deepcopy'):
mod.n_deepcopy(orimod)
BaseNode.n_deepcopy=n_deepcopy
BaseModule.deepcopy=m_deepcopy
@register_model
class metarightsltet(BaseModule):
def __init__(self,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False
self.num_classes = num_classes
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
self.kdloss=nn.KLDivLoss()
self.learner=eval(kwargs['learner'])(num_classes=num_classes,
step=step,
node_type=LIFNode,
encode_type='direct',
sum_output=False,
reshape_output=False,
*args,
**kwargs)
self.copyopt=kwargs["copyopt"]
self.loc=kwargs["loc"]
def forward(self, inputs,target=None,loss_fn=None,softloss_fn=None):
self.target=target
self.loss_fn=loss_fn
self.softloss_fn=softloss_fn
outputs = self.learner(inputs)
outputs=[i for i in outputs]
if self.training:
outputstensor=torch.stack(outputs).transpose(0,1)
outputsmask=(outputstensor.max(2)[1]==self.target.unsqueeze(1))
softoutputs=(outputstensor*outputsmask.unsqueeze(2)).sum(1)
#print(outputsmask.float().mean())
#outputs=sum(outputs) / len(outputs)
loss1=self.loss_fn(sum(outputs)/ len(outputs),self.target)
loss2=sum([self.softloss_fn(i,softoutputs) for i in outputs])/ len(outputs)
return sum(outputs) / len(outputs),loss1,loss2
return sum(outputs ) / len(outputs )