-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathconfig.py
59 lines (51 loc) · 2.19 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# coding: utf-8
# Author: Miracle Yoo
# E-mail: [email protected]
import torch
class Config(object):
def __init__(self):
# 公共参数设置
self.USE_CUDA = torch.cuda.is_available()
self.MODEL = "TextCNNIncDeep"
self.MODEL_NAME = "TextCNNIncDeep_CHAR_v9"
self.RUNNING_ON_SERVER = True
self.SUMMARY_PATH = "summary/TextCNNIncDeep_CH_AR_v9"
self.NET_SAVE_PATH = "./source/trained_net/"
self.TRAIN_DATASET_PATH = "/disk/Beibei_Dataset/train_data_v5.csv"
self.TEST_DATASET_PATH = "/disk/Beibei_Dataset/yibot_test_1617.csv"
self.NUM_EPOCHS = 50000
self.BATCH_SIZE = 400
self.NUM_TRAIN = self.NUM_EPOCHS * self.BATCH_SIZE
self.NUM_TEST = 0
self.TEST_STEP = 100
self.TOP_NUM = 3
self.NUM_WORKERS = 8
self.IS_TRAINING = True
self.ENSEMBLE_TEST = False
self.LEARNING_RATE = 0.001
self.RE_TRAIN = False
self.USE_PAIR_MAPPING = False
self.USE_TRAD2SIMP = False
self.TEST_POSITION = 'Gangge Server'
# 模型共享参数设置
self.OPTIMIZER = 'Adam'
self.USE_CHAR = True
self.SENT_LEN = 30
self.USE_WORD2VEC = False
self.BANLANCE = True
self.NUM_CLASSES = 1890
self.EMBEDDING_DIM = 300
self.VOCAB_SIZE = 20029
self.CHAR_SIZE = 3403
self.NUM_ID_FEATURE_MAP = 250
# TextCNN模型设置
self.TITLE_DIM = 512
self.LINER_HID_SIZE = 1409
self.KERNEL_SIZE = [2, 2, 3, 3, 4, 4, 5, 5, 7, 7]
self.TITLE_EMBEDDING = 256 # 仅用在TextCNNCos中
# TextCNNInc模型设置
self.SIN_KER_SIZE = [1, 1, 3, 3] # single convolution kernel
self.DOU_KER_SIZE = [(1, 3), (3, 5), (3, 3), (5, 5)] # double convolution kernel
# 模型融合
self.MODEL_NAME_LIST = ["TextCNNINC_CHAR_v5.pth", "TextLSTM_CHAR_v5.pth"]
self.MODEL_THETA_LIST = [1, 0.8]