Skip to content

Commit

Permalink
add hwp dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
zdaiot committed Dec 25, 2019
1 parent 4526f15 commit 75f24a5
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 43 deletions.
7 changes: 5 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_classify_config():

# -----------------------------------------超参数设置-----------------------------------------
parser.add_argument('--batch_size', type=int, default=72, help='batch size')
parser.add_argument('--epoch', type=int, default=150, help='epoch')
parser.add_argument('--epoch', type=int, default=100, help='epoch')
parser.add_argument('--lr', type=float, default=1e-3, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')

Expand Down Expand Up @@ -55,6 +55,8 @@ def get_classify_config():
parser.add_argument('--val_size', type=float, default=0.2, help='the ratio of val data when n_splits=1.')
parser.add_argument('--load_split_from_file', type=str, default='data/huawei_data/dataset_split_delete.json',
help='Loading dataset split from this file')
parser.add_argument('--dataset_from_folder', type=str2bool, nargs='?', const=True, default=False,
help='If True, then load datasets distinguished by train and valid')

# -----------------------------------------模型设置-----------------------------------------
parser.add_argument('--model_type', type=str, default='se_resnext101_32x4d',
Expand Down Expand Up @@ -86,6 +88,7 @@ def get_classify_config():
parser.add_argument('--train_url', type=str, default='./checkpoints',
help='the path to save training outputs. For example: s3://ai-competition-zdaiot/logs/')
parser.add_argument('--data_url', type=str, default='data/huawei_data/combine')
parser.add_argument('--model_snapshots_name', type=str, default='model_snapshots')
parser.add_argument('--init_method', type=str)

config = parser.parse_args()
Expand All @@ -99,4 +102,4 @@ def get_classify_config():
config = get_classify_config()
print(config.augmentation_flag)
print(config.image_size)
print(config.multi_scale_size, type(config.multi_scale_size))
print(config.dataset_from_folder, type(config.multi_scale_size))
64 changes: 64 additions & 0 deletions datasets/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,71 @@ def multi_scale_transforms(image_size, images, mean=(0.485, 0.456, 0.406), std=(
images_resize[index] = image

return images_resize


def get_dataloader_from_folder(
data_root,
image_size,
transforms,
mean,
std,
batch_size,
multi_scale=False,
):
samples_files = [f for f in os.listdir(data_root) if f.endswith('.jpg')]
train_samples_list = []
train_labels_list = []
val_samples_list = []
val_labels_list = []
for sample_file in samples_files:
label_file = sample_file.replace('.jpg', '.txt')
with open(os.path.join(data_root, label_file), 'r') as label_f:
for line in label_f.readlines():
label = int(line.split(' ')[1])
if 'train' in sample_file:
train_samples_list.append(sample_file)
train_labels_list.append(label)
else:
val_samples_list.append(sample_file)
val_labels_list.append(label)

train_dataset = TrainDataset(
data_root,
train_samples_list,
train_labels_list,
image_size,
transforms=transforms,
mean=mean,
std=std,
multi_scale=multi_scale,
)
# 默认不在验证集上进行多尺度
val_dataset = ValDataset(
data_root,
val_samples_list,
val_labels_list,
image_size,
mean=mean,
std=std,
multi_scale=multi_scale
)

train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=8,
pin_memory=True,
shuffle=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
num_workers=8,
pin_memory=True,
shuffle=False
)
return train_dataloader, val_dataloader, [1 for x in range(54)], [1 for x in range(54)]


if __name__ == "__main__":
data_root = 'data/huawei_data/train_data'
Expand Down
53 changes: 53 additions & 0 deletions datasets/imbalanced_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.utils.data
import torchvision


class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
indices (list, optional): a list of indices
num_samples (int, optional): number of samples to draw
"""

def __init__(self, dataset, indices=None, num_samples=None):

# if indices is not provided,
# all elements in the dataset will be considered
self.indices = list(range(len(dataset))) \
if indices is None else indices

# if num_samples is not provided,
# draw `len(indices)` samples in each iteration
self.num_samples = len(self.indices) \
if num_samples is None else num_samples

# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(dataset, idx)
if label in label_to_count:
label_to_count[label] += 1
else:
label_to_count[label] = 1

# weight for each sample
weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
for idx in self.indices]
self.weights = torch.DoubleTensor(weights)

def _get_label(self, dataset, idx):
dataset_type = type(dataset)
if dataset_type is torchvision.datasets.MNIST:
return dataset.train_labels[idx].item()
elif dataset_type is torchvision.datasets.ImageFolder:
return dataset.imgs[idx][1]
else:
raise NotImplementedError

def __iter__(self):
return (self.indices[i] for i in torch.multinomial(
self.weights, self.num_samples, replacement=True))

def __len__(self):
return self.num_samples
13 changes: 7 additions & 6 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,25 @@ def save_checkpoint(self, save_path, state, is_best):
save_best_path = '/'.join(save_path.split('/')[:-1] + ['model_best.pth'])
shutil.copyfile(save_path, save_best_path)

def save_checkpoint_online(self, save_path, state, is_best, bucket_name):
def save_checkpoint_online(self, save_path, state, is_best, bucket_name, model_snapshots_name):
''' 保存模型参数
Args:
save_path: str, 要保存的权重路径
state: dict, 存有模型参数、最大dice等信息的字典
is_best: bool, 是否为最优模型
bucket_name: str, 桶的名称
model_snapshots_name: str,复制到远程文件夹的路径
Return:
None
'''
torch.save(state, save_path)
# mox.file可兼容处理本地路径和OBS路径
# see https://github.com/huaweicloud/ModelArts-Lab/blob/master/docs/moxing_api_doc/MoXing_API_File.md
if not mox.file.exists(os.path.join(bucket_name, 'model_snapshots', 'model')):
mox.file.make_dirs(os.path.join(bucket_name, 'model_snapshots', 'model'))
if not mox.file.exists(os.path.join(bucket_name, model_snapshots_name, 'model')):
mox.file.make_dirs(os.path.join(bucket_name, model_snapshots_name, 'model'))

for file in glob.glob('/'.join(save_path.split('/')[:-1]) + '/events*'):
mox.file.copy(file, os.path.join(bucket_name, 'model_snapshots', 'model', os.path.basename(file)))
mox.file.copy(file, os.path.join(bucket_name, model_snapshots_name, 'model', os.path.basename(file)))

if is_best:
print('Saving Best Model.')
Expand All @@ -124,8 +125,8 @@ def save_checkpoint_online(self, save_path, state, is_best, bucket_name):
os.remove(save_path)

mox.file.copy_parallel('/'.join(save_path.split('/')[:-1]),
os.path.join(bucket_name, 'model_snapshots', 'model'))
mox.file.copy_parallel('../online-service/model', os.path.join(bucket_name, 'model_snapshots', 'model'))
os.path.join(bucket_name, model_snapshots_name, 'model'))
mox.file.copy_parallel('../online-service/model', os.path.join(bucket_name, model_snapshots_name, 'model'))

def load_checkpoint(self, load_path):
''' 保存模型参数
Expand Down
43 changes: 28 additions & 15 deletions train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from solver import Solver
from utils.set_seed import seed_torch
from models.build_model import PrepareModel
from datasets.create_dataset import GetDataloader
from datasets.create_dataset import GetDataloader, get_dataloader_from_folder
from losses.get_loss import Loss
from utils.classification_metric import ClassificationMetric
from datasets.data_augmentation import DataAugmentation
Expand Down Expand Up @@ -293,21 +293,34 @@ def init_log(self):
else:
transforms = None

get_dataloader = GetDataloader(
data_root,
folds_split=folds_split,
test_size=test_size,
choose_dataset=config.choose_dataset,
load_split_from_file=config.load_split_from_file
)
if config.dataset_from_folder:
train_dataloaders, val_dataloaders, train_labels_number, _ = get_dataloader_from_folder(
data_root,
config.image_size,
transforms,
mean,
std,
config.batch_size,
multi_scale,
)
train_dataloaders, val_dataloaders, train_labels_number_folds = [train_dataloaders], [val_dataloaders], [train_labels_number]

else:
get_dataloader = GetDataloader(
data_root,
folds_split=folds_split,
test_size=test_size,
choose_dataset=config.choose_dataset,
load_split_from_file=config.load_split_from_file
)

train_dataloaders, val_dataloaders, train_labels_number_folds, _ = get_dataloader.get_dataloader(
config.batch_size,
config.image_size,
mean, std,
transforms=transforms,
multi_scale=multi_scale
)
train_dataloaders, val_dataloaders, train_labels_number_folds, _ = get_dataloader.get_dataloader(
config.batch_size,
config.image_size,
mean, std,
transforms=transforms,
multi_scale=multi_scale
)

for fold_index, [train_loader, valid_loader, train_labels_number] in enumerate(zip(train_dataloaders, val_dataloaders, train_labels_number_folds)):
if fold_index in config.selected_fold:
Expand Down
54 changes: 34 additions & 20 deletions train_classifier_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from solver import Solver
from utils.set_seed import seed_torch
from models.build_model import PrepareModel
from datasets.create_dataset import GetDataloader
from datasets.create_dataset import GetDataloader, get_dataloader_from_folder
from losses.get_loss import Loss
from utils.classification_metric import ClassificationMetric
from datasets.data_augmentation import DataAugmentation
Expand Down Expand Up @@ -61,7 +61,7 @@ def prepare_data_on_modelarts(args):
print(pip.read())

# train_local: 用于训练过程中保存的输出位置,而train_url用于移动到OBS的位置
args.train_local = os.path.join(args.local_data_root, 'model_snapshots')
args.train_local = os.path.join(args.local_data_root, args.model_snapshots_name)
if not os.path.exists(args.train_local):
os.mkdir(args.train_local)

Expand Down Expand Up @@ -251,7 +251,8 @@ def train(self, train_loader, valid_loader):
),
state,
is_best,
self.bucket_name
self.bucket_name,
config.model_snapshots_name
)

# 写到tensorboard中
Expand Down Expand Up @@ -305,6 +306,7 @@ def validation(self, valid_loader):
oa,
average_accuracy,
kappa,
text_flag=0,
font_fname="../font/simhei.ttf"
)
else:
Expand Down Expand Up @@ -346,23 +348,35 @@ def init_log(self):
else:
transforms = None

get_dataloader = GetDataloader(
config.data_local,
folds_split=folds_split,
test_size=test_size,
label_names_path=config.local_data_root+'label_id_name.json',
choose_dataset=config.choose_dataset,
load_split_from_file=config.load_split_from_file
)

train_dataloaders, val_dataloaders, train_labels_number_folds, _ = get_dataloader.get_dataloader(
config.batch_size,
config.image_size,
mean, std,
transforms=transforms,
multi_scale=multi_scale,
draw_distribution=False
)
if config.dataset_from_folder:
train_dataloaders, val_dataloaders, train_labels_number, _ = get_dataloader_from_folder(
data_root,
config.image_size,
transforms,
mean,
std,
config.batch_size,
multi_scale,
)
train_dataloaders, val_dataloaders, train_labels_number_folds = [train_dataloaders], [val_dataloaders], [train_labels_number]
else:
get_dataloader = GetDataloader(
config.data_local,
folds_split=folds_split,
test_size=test_size,
label_names_path=config.local_data_root+'label_id_name.json',
choose_dataset=config.choose_dataset,
load_split_from_file=config.load_split_from_file
)

train_dataloaders, val_dataloaders, train_labels_number_folds, _ = get_dataloader.get_dataloader(
config.batch_size,
config.image_size,
mean, std,
transforms=transforms,
multi_scale=multi_scale,
draw_distribution=False
)

for fold_index, [train_loader, valid_loader, train_labels_number] in enumerate(zip(train_dataloaders, val_dataloaders, train_labels_number_folds)):
if fold_index in config.selected_fold:
Expand Down

0 comments on commit 75f24a5

Please sign in to comment.