-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patharguments.py
127 lines (116 loc) · 6.5 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
class ArgParser(object):
def __init__(self):
parser = argparse.ArgumentParser()
# Model related arguments
parser.add_argument('--shiftRegressionFlag', default=0, type=int,
help="if turn the task to regresion task")
parser.add_argument('--dataFlag', default=1, type=int,
help="if load diff or same data classes")
parser.add_argument('--audShiftFlag', default=0, type=int,
help="if do audio waveform shift")
parser.add_argument('--id', default='',
help="a name for identifying the model")
parser.add_argument('--num_mix', default=2, type=int,
help="number of sounds to mix")
parser.add_argument('--num_class', default=11, type=int,
help="number of classes in the dataset")
parser.add_argument('--arch_sound', default='unet7',
help="architecture of net_sound")
parser.add_argument('--arch_frame', default='resnet18dilated',
help="architecture of net_frame")
parser.add_argument('--num_channels', default=32, type=int,
help='number of channels')
parser.add_argument('--num_frames', default=1, type=int,
help='number of frames')
parser.add_argument('--stride_frames', default=1, type=int,
help='sampling stride of frames')
parser.add_argument('--mask_thres', default=0.5, type=float,
help="threshold in the case of binary masks")
parser.add_argument('--loss', default='l1',
help="loss function to use")
parser.add_argument('--weighted_loss', default=0, type=int,
help="weighted loss")
parser.add_argument('--log_freq', default=1, type=int,
help="log frequency scale")
# Data related arguments
parser.add_argument('--num_gpus', default=1, type=int,
help='number of gpus to use')
parser.add_argument('--batch_size_per_gpu', default=32, type=int,
help='input batch size')
parser.add_argument('--workers', default=32, type=int,
help='number of data loading workers')
parser.add_argument('--num_val', default=-1, type=int,
help='number of images to evalutate')
parser.add_argument('--num_vis', default=40, type=int,
help='number of images to evalutate')
parser.add_argument('--audLen', default=65535, type=int,
help='sound length')
parser.add_argument('--audRate', default=11025, type=int,
help='sound sampling rate')
parser.add_argument('--stft_frame', default=1022, type=int,
help="stft frame length")
parser.add_argument('--stft_hop', default=256, type=int,
help="stft hop length")
parser.add_argument('--imgSize', default=224, type=int,
help='size of input frame')
parser.add_argument('--frameRate', default=8, type=float,
help='video frame sampling rate')
parser.add_argument('--vid_dur', default=48, type=float,
help='duration of video')
parser.add_argument('--shift_dur', default=1, type=float,
help='shift duration of video')
parser.add_argument('--margin_dur', default=24, type=float,
help='the margin dur set at the start and end of video')
parser.add_argument('--non_inter_dur', default=1, type=float,
help='set 1 sec for non intersection between aligned audio and shifted audio')
parser.add_argument('--triplet_margin', default=2.0, type=float,
help='set margin for triplet loss')
parser.add_argument('--dataset', default='MUSIC21_MIT', type=str,
help='Used dataset MUSIC21')
parser.add_argument('--best_err', default=300, type=float,
help='best err')
parser.add_argument('--best_sdr', default=-100, type=float,
help='best sdr')
# Misc arguments
parser.add_argument('--seed', default=1234, type=int,
help='manual seed')
parser.add_argument('--ckpt', default='./myckpt',
help='folder to output checkpoints')
parser.add_argument('--disp_iter', type=int, default=20,
help='frequency to display')
parser.add_argument('--eval_epoch', type=int, default=1,
help='frequency to evaluate')
self.parser = parser
def add_train_arguments(self):
parser = self.parser
parser.add_argument('--mode', default='train',
help="train/eval")
parser.add_argument('--list_train',
default='data/train.csv')
parser.add_argument('--list_val',
default='data/val.csv')
parser.add_argument('--dup_trainset', default=100, type=int,
help='duplicate so that one epoch has more iters')
# optimization related arguments
parser.add_argument('--num_epoch', default=100, type=int,
help='epochs to train for')
parser.add_argument('--lr_frame', default=1e-4, type=float, help='LR')
parser.add_argument('--lr_sound', default=1e-3, type=float, help='LR')
parser.add_argument('--lr_steps',
nargs='+', type=int, default=[40, 60],
help='steps to drop LR in epochs')
parser.add_argument('--beta1', default=0.9, type=float,
help='momentum for sgd, beta1 for adam')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='weights regularizer')
self.parser = parser
def print_arguments(self, args):
print("Input arguments:")
for key, val in vars(args).items():
print("{:16} {}".format(key, val))
def parse_train_arguments(self):
self.add_train_arguments()
args = self.parser.parse_args()
self.print_arguments(args)
return args