-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
84 lines (72 loc) · 2.56 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import json
import re
def update_chain(chain, seq):
if len(seq) == 0:
return
head = seq[0]
tail = seq[1:]
chain.setdefault(head, {
"children": {},
"score": 0
})
chain[head]["score"] += 1
update_chain(chain[head]["children"], tail)
def prune_chain_node(node, k):
candidates = []
for child in node["children"]:
candidates.append((child, node["children"][child]["score"], node["children"][child]["children"]))
node["children"] = {}
candidates.sort(key=lambda x: -x[1])
for child, score, children in candidates[:k]:
node["children"][child] = {
"children": children,
"score": score
}
for child in node["children"]:
prune_chain_node(node["children"][child], k)
def normalize_chain_node(node):
total = 0
for child in node["children"]:
total += node["children"][child]["score"]
if total > 0:
for child in node["children"]:
node["children"][child]["score"] /= total
for child in node["children"]:
normalize_chain_node(node["children"][child])
def build_markov_chain(text, depth, k):
tokens = [
t for t in map(lambda s: s.strip(), re.split(r"(\W)", text.lower()))
if t != ""
]
print("Text contains", len(tokens), "tokens")
print("Building chain...")
chain = dict()
for seq in zip(*[tokens[d:] for d in range(depth + 1)]):
update_chain(chain, seq)
if k is not None:
print("Pruning...")
for token in chain:
prune_chain_node(chain[token], k)
print("Normalizing...")
for token in chain:
normalize_chain_node(chain[token])
return {
"chain": chain,
"tokens": list(set(tokens))
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Input text file (utf8)")
parser.add_argument("name", type=str, help="Output name")
parser.add_argument("-k", "--top-k", type=int, default=0, help="Pruning amount (0 means no pruning)")
parser.add_argument("-d", "--depth", type=int, default=1, help="Co-occurrences depth (more means more context)")
args = parser.parse_args()
with open(args.input, "r", encoding="utf8") as file:
text = file.read()
chain = build_markov_chain(text, k=None if args.top_k == 0 else args.top_k, depth=args.depth)
output_path = "%s_%d_%d.json" % (args.name, args.depth, args.top_k)
with open(output_path, "w", encoding="utf8") as file:
json.dump(chain, file)
if __name__ == "__main__":
main()