Skip to content

Commit

Permalink
Change input to anndata
Browse files Browse the repository at this point in the history
  • Loading branch information
lcmmichielsen committed Jul 19, 2021
1 parent 802af82 commit ee35f64
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 124 deletions.
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy~=1.19.2
scipy~=1.5.2
scikit-learn~=0.23.2
pandas~=1.1.2
newick~=1.0.0
numpy>=1.19.2
scipy>=1.5.2
scikit-learn>=0.23.2
pandas>=1.1.2
newick~=1.0.0
anndata>=0.7.4
4 changes: 2 additions & 2 deletions scHPL/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 6 09:20:51 2021
Created on Mon Jul 19 10:31:39 2021
@author: lcmmichielsen
"""

from .evaluate import hierarchical_F1, confusion_matrix
from .predict import predict_labels
from .progressive_learning import learn_tree
from .learn import learn_tree
from .train import train_tree
from .update import update_tree
from .utils import TreeNode, add_node, create_tree, print_tree, read_tree
156 changes: 156 additions & 0 deletions scHPL/learn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 22 14:11:01 2019
@author: Lieke
"""

import numpy as np

import anndata

from .train import train_tree
from .utils import TreeNode, create_tree, print_tree
from .predict import predict_labels
from .update import update_tree

def learn_tree(data: anndata,
batch_key: str,
batch_order: list,
cell_type_key: str,
tree: TreeNode = None,
retrain: bool = False,
batch_added: list = None,
classifier: str = 'svm_occ',
dimred: bool = True,
useRE: bool = True,
FN: float = 1,
threshold: float = 0.25,
return_missing: bool = True
):

'''
Apply the hierarchical progressive learning pipeline
Parameters
----------
data: Anndata
anndata object
batch_key: String
Key where the batches in the data can be found.
batch_order: List
List containing the order in which the batches should be added
to the tree.
cell_type_key: String
Key where the celltype labels in the data can be found.
tree: TreeNode = None
Tree to start updating.
retrain: Boolean = False
If 'True', the inputted tree will be retrained (needed if tree or
datasets are changed after intial construction).
batch_added: List = None
Indicates which batches were used to build the existing tree.
classifier: String = 'svm_occ'
Classifier to use (either 'svm' or 'svm_occ').
dimred: Boolean = True
If 'True' PCA is applied before training the classifier.
useRE: Boolean = True
If 'True', cells are also rejected based on the reconstruction error.
FN: Float = 1
Percentage of false negatives allowed when determining the threshold
for the reconstruction error.
threshold: Float = 0.25
Threshold to use when matching the labels.
return_missing: Boolean = True
If 'True' missing nodes are returned to the user, else missing
nodes are attached to the root node.
Return
------
tree_c: trained classification tree
'''

missing_pop=[]

xx = data.X
labels = np.array(data.obs[cell_type_key].values, dtype=str)
batches = data.obs[batch_key]

if(tree == None):
tree = create_tree('root')
firstbatch = batch_order[0]
batch_order = batch_order[1:]
idx_1 = np.where(batches == firstbatch)[0]
labels_1 = labels[idx_1]
tree = _construct_tree(tree, labels_1)
retrain = True
else:
idx_1 = np.isin(batches, batch_added)

labels_1 = labels[idx_1]
data_1 = xx[idx_1]

for b in batch_order:

print('Adding dataset', str(b), 'to the tree')

idx_2 = np.where(batches == b)[0]
data_2 = xx[idx_2]
labels_2 = labels[idx_2]
tree_2 = create_tree('root2')
tree_2 = _construct_tree(tree_2, labels_2)

# Train the trees
if retrain:
tree = train_tree(data_1, labels_1, tree, classifier, dimred, useRE, FN)
else:
retrain = True

tree_2 = train_tree(data_2, labels_2, tree_2, classifier, dimred, useRE, FN)

# Predict labels other dataset
labels_2_pred = predict_labels(data_2, tree)
labels_1_pred = predict_labels(data_1, tree_2)

# Update first tree and labels second dataset
tree, labels_2_new, mis_pop = update_tree(labels_1.reshape(-1,1),
labels_1_pred.reshape(-1,1),
labels_2.reshape(-1,1),
labels_2_pred.reshape(-1,1),
threshold, tree,
return_missing = return_missing)
missing_pop.extend(mis_pop)

print('Updated tree:')
print_tree(tree)

#concatenate the two datasets
data_1 = np.concatenate((data_1, data_2), axis = 0)
labels_1 = np.concatenate((np.squeeze(labels_1), np.squeeze(labels_2_new)),
axis = 0)




# Train the final tree
tree = train_tree(data_1, labels_1, tree, classifier, dimred, useRE, FN)

if return_missing:
return tree, missing_pop
else:
return tree



def _construct_tree(tree, labels):
'''
Construct a flat tree
'''

unique_labels = np.unique(labels)

for ul in unique_labels:
newnode = TreeNode(ul)
tree[0].add_descendant(newnode)

return tree
5 changes: 1 addition & 4 deletions scHPL/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
import numpy as np
from numpy import linalg as LA
import pandas as pd
from .utils import TreeNode

def predict_labels(testdata, tree):
Expand Down Expand Up @@ -36,18 +35,16 @@ def predict_labels(testdata, tree):

RE_error2 = LA.norm(testdata - test_rec, axis = 1)
rej_RE = RE_error2 > t
# print("Cells rejected using RE: ", np.sum(rej_RE))

# Do PCA if needed
dimred = False
if tree[0].get_dimred():
pca, pcs = tree[0].get_pca()
testdata = pca.transform(testdata)
testdata = pd.DataFrame(testdata)
dimred = True

labels_all = []
for idx, testpoint in enumerate(testdata.values):
for idx, testpoint in enumerate(testdata):

if useRE:
if rej_RE[idx]:
Expand Down
98 changes: 0 additions & 98 deletions scHPL/progressive_learning.py

This file was deleted.

21 changes: 9 additions & 12 deletions scHPL/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np
from numpy import linalg as LA
import pandas as pd
from sklearn import svm
from sklearn.decomposition import PCA
from sklearn.utils._testing import ignore_warnings
Expand Down Expand Up @@ -54,8 +53,8 @@ def train_tree(data, labels, tree, classifier = 'svm_occ', dimred = True, useRE

for trainindex, testindex in sss.split(data, labels):

train = data.iloc[trainindex]
test = data.iloc[testindex]
train = data[trainindex]
test = data[testindex]

pca = PCA(n_components = num_components, random_state = 0)
pca.fit(train)
Expand All @@ -82,7 +81,6 @@ def train_tree(data, labels, tree, classifier = 'svm_occ', dimred = True, useRE
tree[0].set_dimred(True)

data = pca.transform(data)
data = pd.DataFrame(data)

#recursively train the classifiers for each node in the tree
for n in tree[0].descendants:
Expand Down Expand Up @@ -141,11 +139,11 @@ def _find_pcs(data, labels, group, n, numgenes):

# positive samples
this_class = np.where(group == 1)[0]
this_data = data.iloc[this_class]
this_data = data[this_class]

# negative samples
other_class = np.where(group == 2)[0]
other_data = data.iloc[other_class]
other_data = data[other_class]

statistic, pvalue = ttest_ind(this_data, other_data, equal_var = False)

Expand All @@ -155,9 +153,7 @@ def _find_pcs(data, labels, group, n, numgenes):
if len(explaining_pcs) == 0:
explaining_pcs = np.argsort(pvalue)[:5]

# print(n.name, ': ', len(explaining_pcs))

data = data.iloc[:,explaining_pcs]
data = data[:,explaining_pcs]

# Save the explaining pcs in the tree
n.set_pca(None, explaining_pcs)
Expand All @@ -180,10 +176,11 @@ def _train_svm(data, labels, group, n):
# group == 2 --> negative samples
group = _find_negativesamples(labels, group, n)
idx_svm = np.where((group == 1) | (group == 2))[0]
data_svm = data.iloc[idx_svm]
data_svm = data[idx_svm]
group_svm = group[idx_svm]

clf = svm.LinearSVC(random_state=1).fit(data_svm, group_svm)

n.set_classifier(clf) #save classifier to the node


Expand All @@ -199,7 +196,7 @@ def _train_occ(data, labels, group, n):
n: node
'''

data_group = data.iloc[np.where(group == 1)[0]]
data_group = data[np.where(group == 1)[0]]

clf = svm.OneClassSVM(gamma = 'scale', nu = 0.05).fit(data_group)
n.set_classifier(clf) #save classifier to the node
Expand Down
Loading

0 comments on commit ee35f64

Please sign in to comment.