-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfun.py
390 lines (299 loc) · 9.5 KB
/
fun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import nltk
from nltk import everygrams
from re import search
from nltk.lm import MLE
import random
nltk_data_path = "Assets/nltk_data"
if nltk_data_path not in nltk.data.path:
nltk.data.path.append(nltk_data_path)
# For n-gram nltk integration
def load_data():
"""
Load text data from a file and produce a list of token lists
"""
sentences = []
with open("Assets/plot_training.txt", "r", encoding='utf8') as f:
for line in f:
line = line[:-1]
if search('[a-zA-Z]', line):
line = nltk.word_tokenize(line.lower())
sentences.append(line)
return sentences
def build_vocab(sentences):
"""
Take a list of sentences and return a vocab
"""
vocab = ['<s>', '</s>']
for sentence in sentences:
for token in sentence:
if token not in vocab:
vocab.append(token)
return vocab
def build_ngrams(n, sentences):
"""
Take a list of unpadded sentences and create all n-grams as specified by the argument "n" for each sentence
"""
padding_sentences = []
all_ngrams = []
start_pad = '<s>'
end_pad = '</s>'
for sentence in sentences:
new_sentence = []
if n >= 2:
for x in range(n-1):
new_sentence.append(start_pad)
for token in sentence:
new_sentence.append(token)
if n >= 2:
for x in range(n-1):
new_sentence.append(end_pad)
padding_sentences.append(new_sentence)
for pad_sentence in padding_sentences:
all_ngrams.append(list(everygrams(pad_sentence, min_len=n, max_len=n)))
return all_ngrams
def bigram_next_token(start_tokens=("<s>", ) * 3):
"""
Take some starting tokens and produce the most likely token that follows under a bi-gram model
"""
next_token, prob = None, None
data = load_data()
n=len(start_tokens) + 1
ngrams = build_ngrams(n, data)
total_count = 0
freq_dict_matches ={}
for line in ngrams:
for x in range(len(line)):
if start_tokens == (line[x][0:len(start_tokens)]):
total_count += 1
if line[x] in freq_dict_matches:
freq_dict_matches[line[x]] += 1
else:
freq_dict_matches[line[x]] = 1
top_value, top_key = (max(zip(freq_dict_matches.values(), freq_dict_matches.keys())))
prob = top_value/total_count
next_token = top_key[-1]
return next_token, prob
def train_ngram_lm(n):
"""
Train a n-gram language model as specified by the argument "n"
"""
lm = MLE(n)
data = load_data()
train = build_ngrams(n, data)
vocab = build_vocab(data)
lm.fit(train, vocab)
return lm
def string_cleaner(str):
str = str.capitalize()
str = str.replace("""</s>""", """""")
str = str.replace("""<s>""", """""")
str = str.replace(""" ,""", """,""")
str = str.replace(""" n't""", """n't""")
str = str.replace(""" ’""", """'""")
str = str.replace("""’ """, """'""")
str = str.replace("""' """, """'""")
str = str.replace(""" '""", """'""")
str = str.replace("""` """, """'""")
str = str.replace(""" `""", """'""")
str = str.replace(""" :""", """:""")
str = str.replace(""" .""", """.""")
str = str.replace(""" !""", """!""")
str = str.replace(""" ;""", """;""")
str = str.replace(""" ?""", """?""")
str = str.replace(""" )""", """)""")
str = str.replace("""( """, """(""")
str = str.replace("talia", "Talia")
str = str.replace("mae", "Mae")
str = str.replace("placer", "Placer")
str = str.replace("placeholder", "Placeholder")
str = str.replace("tracey", "Tracey")
str = str.replace("madeline", "Madeline")
str = str.replace("nyala", "Nyala")
str = str.replace("sarah", "Sarah")
str = str.replace("cadence", "Cadence")
str = str.replace("dani", "Dani")
str = str.replace("hib", "Hib")
str = str.replace("amara", "Amara")
str = str.replace("estelle", "Estelle")
return str
def build_limited_vocab(sentences, exclusions):
"""
Take a list of sentences and return a vocab
"""
vocab = ['<s>', '</s>']
for sentence in sentences:
for token in sentence:
if token not in vocab:
vocab.append(token)
try:
for word in exclusions:
vocab.remove(word)
except ValueError:
return vocab
return vocab
def train_limited_ngram_lm(n, exclusions):
"""
Train a n-gram language model as specified by the argument "n"
"""
lm = MLE(n)
data = load_data()
train = build_ngrams(n, data)
clean_train = []
for sentence in train:
new_sentence = [gram for gram in sentence if not any(excl in gram for excl in exclusions)]
clean_train.append(new_sentence)
vocab = build_limited_vocab(data, exclusions)
lm.fit(clean_train, vocab)
return lm
def multi_plot(num_sentences):
paragraph = []
for i in range(num_sentences):
paragraph.append(two_seed_generator([],[]))
return (" ".join(paragraph))
def two_seed_generator(feeder, exclusions):
seed_list = []
excl_list = []
if feeder != None:
seed_list = feeder.split()
if exclusions != None:
excl_list = exclusions.split()
if len(seed_list) == 0:
seed1 = "<s>"
seed2 = "<s>"
elif len(seed_list) == 1:
seed1 = "<s>"
seed2 = seed_list[0]
else:
seed1 = seed_list[0]
seed2 = seed_list[1]
n=3
num_words=60
text_seed=[seed1]+[seed2]
lm = train_limited_ngram_lm(n, excl_list)
output = []
if seed1 != "<s>":
output.append([seed1])
if seed2 != "<s>":
output.append([seed2])
success=0
tries=0
error_msg= '<:ruri_no:921119263547326515>'
while success==0:
try:
line = lm.generate(num_words, text_seed=text_seed)
success=1
except ValueError: # the generation is not always successful. need to capture exceptions
tries += 1
if tries < 50:
continue
else:
return error_msg
output.append(line)
out = [item for sublist in output for item in sublist]
str_out = " ".join(out)
sent_out = str_out.split('.')[0] +"."
cleaned = string_cleaner(sent_out)
return cleaned
def string_parser(str="<s>"):
seed_list = []
excl_list = []
if '-' in str:
left, right = str.split("-")
excl_list = right.split()
seed_list = left.split()
else:
seed_list = str.split()
return seed_list, excl_list
def plot_better(str):
seed_list, excl_list = string_parser(str)
output = two_seed_generator(seed_list, excl_list)
return output
# For "mad lib" fillers
def verb():
verbs=[]
with open("Assets/verbs_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
verbs.append(line)
output = random.choice(verbs)
return output
def noun():
noun=[]
with open("Assets/noun_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
noun.append(line)
output = random.choice(noun)
return output
def person():
person=[]
with open("Assets/person_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
person.append(line)
output = random.choice(person)
return output
def place():
place=[]
with open("Assets/places_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
place.append(line)
output = random.choice(place)
return output
def adj():
adj=[]
with open("Assets/adj_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
adj.append(line)
output = random.choice(adj)
return output
def adv():
adv=[]
with open("Assets/adv_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
adv.append(line)
output = random.choice(adv)
return output
def color():
color=[]
with open("Assets/color_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
color.append(line)
output = random.choice(color)
return output
def emote():
emote=[]
with open("Assets/emote_list.txt", "r", encoding="utf8") as f:
for line in f:
line = line[:-1]
emote.append(line)
output = random.choice(emote)
return output
def filler(str):
str_list = str.split()
out_list = []
for word in str_list:
if word.lower() == "verb":
out_list.append(verb())
elif word.lower() == "noun":
out_list.append(noun())
elif word.lower() == "person":
out_list.append(person())
elif word.lower() == "place":
out_list.append(place())
elif word.lower() == "adj":
out_list.append(adj())
elif word.lower() == "adv":
out_list.append(adv())
elif word.lower() == "color":
out_list.append(color())
elif word.lower() == "emotion":
out_list.append(emote())
else:
out_list.append(word)
str_out = " ".join(out_list)
return str_out