-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2421d9b
commit 1572a1e
Showing
5 changed files
with
286 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from scipy.spatial import KDTree | ||
import numpy as np | ||
import sys | ||
from dbloader import EntityLoader, yagoReader, csvRead, \ | ||
convertToEmbeddings, loadGensim | ||
|
||
|
||
class ShortestWord2VecDistanceClassifier: | ||
def __init__(self, threshold, target_words, target_embeddings): | ||
self.threshold = threshold | ||
self.target_words = target_words | ||
self.vec_tree = KDTree(target_embeddings) | ||
|
||
def closest_word(self, embeddings): | ||
distances, indices = self.vec_tree.query(embeddings) | ||
results = [yago_entities[i] if d < self.threshold else None | ||
for d, i in zip(distances, indices)] | ||
return results | ||
|
||
def closest_word_with_distance(self, embeddings): | ||
distances, indices = self.vec_tree.query(embeddings) | ||
results = [(yago_entities[i], d) if d < self.threshold else (None, d) | ||
for d, i in zip(distances, indices)] | ||
return results | ||
|
||
|
||
# so that python doesn't explode | ||
sys.setrecursionlimit(10000) | ||
|
||
|
||
yago_obj = EntityLoader('data/yagoFacts.tsv', yagoReader) | ||
yago_obj.cache() | ||
|
||
wiki_obj = EntityLoader('data/wikipedia-full-reverb.txt', csvRead) | ||
wiki_obj.cache() | ||
|
||
yago = yago_obj._df | ||
wiki = wiki_obj._df | ||
|
||
gensim = loadGensim() | ||
|
||
yago_entities = yago.iloc[:, 0].append(yago.iloc[:, 2]).unique() | ||
yago_entity_embeddings = convertToEmbeddings(yago_entities, gensim) | ||
|
||
|
||
# note we can make this 3 times faster by only calculating the mappings | ||
# for unique wiki_entries (only 30% of the total) | ||
|
||
wiki_entities_1 = wiki.iloc[:, 0] | ||
wiki_entity_1_embeddings = convertToEmbeddings(wiki_entities_1, gensim) | ||
|
||
wiki_entities_2 = wiki.iloc[:, 2] | ||
wiki_entity_2_embeddings = convertToEmbeddings(wiki_entities_2, gensim) | ||
|
||
model = ShortestWord2VecDistanceClassifier(threshold=1, | ||
target_words=yago_entities, | ||
target_embeddings=yago_entity_embeddings) | ||
|
||
wiki['e1p'] = model.closest_word(wiki_entity_1_embeddings) | ||
wiki['e2p'] = model.closest_word(wiki_entity_2_embeddings) | ||
|
||
wiki.to_csv('wiki_pred.tsv', sep='\t', index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from scipy.spatial import KDTree | ||
import os | ||
import pickle | ||
import time | ||
from dbloader import EntityLoader, yagoReader, csvRead, \ | ||
convertToEmbeddings, loadGensim | ||
|
||
|
||
# jumping through all these hoops to avoid loading gensim into memory | ||
# unless we absolutely have to | ||
|
||
def load_or_pickle(pickle_file, load_func, *args): | ||
if pickle_file in os.listdir(): | ||
with open(pickle_file, 'rb') as f: | ||
obj = pickle.load(f) | ||
else: | ||
with open(pickle_file, 'wb') as f: | ||
obj = load_func(*args) | ||
pickle.dump(obj, f) | ||
return obj | ||
|
||
|
||
def load_yago_entities(yago): | ||
return yago.iloc[:, 0].append(yago.iloc[:, 2]).unique() | ||
|
||
|
||
def load_embeddings(entities): | ||
gensim = loadGensim() | ||
return convertToEmbeddings(entities, gensim) | ||
|
||
|
||
def get_wiki_entities(wiki, index): | ||
return wiki.iloc[:, index] | ||
|
||
|
||
class ShortestWord2VecDistanceClassifier: | ||
def __init__(self, threshold, target_words, target_embeddings): | ||
self.threshold = threshold | ||
self.target_words = target_words | ||
self.vec_tree = KDTree(target_embeddings) | ||
|
||
def closest_word(self, embeddings): | ||
distances, indices = self.vec_tree.query(embeddings) | ||
results = [self.target_words[i] if d < self.threshold else None | ||
for d, i in zip(distances, indices)] | ||
return results | ||
|
||
def closest_word_with_distance(self, embeddings): | ||
distances, indices = self.vec_tree.query(embeddings) | ||
results = [(self.target_words[i], d) if d < self.threshold else (None, d) | ||
for d, i in zip(distances, indices)] | ||
return results | ||
|
||
def closest_word_single(self, embedding): | ||
distance, index = self.vec_tree.query(embedding) | ||
if distance < self.threshold: | ||
return self.target_words[index] | ||
else: | ||
return None | ||
|
||
|
||
yago_obj = EntityLoader('data/yagoFacts.tsv', yagoReader) | ||
yago_obj.cache() | ||
|
||
wiki_obj = EntityLoader('data/wikipedia-full-reverb.txt', csvRead) | ||
wiki_obj.cache() | ||
|
||
yago = yago_obj._df | ||
wiki = wiki_obj._df | ||
|
||
yago_entities = load_or_pickle('yago_entities.pickle', load_yago_entities, yago) | ||
yago_entity_embeddings = load_or_pickle('yago_entity_embeddings.pickle', load_embeddings, yago_entities) | ||
|
||
wiki_entities_1 = load_or_pickle('wiki_entities_1.pickle', get_wiki_entities, wiki, 0) | ||
wiki_entity_1_embeddings = load_or_pickle('wiki_entity_1_embeddings.pickle', load_embeddings, wiki_entities_1) | ||
|
||
wiki_entities_2 = load_or_pickle('wiki_entities_2.pickle', get_wiki_entities, wiki, 2) | ||
wiki_entity_2_embeddings = load_or_pickle('wiki_entity_2_embeddings.pickle', load_embeddings, wiki_entities_2) | ||
|
||
model = ShortestWord2VecDistanceClassifier(threshold=1, | ||
target_words=yago_entities, | ||
target_embeddings=yago_entity_embeddings) | ||
|
||
|
||
def wiki_unique_entitiy_map(wiki_entities_1, wiki_entities_2): | ||
wiki_entities_unique = list(wiki_entities_1.unique()) + list(wiki_entities_2.unique()) | ||
wiki_embeddings_unique = load_embeddings(wiki_entities_unique) | ||
unique_entity_map = {ent: (emb, '<UNK>') for ent, emb in zip(wiki_entities_unique, wiki_embeddings_unique)} | ||
return unique_entity_map | ||
|
||
|
||
if __file__ == '__main__': | ||
|
||
wiki_unique_entitiy_map = load_or_pickle('wiki_unique_entitiy_map.pickle', | ||
wiki_unique_entitiy_map, | ||
wiki_entities_1, wiki_entities_2) | ||
|
||
start = time.time() | ||
i = 0 | ||
completed = 0 | ||
chk = int(len(wiki_unique_entitiy_map) / 100) | ||
for entity in wiki_unique_entitiy_map: | ||
i += 1 | ||
embedding, target_class = wiki_unique_entitiy_map[entity] | ||
# don't recalculate data | ||
if target_class != '<UNK>': | ||
pass | ||
else: | ||
target_class = model.closest_word_single(embedding) | ||
wiki_unique_entitiy_map[entity] = (embedding, target_class) | ||
|
||
# make a checkpoint every 1 % | ||
if i % chk == 0: | ||
end = time.time() | ||
print("Checkpoint: {}, took {}".format(i, end - start)) | ||
with open('wiki_unique_entitiy_map.pickle', 'wb') as f: | ||
pickle.dump(wiki_unique_entitiy_map, f) | ||
start = end | ||
|
||
with open('wiki_unique_entitiy_map.pickle', 'wb') as f: | ||
pickle.dump(wiki_unique_entitiy_map, f) | ||
|
||
entity_map = {entity: wiki_unique_entitiy_map[entity][1] for entity in wiki_unique_entitiy_map} | ||
with open('entity_map.json', 'w') as fp: | ||
json.dump(entity_map, fp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Taken from: | ||
# https://stackoverflow.com/questions/42653386/does-pickle-randomly-fail-with-oserror-on-large-files?rq=1 | ||
import pickle | ||
import sys | ||
import os | ||
|
||
|
||
def save_as_pickled_object(obj, filepath): | ||
""" | ||
This is a defensive way to write pickle.write, | ||
allowing for very large files on all platforms | ||
""" | ||
max_bytes = 2**31 - 1 | ||
bytes_out = pickle.dumps(obj) | ||
n_bytes = sys.getsizeof(bytes_out) | ||
with open(filepath, 'wb') as f_out: | ||
for idx in range(0, n_bytes, max_bytes): | ||
f_out.write(bytes_out[idx:idx + max_bytes]) | ||
|
||
|
||
def try_to_load_as_pickled_object_or_None(filepath): | ||
""" | ||
This is a defensive way to write pickle.load, | ||
allowing for very large files on all platforms | ||
""" | ||
max_bytes = 2**31 - 1 | ||
try: | ||
input_size = os.path.getsize(filepath) | ||
bytes_in = bytearray(0) | ||
with open(filepath, 'rb') as f_in: | ||
for _ in range(0, input_size, max_bytes): | ||
bytes_in += f_in.read(max_bytes) | ||
obj = pickle.loads(bytes_in) | ||
except: | ||
return None | ||
return obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import json | ||
from dbloader import EntityLoader, csvRead, yagoReader | ||
import numpy as np | ||
|
||
with open('./data/entity_map_2.json', 'rb') as fp: | ||
entity_map = json.loads(fp.read()) | ||
|
||
|
||
def lookup_entity_map(entity): | ||
if entity is not None: | ||
return entity_map[entity] | ||
else: | ||
return None | ||
|
||
|
||
yago_obj = EntityLoader('data/yagoFacts.tsv', yagoReader) | ||
yago_obj.cache() | ||
|
||
wiki_obj = EntityLoader('data/wikipedia-full-reverb.txt', csvRead) | ||
wiki_obj.cache() | ||
|
||
yago = yago_obj._df | ||
wiki = wiki_obj._df | ||
|
||
wiki = wiki.dropna() | ||
|
||
wiki['e1p'] = wiki['e1'].apply(lookup_entity_map) | ||
wiki['e2p'] = wiki['e2'].apply(lookup_entity_map) | ||
|
||
wiki_no_none = wiki[wiki['e1p'].notnull() & wiki['e2p'].notnull()][['e1p', 'rel', 'e2p']] | ||
|
||
df_merge = wiki_no_none.merge(yago, left_on=['e1p', 'e2p'], right_on=['e1', 'e2']) | ||
|
||
df_merge[['rel_x', 'rel_y']].drop_duplicates().to_csv('merged_no_dupes.tsv', index=False, sep='\t') |