diff --git a/config.py b/config.py index 5c46aff..4243c5d 100644 --- a/config.py +++ b/config.py @@ -18,7 +18,7 @@ def get_classify_config(): parser = argparse.ArgumentParser() # -----------------------------------------超参数设置----------------------------------------- - parser.add_argument('--batch_size', type=int, default=72, help='batch size') + parser.add_argument('--batch_size', type=int, default=48, help='batch size') parser.add_argument('--epoch', type=int, default=100, help='epoch') parser.add_argument('--lr', type=float, default=1e-3, help='init lr') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer') diff --git a/demo.py b/demo.py index 9e3eb39..126cbd3 100644 --- a/demo.py +++ b/demo.py @@ -143,7 +143,7 @@ def __prepare__(self, label_json_path): mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) transforms = None - config.image_size = [416, 416] + config.selected_fold = [0] weight_path = os.path.join('checkpoints', model_type) lists = os.listdir(weight_path) # 获得文件夹内所有文件 diff --git a/train_classifier.py b/train_classifier.py index 88a2bde..b249ebf 100644 --- a/train_classifier.py +++ b/train_classifier.py @@ -98,7 +98,7 @@ def __init__(self, config, fold, train_labels_number): # 初始化分类度量准则类 with open("online-service/model/label_id_name.json", 'r', encoding='utf-8') as json_file: self.class_names = list(json.load(json_file).values()) - self.classification_metric = ClassificationMetric(self.class_names, self.model_path) + self.classification_metric = ClassificationMetric(self.class_names, self.model_path, text_flag=0) self.max_accuracy_valid = 0 diff --git a/train_classifier_online.py b/train_classifier_online.py index e3200ef..50a0c94 100644 --- a/train_classifier_online.py +++ b/train_classifier_online.py @@ -157,7 +157,7 @@ def __init__(self, config, fold, train_labels_number): # 初始化分类度量准则类 with open(config.local_data_root+'label_id_name.json', 'r', encoding='utf-8') as json_file: self.class_names = list(json.load(json_file).values()) - self.classification_metric = ClassificationMetric(self.class_names, self.model_path) + self.classification_metric = ClassificationMetric(self.class_names, self.model_path, text_flag=0) self.max_accuracy_valid = 0 @@ -306,7 +306,6 @@ def validation(self, valid_loader): oa, average_accuracy, kappa, - text_flag=0, font_fname="../font/simhei.ttf" ) else: diff --git a/utils/rm_deleted_files.py b/utils/rm_deleted_files.py new file mode 100644 index 0000000..141b7d0 --- /dev/null +++ b/utils/rm_deleted_files.py @@ -0,0 +1,45 @@ +import json +import os + +delete_files_path = 'data/huawei_data/delete_bak' +dataset_split_file = 'data/huawei_data/dataset_split.json' + +delete_files = [f for f in os.listdir(delete_files_path) if f.endswith('jpg')] +with open(dataset_split_file, 'r') as f: + train_list, val_list = json.load(f) + +delete_files_count = [] +undelete_train_list = [] +for fold_index, fold_list in enumerate(train_list): + fold_sample_list = fold_list[0] + fold_label_list = fold_list[1] + undelete_fold_sample = [] + undelete_fold_label = [] + for sample_index, sample in enumerate(fold_sample_list): + if sample not in delete_files: + # 不在被删除的文件中 + undelete_fold_sample.append(sample) + undelete_fold_label.append(fold_label_list[sample_index]) + else: + delete_files_count.append([sample, fold_label_list[sample_index]]) + print('[Train Fold %d] Remove: %s, Label: %d' % (fold_index, sample, fold_label_list[sample_index])) + undelete_train_list.append([undelete_fold_sample, undelete_fold_label]) + +undelete_val_list = [] +for fold_index, fold_list in enumerate(val_list): + fold_sample_list = fold_list[0] + fold_label_list = fold_list[1] + undelete_fold_sample = [] + undelete_fold_label = [] + for sample_index, sample in enumerate(fold_sample_list): + if sample not in delete_files: + # 不在被删除的文件中 + undelete_fold_sample.append(sample) + undelete_fold_label.append(fold_label_list[sample_index]) + else: + delete_files_count.append([sample, fold_label_list[sample_index]]) + print('[Val Fold %d] Remove: %s, Label: %d' % (fold_index, sample, fold_label_list[sample_index])) + undelete_val_list.append([undelete_fold_sample, undelete_fold_label]) + +with open('dataset_split_delete.json', 'w') as f: + json.dump([undelete_train_list, undelete_val_list], f, ensure_ascii=False) \ No newline at end of file