-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotnormalisedcossim - old.py
119 lines (97 loc) · 4.02 KB
/
plotnormalisedcossim - old.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 21 09:43:22 2021
@author: aluenen
"""
import gensim
from gensim.models.word2vec import Word2Vec
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
import matplotlib.pyplot as plt
import os
print(os.getcwd())
import statistics
import numpy as np
#compass = Word2Vec.load("D:/Newsdata/Models/compass.model")
normalisationfactors = {}
for year in range(2010, 2019):
for part in range(1, 3):
yearpart = str(year) + '_' + str(part)
filename = "D:/Newsdata/Models/" + yearpart + '.model'
w2v = Word2Vec.load(filename)
distances = []
for word in w2v.wv.vocab.keys():
distances.append(w2v.wv.similarity('kanker', word))
normalisationfactors[yearpart] = []
normalisationfactors[yearpart].append(statistics.mean(distances))
normalisationfactors[yearpart].append(np.std(distances))
print(normalisationfactors)
#print(compass.wv.most_similar('kanker', topn=20))
#10 nearest neighbours van kanker die GEEN kankersoort zijn
compassnns = ['alzheimer', 'ziekte', 'reuma', 'suikerziekte',\
'tuberculose']#, 'aids', 'malaria', 'epilepsie', 'hiv']
#10 nearest neighbours borstkanker
#compassnns = ['suikerziekte', 'alzheimer', 'reuma', 'trombose', 'anorexia']
#10 nearest neighbours darmkanker
#compassnns = ['trombose', 'hartfalen', 'suikerziekte', 'melanoom', 'aderverkalking']
#nearest neighbours baarmoederhalskanker
#compassnns = ['taaislijmziekte', 'HPV', 'hiv', 'kinkhoest', 'melanoom', 'trombose']
data = {}
#structure data {woord:[cossim, cossim], woord:[cossim, ]}
for year in range(2010, 2019):
for part in range(1, 3):
yearpart = str(year) + '_' + str(part)
filename = "D:/Newsdata/Models/" + yearpart + '.model'
#print(filename)
#try:
w2v = Word2Vec.load(filename)
for word in compassnns:
if word not in data.keys():
print(w2v.wv.similarity('kanker', word))
print(normalisationfactors[yearpart])
print((w2v.wv.similarity('kanker', word)-normalisationfactors[yearpart][0])/normalisationfactors[yearpart][1])
data[word] = [(w2v.wv.similarity('kanker', word)-normalisationfactors[yearpart][0])/normalisationfactors[yearpart][1]]
else:
print(w2v.wv.similarity('kanker', word))
print(normalisationfactors[yearpart])
print((w2v.wv.similarity('kanker', word)-normalisationfactors[yearpart][0])/normalisationfactors[yearpart][1])
data[word].append((w2v.wv.similarity('kanker', word)-normalisationfactors[yearpart][0])/normalisationfactors[yearpart][1])
#except:
# print(yearpart)
# for word in compassnns:
# if word not in data.keys():
# data[word] = [None]
# else:
# data[word].append(None)
print(data)
x = []
for year in range(2010, 2019):
for part in range(1, 3):
yearpart = str(year) + '_' + str(part)
x.append(yearpart)
#print(x)
for word in data:
#print(data[word])
plt.plot(x, data[word], label=word)
plt.xticks(rotation=45)
#ax.set_xticks(['2010_1', '2011_1', '2012_1', '2013_1', '2014_1', '2015_1', \
# '2016_1', '2017_1', '2018_1'])
#ax.set_xlabels([2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018])
#plt.set_xticks([i for i in range(2010, 2019)])
plt.xlabel('Year')
plt.ylabel('Normalised Cosine similarity')
#plt.ylim(ymax = 1, ymin = 0.5)
plt.title("Normalised cosine similarity between 'kanker' and nearest neighbours")
#ax = plt.gca()
#print(plt.xticks_marks)
#fix legend
ax = plt.subplot(111)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
# Put a legend to the right of the current axis
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
#plt.legend(loc='lower left')
# function to show the plot
plt.show()