Skip to content

Commit

Permalink
fixes in algos, fix evaluation, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommos0 committed Dec 13, 2019
1 parent 0b7c00c commit 9a9d8ee
Show file tree
Hide file tree
Showing 13 changed files with 274 additions and 604 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.egg-info
__pychache__
*.pyc
.ipynb_checkpoints
.ipynb_checkpoints
.idea
2 changes: 0 additions & 2 deletions .idea/.gitignore

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/inspectionProfiles/profiles_settings.xml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/libraries/R_User_Library.xml

This file was deleted.

94 changes: 0 additions & 94 deletions .idea/misc.xml

This file was deleted.

8 changes: 0 additions & 8 deletions .idea/modules.xml

This file was deleted.

14 changes: 0 additions & 14 deletions .idea/tadpole-algorithms.iml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/vcs.xml

This file was deleted.

222 changes: 222 additions & 0 deletions tadpole_algorithms/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import numpy as np
from datetime import datetime
from . import MAUC
from sklearn.metrics import confusion_matrix


def calcBCA(estimLabels, trueLabels, nrClasses):
# Balanced Classification Accuracy

bcaAll = []
for c0 in range(nrClasses):
for c1 in range(c0 + 1, nrClasses):
# c0 = positive class & c1 = negative class
TP = np.sum((estimLabels == c0) & (trueLabels == c0))
TN = np.sum((estimLabels == c1) & (trueLabels == c1))
FP = np.sum((estimLabels == c1) & (trueLabels == c0))
FN = np.sum((estimLabels == c0) & (trueLabels == c1))

# sometimes the sensitivity of specificity can be NaN, if the user doesn't forecast one of the classes.
# In this case we assume a default value for sensitivity/specificity
if (TP + FN) == 0:
sensitivity = 0.5
else:
sensitivity = (TP * 1.) / (TP + FN)

if (TN + FP) == 0:
specificity = 0.5
else:
specificity = (TN * 1.) / (TN + FP)

bcaCurr = 0.5 * (sensitivity + specificity)
bcaAll += [bcaCurr]
# print('bcaCurr %f TP %f TN %f FP %f FN %f' % (bcaCurr, TP, TN, FP, FN))

return np.mean(bcaAll)


def parseData(d4Df, forecastDf, diagLabels):
trueDiag = d4Df['Diagnosis']
trueADAS = d4Df['ADAS13']
trueVents = d4Df['Ventricles']

nrSubj = d4Df.shape[0]

zipTrueLabelAndProbs = []

hardEstimClass = -1 * np.ones(nrSubj, int)
adasEstim = -1 * np.ones(nrSubj, float)
adasEstimLo = -1 * np.ones(nrSubj, float) # lower margin
adasEstimUp = -1 * np.ones(nrSubj, float) # upper margin
ventriclesEstim = -1 * np.ones(nrSubj, float)
ventriclesEstimLo = -1 * np.ones(nrSubj, float) # lower margin
ventriclesEstimUp = -1 * np.ones(nrSubj, float) # upper margin

# print('subDf.keys()', forecastDf['Forecast Date'])
invalidResultReturn = (None, None, None, None, None, None, None, None, None, None, None)
invalidFlag = False
# for each subject in D4 match the closest user forecasts
for s in range(nrSubj):
currSubjMask = d4Df['RID'].iloc[s] == forecastDf['RID']
currSubjData = forecastDf[currSubjMask]

# if subject is missing
if currSubjData.shape[0] == 0:
print('WARNING: Subject RID %s missing from user forecasts' % d4Df['RID'].iloc[s])
invalidFlag = True
continue

# if not all forecast months are present
if currSubjData.shape[0] < 5 * 12: # check if at least 5 years worth of forecasts exist
print('WARNING: Missing forecast months for subject with RID %s' % d4Df['RID'].iloc[s])
invalidFlag = True
continue

currSubjData = currSubjData.reset_index(drop=True)

timeDiffsScanCog = [d4Df['CognitiveAssessmentDate'].iloc[s] - d for d in currSubjData['Forecast Date']]
# print('Forecast Date 2',currSubjData['Forecast Date'])
indexMin = np.argsort(np.abs(timeDiffsScanCog))[0]
# print('timeDiffsScanMri', indexMin, timeDiffsScanMri)

pCN = currSubjData['CN relative probability'].iloc[indexMin]
pMCI = currSubjData['MCI relative probability'].iloc[indexMin]
pAD = currSubjData['AD relative probability'].iloc[indexMin]

# normalise the relative probabilities by their sum
pSum = (pCN + pMCI + pAD) / 3
pCN /= pSum
pMCI /= pSum
pAD /= pSum

hardEstimClass[s] = np.argmax([pCN, pMCI, pAD])

adasEstim[s] = currSubjData['ADAS13'].iloc[indexMin]
adasEstimLo[s] = currSubjData['ADAS13 50% CI lower'].iloc[indexMin]
adasEstimUp[s] = currSubjData['ADAS13 50% CI upper'].iloc[indexMin]

# for the mri scan find the forecast closest to the scan date,
# which might be different from the cognitive assessment date
timeDiffsScanMri = [d4Df['ScanDate'].iloc[s] - d for d in currSubjData['Forecast Date']]
indexMinMri = np.argsort(np.abs(timeDiffsScanMri))[0]

ventriclesEstim[s] = currSubjData['Ventricles_ICV'].iloc[indexMinMri]
ventriclesEstimLo[s] = currSubjData['Ventricles_ICV 50% CI lower'].iloc[indexMinMri]
ventriclesEstimUp[s] = currSubjData['Ventricles_ICV 50% CI upper'].iloc[indexMinMri]
# print('%d probs' % d4Df['RID'].iloc[s], pCN, pMCI, pAD)

if isinstance(trueDiag.iloc[s], str) or not np.isnan(trueDiag.iloc[s]):
zipTrueLabelAndProbs += [(trueDiag.iloc[s], [pCN, pMCI, pAD])]

if invalidFlag:
# if at least one subject was missing or if
raise ValueError('Submission was incomplete. Please resubmit')

# If there are NaNs in D4, filter out them along with the corresponding user forecasts
# This can happen if rollover subjects don't come for visit in ADNI3.
notNanMaskDiag = np.logical_not(np.isnan(trueDiag))
trueDiagFilt = trueDiag[notNanMaskDiag]
hardEstimClassFilt = hardEstimClass[notNanMaskDiag]

notNanMaskADAS = np.logical_not(np.isnan(trueADAS))
trueADASFilt = trueADAS[notNanMaskADAS]
adasEstim = adasEstim[notNanMaskADAS]
adasEstimLo = adasEstimLo[notNanMaskADAS]
adasEstimUp = adasEstimUp[notNanMaskADAS]

notNanMaskVents = np.logical_not(np.isnan(trueVents))
trueVentsFilt = trueVents[notNanMaskVents]
ventriclesEstim = ventriclesEstim[notNanMaskVents]
ventriclesEstimLo = ventriclesEstimLo[notNanMaskVents]
ventriclesEstimUp = ventriclesEstimUp[notNanMaskVents]

assert trueDiagFilt.shape[0] == hardEstimClassFilt.shape[0]
assert trueADASFilt.shape[0] == adasEstim.shape[0] == adasEstimLo.shape[0] == adasEstimUp.shape[0]
assert trueVentsFilt.shape[0] == ventriclesEstim.shape[0] == \
ventriclesEstimLo.shape[0] == ventriclesEstimUp.shape[0]

return zipTrueLabelAndProbs, hardEstimClassFilt, adasEstim, adasEstimLo, adasEstimUp, \
ventriclesEstim, ventriclesEstimLo, ventriclesEstimUp, trueDiagFilt, trueADASFilt, trueVentsFilt


def evaluate_forecast(d4Df, forecastDf):
"""
Evaluates one submission.
Parameters
----------
d4Df - Pandas data frame containing the D4 dataset
subDf - Pandas data frame containing user forecasts for D2 subjects.
Returns
-------
mAUC - multiclass Area Under Curve
bca - balanced classification accuracy
adasMAE - ADAS13 Mean Aboslute Error
ventsMAE - Ventricles Mean Aboslute Error
adasCovProb - ADAS13 Coverage Probability for 50% confidence interval
ventsCovProb - Ventricles Coverage Probability for 50% confidence interval
"""
# d4Df['CognitiveAssessmentDate'] = [datetime.strptime(x, '%Y-%m-%d') for x in d4Df['CognitiveAssessmentDate']]
# d4Df['ScanDate'] = [datetime.strptime(x, '%Y-%m-%d') for x in d4Df['ScanDate']]

forecastDf['Forecast Date'] = [datetime.strptime(x, '%Y-%m-%d') for x in forecastDf['Forecast Date']]
# todo: considers every month estimate to be the actual first day 2017-01

d4Df['CognitiveAssessmentDate'] = [datetime.strptime(x, '%Y-%m-%d') for x in d4Df['CognitiveAssessmentDate']]
d4Df['ScanDate'] = [datetime.strptime(x, '%Y-%m-%d') for x in d4Df['ScanDate']]
mapping = {'CN': 0, 'MCI': 1, 'AD': 2}
d4Df.replace({'Diagnosis': mapping}, inplace=True)

diagLabels = ['CN', 'MCI', 'AD']

zipTrueLabelAndProbs, hardEstimClass, adasEstim, adasEstimLo, adasEstimUp, \
ventriclesEstim, ventriclesEstimLo, ventriclesEstimUp, trueDiagFilt, trueADASFilt, trueVentsFilt = \
parseData(d4Df, forecastDf, diagLabels)
zipTrueLabelAndProbs = list(zipTrueLabelAndProbs)

########## compute metrics for the clinical status #############

##### Multiclass AUC (mAUC) #####

nrClasses = len(diagLabels)
mAUC = MAUC.MAUC(zipTrueLabelAndProbs, num_classes=nrClasses)

### Balanced Classification Accuracy (BCA) ###
# print('hardEstimClass', np.unique(hardEstimClass), hardEstimClass)
trueDiagFilt = trueDiagFilt.astype(int)
# print('trueDiagFilt', np.unique(trueDiagFilt), trueDiagFilt)
bca = calcBCA(hardEstimClass, trueDiagFilt, nrClasses=nrClasses)

## Confusion matrix ## Added by Esther Bron
conf = confusion_matrix(hardEstimClass, trueDiagFilt.values, [0, 1, 2])
conf = np.transpose(conf) # Transposed to match confusion matrix on web site

print(conf)

####### compute metrics for Ventricles and ADAS13 ##########

#### Mean Absolute Error (MAE) #####

adasMAE = np.mean(np.abs(adasEstim - trueADASFilt))
ventsMAE = np.mean(np.abs(ventriclesEstim - trueVentsFilt))

##### Weighted Error Score (WES) ####
adasCoeffs = 1 / (adasEstimUp - adasEstimLo)
adasWES = np.sum(adasCoeffs * np.abs(adasEstim - trueADASFilt)) / np.sum(adasCoeffs)

ventsCoeffs = 1 / (ventriclesEstimUp - ventriclesEstimLo)
ventsWES = np.sum(ventsCoeffs * np.abs(ventriclesEstim - trueVentsFilt)) / np.sum(ventsCoeffs)

#### Coverage Probability Accuracy (CPA) ####

adasCovProb = (np.sum((adasEstimLo < trueADASFilt) &
(adasEstimUp > trueADASFilt)) * 1.) / trueADASFilt.shape[0]
adasCPA = np.abs(adasCovProb - 0.5)

ventsCovProb = (np.sum((ventriclesEstimLo < trueVentsFilt) &
(ventriclesEstimUp > trueVentsFilt)) * 1.) / trueVentsFilt.shape[0]
ventsCPA = np.abs(ventsCovProb - 0.5)

return mAUC, bca, adasMAE, ventsMAE, adasWES, ventsWES, adasCPA, ventsCPA, adasEstim, trueADASFilt
Loading

0 comments on commit 9a9d8ee

Please sign in to comment.