-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathrun.py
140 lines (89 loc) · 5.2 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from configs.base import ParamManager
from dataloaders.base import DataManager
from backbones.base import ModelManager
from methods import method_map
from utils.functions import save_results
import logging
import argparse
import sys
import os
import datetime
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='open_intent_detection', help="Type style.")
parser.add_argument('--logger_name', type=str, default='Detection', help="Logger name for open intent detection.")
parser.add_argument('--log_dir', type=str, default='logs', help="Logger directory.")
parser.add_argument("--dataset", default='banking', type=str, help="The name of the dataset to train selected")
parser.add_argument("--known_cls_ratio", default=0.75, type=float, help="The number of known classes")
parser.add_argument("--labeled_ratio", default=1.0, type=float, help="The ratio of labeled samples in the training set")
parser.add_argument("--method", type=str, default='ADB', help="which method to use")
parser.add_argument("--train", action="store_true", help="Whether to train the model")
parser.add_argument("--pretrain", action="store_true", help="Whether to pre-train the model")
parser.add_argument("--save_model", action="store_true", help="save trained-model for open intent detection")
parser.add_argument("--backbone", type=str, default='bert', help="which backbone to use")
parser.add_argument("--config_file_name", type=str, default='ADB.py', help = "The name of the config file.")
parser.add_argument('--seed', type=int, default=0, help="random seed for initialization")
parser.add_argument("--gpu_id", type=str, default='0', help="Select the GPU id")
parser.add_argument("--pipe_results_path", type=str, default='pipe_results', help="the path to save results of pipeline methods")
parser.add_argument("--data_dir", default = sys.path[0]+'/../data', type=str,
help="The input data dir. Should contain the .csv files (or other data files) for the task.")
parser.add_argument("--output_dir", default= '/home/sharing/disk1/zhaoshaojie/baseline_test/TEXTOIR', type=str,
help="The output directory where all train data will be written.")
parser.add_argument("--model_dir", default='models', type=str,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--load_pretrained_method", default=None, type=str,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--result_dir", type=str, default = 'results', help="The path to save results")
parser.add_argument("--results_file_name", type=str, default = 'results.csv', help="The file name of all the results.")
parser.add_argument("--save_results", action="store_true", help="save final results for open intent detection")
parser.add_argument("--loss_fct", default="CrossEntropyLoss", help="The loss function for training.")
args = parser.parse_args()
return args
def set_logger(args):
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
file_name = f"{args.method}_{args.dataset}_{args.backbone}_{args.known_cls_ratio}_{args.labeled_ratio}_{time}.log"
logger = logging.getLogger(args.logger_name)
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(os.path.join(args.log_dir, file_name))
fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')
fh.setFormatter(fh_formatter)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch_formatter = logging.Formatter('%(name)s - %(message)s')
ch.setFormatter(ch_formatter)
logger.addHandler(ch)
return logger
def run(args, data, model, logger):
method_manager = method_map[args.method]
method = method_manager(args, data, model, logger_name = args.logger_name)
if args.train:
logger.info('Training Begin...')
method.train(args, data)
logger.info('Training Finished...')
logger.info('Testing begin...')
outputs = method.test(args, data)
logger.info('Testing finished...')
if args.save_results:
logger.info('Results saved in %s', str(os.path.join(args.result_dir, args.results_file_name)))
save_results(args, outputs)
if __name__ == '__main__':
sys.path.append('.')
args = parse_arguments()
logger = set_logger(args)
logger.info('Open Intent Detection Begin...')
logger.info('Parameters Initialization...')
param = ParamManager(args)
args = param.args
logger.debug("="*30+" Params "+"="*30)
for k in args.keys():
logger.debug(f"{k}:\t{args[k]}")
logger.debug("="*30+" End Params "+"="*30)
logger.info('Data and Model Preparation...')
data = DataManager(args, logger_name = args.logger_name)
model = ModelManager(args, data, logger_name = args.logger_name)
run(args, data, model, logger)
logger.info('Open Intent Detection Finished...')