Skip to content

Commit

Permalink
adjust custom classifier info
Browse files Browse the repository at this point in the history
  • Loading branch information
maxcorsini committed Feb 11, 2022
1 parent a6bb735 commit 60a5cc9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
14 changes: 10 additions & 4 deletions TagLab.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def __init__(self, parent=None):

# training results
self.classifier_name = None
self.network_name = None
self.dataset_train = None

# NETWORKS
Expand Down Expand Up @@ -2165,6 +2166,7 @@ def resetAll(self):
self.trainResultsWidget = None
self.progress_bar = None
self.classifier_name = None
self.network_name = None
self.dataset_train_info = None
self.project = Project()
self.project.loadDictionary(self.default_dictionary)
Expand Down Expand Up @@ -3875,7 +3877,7 @@ def trainNewNetwork(self):
L2 = self.trainYourNetworkWidget.getWeightDecay()
batch_size = self.trainYourNetworkWidget.getBatchSize()

classifier_name = self.trainYourNetworkWidget.editClassifierName.text()
classifier_name = self.trainYourNetworkWidget.editNetworkName.text()
network_name = self.trainYourNetworkWidget.editNetworkName.text() + ".net"
network_filename = os.path.join(os.path.join(self.taglab_dir, "models"), network_name)

Expand Down Expand Up @@ -3921,6 +3923,7 @@ def trainNewNetwork(self):

# info about the classifier created
self.classifier_name = classifier_name
self.network_name = network_name
self.dataset_train_info = dataset_train_info

self.deleteProgressBar()
Expand All @@ -3943,18 +3946,21 @@ def confirmTraining(self):

new_classifier = dict()
new_classifier["Classifier Name"] = self.classifier_name
new_classifier["Average Norm."] = list(self.dataset_train_info.dataset_average)
new_classifier["Weights"] = self.network_name
new_classifier["Num. Classes"] = self.dataset_train_info.num_classes
new_classifier["Classes"] = list(self.dataset_train_info.dict_target)

# read the target scale factor
# scale
target_pixel_size_file = os.path.join(self.trainResultsWidget.dataset_folder, "target-scale-factor.txt")
fl = open(target_pixel_size_file, "r")
line = fl.readline()
fl.close()
target_pixel_size = float(line)

new_classifier["Scale"] = target_pixel_size

new_classifier["Average Norm."] = list(self.dataset_train_info.dataset_average)

# update config file
self.available_classifiers.append(new_classifier)
newconfig = dict()
newconfig["Available Classifiers"] = self.available_classifiers
Expand Down
4 changes: 0 additions & 4 deletions source/QtTYNWidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ def __init__(self, annotations, parent=None):
self.editDatasetFolder = QLineEdit("temp")
self.editDatasetFolder.setStyleSheet("background-color: rgb(55,55,55); border: 1px solid rgb(90,90,90)")
self.editDatasetFolder.setFixedWidth(LINEWIDTH)
self.editClassifierName = QLineEdit("myclassifier")
self.editClassifierName.setStyleSheet("background-color: rgb(55,55,55); border: 1px solid rgb(90,90,90)")
self.editClassifierName.setFixedWidth(LINEWIDTH)
self.editClassifierName.setReadOnly(False)
self.editNetworkName = QLineEdit("mynetwork")
self.editNetworkName.setStyleSheet("background-color: rgb(55,55,55); border: 1px solid rgb(90,90,90)")
self.editNetworkName.setFixedWidth(LINEWIDTH)
Expand Down

0 comments on commit 60a5cc9

Please sign in to comment.