Skip to content

Commit

Permalink
fix some bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zdaiot committed Dec 25, 2019
1 parent 75f24a5 commit f392d27
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 5 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_classify_config():
parser = argparse.ArgumentParser()

# -----------------------------------------超参数设置-----------------------------------------
parser.add_argument('--batch_size', type=int, default=72, help='batch size')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--epoch', type=int, default=100, help='epoch')
parser.add_argument('--lr', type=float, default=1e-3, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __prepare__(self, label_json_path):
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transforms = None
config.image_size = [416, 416]
config.selected_fold = [0]

weight_path = os.path.join('checkpoints', model_type)
lists = os.listdir(weight_path) # 获得文件夹内所有文件
Expand Down
2 changes: 1 addition & 1 deletion train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, config, fold, train_labels_number):
# 初始化分类度量准则类
with open("online-service/model/label_id_name.json", 'r', encoding='utf-8') as json_file:
self.class_names = list(json.load(json_file).values())
self.classification_metric = ClassificationMetric(self.class_names, self.model_path)
self.classification_metric = ClassificationMetric(self.class_names, self.model_path, text_flag=0)

self.max_accuracy_valid = 0

Expand Down
3 changes: 1 addition & 2 deletions train_classifier_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, config, fold, train_labels_number):
# 初始化分类度量准则类
with open(config.local_data_root+'label_id_name.json', 'r', encoding='utf-8') as json_file:
self.class_names = list(json.load(json_file).values())
self.classification_metric = ClassificationMetric(self.class_names, self.model_path)
self.classification_metric = ClassificationMetric(self.class_names, self.model_path, text_flag=0)

self.max_accuracy_valid = 0

Expand Down Expand Up @@ -306,7 +306,6 @@ def validation(self, valid_loader):
oa,
average_accuracy,
kappa,
text_flag=0,
font_fname="../font/simhei.ttf"
)
else:
Expand Down
45 changes: 45 additions & 0 deletions utils/rm_deleted_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
import os

delete_files_path = 'data/huawei_data/delete_bak'
dataset_split_file = 'data/huawei_data/dataset_split.json'

delete_files = [f for f in os.listdir(delete_files_path) if f.endswith('jpg')]
with open(dataset_split_file, 'r') as f:
train_list, val_list = json.load(f)

delete_files_count = []
undelete_train_list = []
for fold_index, fold_list in enumerate(train_list):
fold_sample_list = fold_list[0]
fold_label_list = fold_list[1]
undelete_fold_sample = []
undelete_fold_label = []
for sample_index, sample in enumerate(fold_sample_list):
if sample not in delete_files:
# 不在被删除的文件中
undelete_fold_sample.append(sample)
undelete_fold_label.append(fold_label_list[sample_index])
else:
delete_files_count.append([sample, fold_label_list[sample_index]])
print('[Train Fold %d] Remove: %s, Label: %d' % (fold_index, sample, fold_label_list[sample_index]))
undelete_train_list.append([undelete_fold_sample, undelete_fold_label])

undelete_val_list = []
for fold_index, fold_list in enumerate(val_list):
fold_sample_list = fold_list[0]
fold_label_list = fold_list[1]
undelete_fold_sample = []
undelete_fold_label = []
for sample_index, sample in enumerate(fold_sample_list):
if sample not in delete_files:
# 不在被删除的文件中
undelete_fold_sample.append(sample)
undelete_fold_label.append(fold_label_list[sample_index])
else:
delete_files_count.append([sample, fold_label_list[sample_index]])
print('[Val Fold %d] Remove: %s, Label: %d' % (fold_index, sample, fold_label_list[sample_index]))
undelete_val_list.append([undelete_fold_sample, undelete_fold_label])

with open('dataset_split_delete.json', 'w') as f:
json.dump([undelete_train_list, undelete_val_list], f, ensure_ascii=False)

0 comments on commit f392d27

Please sign in to comment.