Skip to content

Commit

Permalink
Add option to return probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
lcmmichielsen committed Jan 5, 2024
1 parent 70a7db9 commit 5f35746
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 33 deletions.
1 change: 1 addition & 0 deletions scHPL/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import seaborn as sns

from .utils import TreeNode
# from utils import TreeNode

def hierarchical_F1(true_labels,
pred_labels,
Expand Down
8 changes: 6 additions & 2 deletions scHPL/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from .utils import TreeNode, create_tree, print_tree
from .predict import predict_labels
from .update import update_tree
# from train import train_tree
# from utils import TreeNode, create_tree, print_tree
# from predict import predict_labels
# from update import update_tree

try:
from typing import Literal
Expand Down Expand Up @@ -142,8 +146,8 @@ def learn_tree(data: AnnData,
distkNN)

# Predict labels other dataset
labels_2_pred = predict_labels(data_2, tree, threshold=rej_threshold)
labels_1_pred = predict_labels(data_1, tree_2, threshold=rej_threshold)
labels_2_pred,_ = predict_labels(data_2, tree, threshold=rej_threshold)
labels_1_pred,_ = predict_labels(data_1, tree_2, threshold=rej_threshold)

# Update first tree and labels second dataset
tree, mis_pop = update_tree(tree, labels_1.reshape(-1,1),
Expand Down
11 changes: 8 additions & 3 deletions scHPL/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from numpy import linalg as LA
from .utils import TreeNode
# from utils import TreeNode

def predict_labels(testdata,
tree: TreeNode,
Expand All @@ -26,6 +27,8 @@ def predict_labels(testdata,
Returns
------
Predicted labels
Probability of the predicted labels (only works for kNN, since SVM
doesn't return probabilities)
'''

useRE = False
Expand All @@ -50,8 +53,8 @@ def predict_labels(testdata,
dimred = True

labels_all = []
prob_all = np.zeros((np.shape(testdata)[0],1))
for idx, testpoint in enumerate(testdata):
# print(idx)
if useRE:
if rej_RE[idx]:
labels_all.append('Rejected (RE)')
Expand All @@ -71,7 +74,7 @@ def predict_labels(testdata,

### Reject cells based on distance
predict=True
dist,idx = parentnode.classifier.kneighbors(testpoint, return_distance=True)
dist,_ = parentnode.classifier.kneighbors(testpoint, return_distance=True)
if(np.mean(dist) > parentnode.get_maxDist()):
labels.append('Rejection (dist)')
predict=False
Expand All @@ -83,6 +86,7 @@ def predict_labels(testdata,
#If score higher than threshold -> iterate further over tree
if score > threshold:
labels.append(label[0])
prob_all[idx] = score
oldparent = parentnode
for n in parentnode.descendants:
if n.name[0] == label:
Expand Down Expand Up @@ -122,7 +126,8 @@ def predict_labels(testdata,
# Label cell with last predicted label
labels_all.append(labels[-1])

return np.asarray(labels_all)

return np.asarray(labels_all), prob_all

def _predict_node(testpoint, n, dimred):
'''Use the local classifier of a node to predict the label of a cell.
Expand Down
9 changes: 9 additions & 0 deletions scHPL/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.model_selection import StratifiedKFold
from scipy.stats import ttest_ind
from .utils import TreeNode
# from utils import TreeNode
import copy as cp

try:
Expand Down Expand Up @@ -350,3 +351,11 @@ def _find_negativesamples(labels, group, n):

return group









2 changes: 2 additions & 0 deletions scHPL/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
from .utils import TreeNode
from .evaluate import confusion_matrix
# from utils import TreeNode
# from evaluate import confusion_matrix

def update_tree(tree: TreeNode,
y_true1,
Expand Down
105 changes: 77 additions & 28 deletions vignettes/tutorial.ipynb

Large diffs are not rendered by default.

0 comments on commit 5f35746

Please sign in to comment.