-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
154 lines (115 loc) · 3.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# coding=utf-8
import re
import json
import nltk
from nltk.corpus import wordnet as wn
from nltk.stem import *
import random
# 第二张表
def semantic_prob(semantic_segs_list):
semantic_tag_list = []
# 段列表
segandstag_list = []
for segs in semantic_segs_list:
# 三元组
seg = segs[0]
seman = segs[2]
semantic_tag_list.append(seman)
segandstag_list.append((seg,seman))
# 语义模式词典
semantic_dict = {}
# 统计概率
segandstag_set = set(segandstag_list)
for segandstag in list(segandstag_set):
seg,stag = segandstag
term = stag + "->" + seg
key = segandstag_list.count(segandstag) / semantic_tag_list.count(stag) + random.random()/1000000
dict = {term:key}
semantic_dict.update(dict)
with open("../data/seman_prob.json", "w") as seman_porb:
json.dump(semantic_dict, seman_porb)
return semantic_dict
# 第一张表
def base_struct_prob(base_list):
base_porb_dict = {}
# 去重
base_set = set(base_list)
for bs in list(base_set):
dict = {bs:base_list.count(bs)/len(base_list)}
base_porb_dict.update(dict)
with open("../data/base_prob.json","w") as base_porb:
json.dump(base_porb_dict,base_porb)
return base_porb_dict
def generate_base_struct(segs):
base = ""
for seg,tag,semantic in segs:
base = base + "[" + semantic + "]"
return base
def semantic_classify(segs):
stemmer = SnowballStemmer("english")
verb = 'V.*'
noun = 'N.*'
semantic_segs = []
c = ""
for seg,tag in segs:
if tag == 'gap':
c = get_gap_type(seg)
elif re.match(verb, tag) or re.match(noun, tag):
synsets = wn.synsets(stemmer.stem(seg))
if len(synsets) > 0:
c = synsets[0].name()
else:
# print("无法分类")
c = stemmer.stem(seg + ".nn.01")
else:
c = tag
semantic_segs.append((seg,tag,c))
return semantic_segs
def pos_tag(segs,words,gaps):
tagged_segs = []
for seg in segs:
if seg in words:
token = nltk.word_tokenize(seg)
tagged_seg = nltk.pos_tag(token)
if seg in gaps:
tagged_seg = [(seg,'gap')]
tagged_segs = tagged_segs + tagged_seg
return tagged_segs
def get_gap_type(gap):
number = "^[0-9]*$"
char = "^[A-Za-z]+$"
num_char = "^[A-Za-z0-9]+$"
special = "^[%&',;=+_@!()?>$\x22]+"
if re.match(number,gap):
return 'number'
elif re.match(char, gap):
return 'char'
elif re.match(num_char, gap):
return 'num_char'
elif re.match(special, gap):
return 'speical'
else:
return 'all_mixed'
if __name__ == '__main__':
# nltk.download('punkt')
# nltk.download('wordnet')
# list = semantic_classify(pos_tag(segs,words,gaps))
#
# print(generate_base_struct(list))
with open("../data/data0.json",'r') as data:
lines = data.readlines()
base_list = []
segs_list = []
for line in lines:
dict = json.loads(line)
sorted_segs = dict.get("sorted_segs")
words = dict.get("words")
gaps = dict.get("gaps")
tagged_segs = pos_tag(sorted_segs, words, gaps)
semantic_tagged_segs = semantic_classify(tagged_segs)
base = generate_base_struct(semantic_tagged_segs)
base_list.append(base)
for seg in semantic_tagged_segs:
segs_list.append(seg)
# base_struct_prob(base_list)
# semantic_prob(segs_list)