-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathAllCnn.py
97 lines (70 loc) · 2.78 KB
/
AllCnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
############################################################################################
#
# Project: Asociacion De Investigacion En Inteligencia Artificial Para La Leucemia Peter Moss
# Repository: ALL-IDB Classifiers
#
# Author: Adam Milton-Barker
# Contributors:
#
# Title: AllCnn Wrapper Class
# Description: Core AllCnn wrapper class for the ALL-IDB Classifiers project.
# License: MIT License
# Last Modified: 2019-07-23
#
############################################################################################
import sys
from Classes.Helpers import Helpers
from Classes.DataP1 import Data as DataP1
from Classes.ModelP1 import Model as ModelP1
class AllCnn():
""" ALL Papers AllCnn Wrapper Class
Core AllCnn wrapper class for the Tensorflow 2.0 ALL Papers project.
"""
def __init__(self):
self.Helpers = Helpers("Core")
self.do_augmentation = False
self.optimizer = ""
def paper_1(self):
""" Replicates the model proposed in Paper 1.
Replicates the networked and data splits outlined in the Acute Leukemia Classification
Using Convolution Neural Network In Clinical Decision Support System paper
using Tensorflow 2.0.
https://airccj.org/CSCP/vol7/csit77505.pdf
"""
self.DataP1 = DataP1(self.model_type, self.optimizer, self.do_augmentation)
self.DataP1.data_and_labels_sort()
if self.do_augmentation == False:
self.DataP1.data_and_labels_prepare()
else:
self.DataP1.data_and_labels_augmentation_prepare()
self.DataP1.shuffle()
self.DataP1.get_split()
self.ModelP1 = ModelP1(self.model_type, self.DataP1.X_train, self.DataP1.X_test,
self.DataP1.y_train, self.DataP1.y_test, self.optimizer, self.do_augmentation)
self.ModelP1.build_network()
self.ModelP1.compile_and_train()
self.ModelP1.save_model_as_json()
self.ModelP1.save_weights()
self.ModelP1.predictions()
self.ModelP1.evaluate_model()
self.ModelP1.plot_metrics()
self.ModelP1.confusion_matrix()
self.ModelP1.figures_of_merit()
AllCnn = AllCnn()
def main():
if sys.argv[1] == "Adam":
AllCnn.optimizer = "adam"
elif sys.argv[1] == "RMSprop":
AllCnn.optimizer = "rmsprop"
if sys.argv[3] == 'True':
AllCnn.do_augmentation = True
else:
AllCnn.do_augmentation = False
if sys.argv[2] == '1':
AllCnn.model_type = "model_1"
AllCnn.paper_1()
elif sys.argv[2] == '2':
AllCnn.model_type = "model_2"
print("Model 2 is currently not available yet")
if __name__ == "__main__":
main()