-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
143 lines (104 loc) · 5.99 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import copy
import os
import torch
import SETTINGS
from faultManager.FaultListManager import FLManager
from faultManager.FaultInjectionManager import FaultInjectionManager
from ofmapManager.OutputFeatureMapsManager import OutputFeatureMapsManager
from utils import get_network, get_device, get_loader, get_fault_list, clean_inference, output_definition, \
get_fault_list, clean_inference, output_definition, fault_list_gen, csv_summary
def main():
if SETTINGS.FAULT_LIST_GENERATION:
fault_list_gen()
else:
print('Fault list generation is disabled')
if SETTINGS.FAULTS_INJECTION or SETTINGS.ONLY_CLEAN_INFERENCE:
# Set deterministic algorithms
torch.use_deterministic_algorithms(mode=True)
# Select the device
device = get_device(use_cuda0=SETTINGS.USE_CUDA_0,
use_cuda1=SETTINGS.USE_CUDA_1)
print(f'Using device {device}')
# Load the dataset
_, loader = get_loader(network_name=SETTINGS.NETWORK,
batch_size=SETTINGS.BATCH_SIZE,
dataset_name=SETTINGS.DATASET)
# Load the network
network = get_network(network_name=SETTINGS.NETWORK,
device=device,
dataset_name=SETTINGS.DATASET)
if SETTINGS.ONLY_CLEAN_INFERENCE:
print('clean inference accuracy test:')
clean_inference(network, loader, device, SETTINGS.NETWORK)
exit(-1)
print('clean inference accuracy test:')
clean_inference(network, loader, device, SETTINGS.NETWORK)
# Folder containing the feature maps
clean_fm_folder = SETTINGS.CLEAN_FM_FOLDER
faulty_fm_folder = SETTINGS.FAULTY_FM_FOLDER
os.makedirs(clean_fm_folder, exist_ok=True)
os.makedirs(faulty_fm_folder, exist_ok=True)
# Folder containing the clean output
clean_output_folder = SETTINGS.CLEAN_OUTPUT_FOLDER
#attenzione a module_classes che mi salva ofm diverse!
module_classes = SETTINGS.MODULE_CLASSES
feature_maps_layer_names = [name.replace('.weight', '') for name, module in network.named_modules()
if isinstance(module, module_classes)]
print('feature maps layer names:')
print(feature_maps_layer_names)
clean_ofm_manager = OutputFeatureMapsManager(network=network,
loader=loader,
module_classes=SETTINGS.MODULE_CLASSES,
device=device,
fm_folder=clean_fm_folder,
clean_output_folder=clean_output_folder)
# Try to load the clean input
clean_ofm_manager.load_clean_output()
# Generate fault list
fault_list_generator = FLManager(network=network,
network_name=SETTINGS.NETWORK,
device=device,
module_class=SETTINGS.MODULE_CLASSES_FAULT_LIST,
input_size=loader.dataset[0][0].unsqueeze(0).shape,
save_ifm=True)
# Create a smart network. a copy of the network with its convolutional layers replaced by their smart counterpart
# smart_network = copy.deepcopy(network)
# fault_list_generator.update_network(network)
# Manage the fault models
fault_list, injectable_modules = get_fault_list(fault_model=SETTINGS.FAULT_MODEL,
fault_list_generator=fault_list_generator)
# Execute the fault injection campaign with the smart network
fault_injection_executor = FaultInjectionManager(network=network,
network_name=SETTINGS.NETWORK,
device=device,
loader=loader,
clean_output=clean_ofm_manager.clean_output,
injectable_modules=injectable_modules)
fault_injection_executor.run_faulty_campaign_on_weight(fault_model=SETTINGS.FAULT_MODEL,
fault_list=fault_list,
first_batch_only=False,
force_n=SETTINGS.FAULTS_TO_INJECT,
save_output=SETTINGS.SAVE_FAULTY_OUTPUT,
save_ofm=SETTINGS.SAVE_FAULTY_OFM,
ofm_folder=faulty_fm_folder)
else:
print('Fault injection is disabled')
if SETTINGS.FI_ANALYSIS:
try:
output_definition(test_loader=loader, batch_size=SETTINGS.BATCH_SIZE)
print('Done')
except:
print('No loader found to save the labels, creating a new one')
_, loader = get_loader(network_name=SETTINGS.NETWORK,
batch_size=SETTINGS.BATCH_SIZE,
dataset_name=SETTINGS.DATASET)
output_definition(test_loader=loader, batch_size=SETTINGS.BATCH_SIZE)
print('Done')
else:
print('Fault injection analysis is disabled')
if SETTINGS.FI_ANALYSIS_SUMMARY:
print('Generating csv summary')
csv_summary()
print('csv summary generated')
if __name__ == '__main__':
main()