-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathcluster_embeddings.py
executable file
·77 lines (71 loc) · 3.52 KB
/
cluster_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import re
import time
import io
import sys
import argparse
from collections import defaultdict
# parse/validate arguments
argparser = argparse.ArgumentParser()
argparser.add_argument("-i", "--input_filename", help='input word vector representations.', required=True)
argparser.add_argument("-o", "--output_filename", help='output word cluster bitstrings.', required=True)
args = argparser.parse_args()
print "Clustering word vectors..."
temp_filename = args.output_filename + "temp"
os.system('matlab -nodesktop -nosplash -nojvm -nodisplay -r "cluster_embeddings({}, {}); exit"'.format("'"+ args.input_filename + "'", "'" + temp_filename + "'"))
# read branchings (bottom up)
print "Reading the hierarchy..."
parent_to_children_cluster_ids = {}
child_to_parent_and_direction = {}
with io.open(temp_filename, encoding='utf8') as branchings_file:
# first line consists of an integer, the number of vectors in input file
vectors_count = int(branchings_file.readline())
assert(vectors_count > 0)
lines_counter = 1
for line in branchings_file:
parent_cluster_id = lines_counter + vectors_count
left_cluster_id, right_cluster_id, distance = line.split()
left_cluster_id, right_cluster_id, distance = int(left_cluster_id), int(right_cluster_id), float(distance)
parent_to_children_cluster_ids[parent_cluster_id] = (left_cluster_id, right_cluster_id,)
child_to_parent_and_direction[left_cluster_id] = (parent_cluster_id, 0)
child_to_parent_and_direction[right_cluster_id] = (parent_cluster_id, 1)
lines_counter += 1
assert(lines_counter == vectors_count)
# accumulate bit strings (top down)
print "Computing bitstrings..."
root_cluster_id = 2 * vectors_count - 1
assert(root_cluster_id not in child_to_parent_and_direction)
assert(root_cluster_id in parent_to_children_cluster_ids)
cluster_id_to_bitstring = {root_cluster_id:''}
traversal_stack = [root_cluster_id]
nodes_counter = 0
while len(traversal_stack):
current_cluster_id = traversal_stack.pop()
# leaf?
if current_cluster_id <= vectors_count:
assert(current_cluster_id in child_to_parent_and_direction)
assert(current_cluster_id not in parent_to_children_cluster_ids)
continue
# compute children's bitstrings and add them to the stack
assert(current_cluster_id in parent_to_children_cluster_ids)
assert(current_cluster_id in cluster_id_to_bitstring)
current_cluster_bitstring = cluster_id_to_bitstring[current_cluster_id]
left_cluster_id, right_cluster_id = parent_to_children_cluster_ids[current_cluster_id]
left_cluster_bitstring, right_cluster_bitstring = current_cluster_bitstring + '0', current_cluster_bitstring + '1'
cluster_id_to_bitstring[left_cluster_id], cluster_id_to_bitstring[right_cluster_id] = left_cluster_bitstring, right_cluster_bitstring
traversal_stack.append(left_cluster_id)
traversal_stack.append(right_cluster_id)
nodes_counter += 1
assert(len(cluster_id_to_bitstring) == 2 * vectors_count - 1)
# persist.
print "Writing bitstrings to file..."
with io.open(args.input_filename, encoding='utf8') as word_vectors_file, io.open(args.output_filename, encoding='utf8', mode='w') as cluster_bitstrings_file:
metadata = word_vectors_file.readline();
print 'the first line in {} which reads "{}" has been ignored'.format(args.input_filename, metadata.strip())
lines_counter = 1
for line in word_vectors_file:
current_word = line.split(' ')[0]
current_bitstring = cluster_id_to_bitstring[lines_counter]
cluster_bitstrings_file.write(u'{} {}\n'.format(current_word, current_bitstring))
lines_counter += 1
print "Done."