-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathjensenshannon.py
32 lines (23 loc) · 980 Bytes
/
jensenshannon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/env python
# -*- coding: utf-8 -*-
""" Computes the Jenson-Shannon divergence between Settings A and B """
import sys
import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon
data = pd.read_csv(sys.argv[1])
pretraining = int(sys.argv[2]) if len(sys.argv) > 2 else 200
print("Pretraining =", pretraining)
data = data[data['pretraining'] == pretraining]
for model in ["MLP","GCN","GCN-64","GraphSAGE", "GAT"]:
df_model = data[data['model'] == model]
# jsdivs = []
# for epoch in range(33):
# relevant = subset[subset['epoch'] == epoch]
settingA = df_model[df_model['setting'] == 'A']
settingB = df_model[df_model['setting'] == 'B']
accuracyA = settingA['accuracy'].values
accuracyB = settingB['accuracy'].values
# jsdivs.append(jensenshannon(accuracyA, accuracyB))
print("\tModel:", model)
print("\tJensen-Shannon divergence between A/B: %.4f" % jensenshannon(accuracyA, accuracyB))