Skip to content

Commit

Permalink
Merge pull request lcmmichielsen#13 from Mye-InfoBank/master
Browse files Browse the repository at this point in the history
Implement FAISS GPU support
  • Loading branch information
lcmmichielsen authored Apr 30, 2024
2 parents a0b244d + ac6652d commit e7f3a2e
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 33 deletions.
10 changes: 9 additions & 1 deletion scHPL/faissKNeighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@
import numpy as np

class FaissKNeighbors:
def __init__(self, k=50):
def __init__(self, k=50, gpu=None):
self.index = None
self.y = None
self.k = k
self.gpu = gpu

def fit(self, X, y):
self.index = faiss.IndexFlatL2(X.shape[1])
if self.gpu is not None:
self.to_gpu(self.gpu)

self.index.add(X.astype(np.float32))
self.y = y

def to_gpu(self, gpu):
res = faiss.StandardGpuResources()
self.index = faiss.index_cpu_to_gpu(res, gpu, self.index)

def predict(self, X):
distances, indices = self.index.search(X.astype(np.float32), k=self.k)
votes = self.y[indices]
Expand Down
15 changes: 9 additions & 6 deletions scHPL/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
# from update import update_tree

try:
from typing import Literal
from typing import Literal, Optional
except ImportError:
from typing_extensions import Literal
from typing_extensions import Literal, Optional


def learn_tree(data: AnnData,
Expand All @@ -36,11 +36,12 @@ def learn_tree(data: AnnData,
distkNN: int = 99,
dimred: bool = False,
useRE: bool = True,
FN: float = 0.5,
FN: float = 0.5,
rej_threshold: float = 0.5,
match_threshold: float = 0.25,
attach_missing: bool = False,
print_conf: bool = False
print_conf: bool = False,
gpu: Optional[int] = None
):

'''Learn a classification tree based on multiple labeled datasets.
Expand Down Expand Up @@ -93,6 +94,8 @@ def learn_tree(data: AnnData,
If 'True' missing nodes are attached to the root node.
print_conf: Boolean = False
Whether to print the confusion matrices during the matching step.
gpu: int = None
GPU index to use for the Faiss library (only used when classifier='knn')
Returns
-------
Expand Down Expand Up @@ -137,13 +140,13 @@ def learn_tree(data: AnnData,
if retrain:
tree = train_tree(data_1, labels_1, tree, classifier,
dimred, useRE, FN, n_neighbors, dynamic_neighbors,
distkNN)
distkNN, gpu=gpu)
else:
retrain = True

tree_2 = train_tree(data_2, labels_2, tree_2, classifier,
dimred, useRE, FN, n_neighbors, dynamic_neighbors,
distkNN)
distkNN, gpu=gpu)

# Predict labels other dataset
labels_2_pred,_ = predict_labels(data_2, tree, threshold=rej_threshold)
Expand Down
16 changes: 14 additions & 2 deletions scHPL/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
import numpy as np
from numpy import linalg as LA
from .utils import TreeNode
from .faissKNeighbors import FaissKNeighbors
try:
from tqdm import tqdm
except ImportError:
def tqdm(x):
return x
# from utils import TreeNode

def predict_labels(testdata,
tree: TreeNode,
threshold: float = 0.5):
threshold: float = 0.5,
gpu=None):
'''Use the trained tree to predict the labels of a new dataset.
Parameters
Expand Down Expand Up @@ -51,10 +58,15 @@ def predict_labels(testdata,
pca, pcs = tree[0].get_pca()
testdata = pca.transform(testdata)
dimred = True

if (tree[0].classifier and
tree[0].classifier.__class__ == FaissKNeighbors and
gpu is not None):
tree[0].classifier.to_gpu(gpu)

labels_all = []
prob_all = np.zeros((np.shape(testdata)[0],1))
for idx, testpoint in enumerate(testdata):
for idx, testpoint in enumerate(tqdm(testdata)):
if useRE:
if rej_RE[idx]:
labels_all.append('Rejected (RE)')
Expand Down
22 changes: 13 additions & 9 deletions scHPL/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import copy as cp

try:
from typing import Literal
from typing import Literal, Optional
except ImportError:
from typing_extensions import Literal
from typing_extensions import Literal, Optional


@ignore_warnings(category=ConvergenceWarning)
Expand All @@ -34,7 +34,8 @@ def train_tree(data,
FN: float = 0.5,
n_neighbors: int = 50,
dynamic_neighbors: bool = True,
distkNN: int = 99):
distkNN: int = 99,
gpu: Optional[int] = None):
'''Train a hierarchical classifier.
Parameters
Expand Down Expand Up @@ -66,6 +67,8 @@ def train_tree(data,
cell and it's closest neighbor of the training set. Threshold is
set to the distkNN's percentile of distances within the training
set
gpu: int | None = None
GPU index to use for the Faiss library (only used when classifier='knn')
Returns
Expand Down Expand Up @@ -129,7 +132,7 @@ def train_tree(data,
except:
None
_,_ = _train_parentnode(data, labels_train, tree[0], n_neighbors,
dynamic_neighbors, distkNN)
dynamic_neighbors, distkNN, gpu=gpu)
else:
for n in tree[0].descendants:
_ = _train_node(data, labels, n, classifier, dimred, numgenes)
Expand Down Expand Up @@ -175,7 +178,7 @@ def _train_node(data, labels, n, classifier, dimred, numgenes):

return group

def _train_parentnode(data, labels, n, n_neighbors, dynamic_neighbors, distkNN):
def _train_parentnode(data, labels, n, n_neighbors, dynamic_neighbors, distkNN, gpu=None):
'''Train a knn classifier. In contrast to the linear svm and oc svm, this
is trained for each parent node instead of each child node
Expand All @@ -187,6 +190,7 @@ def _train_parentnode(data, labels, n, n_neighbors, dynamic_neighbors, distkNN):
classifier: which classifier to use
dimred: dimensionality reduction
numgenes: number of genes in the training data
gpu: GPU index to use for the Faiss library (only used when classifier='knn')
Return
------
Expand All @@ -203,15 +207,15 @@ def _train_parentnode(data, labels, n, n_neighbors, dynamic_neighbors, distkNN):
for j in n.descendants:
group_new, labels_new = _train_parentnode(data, labels, j,
n_neighbors, dynamic_neighbors,
distkNN)
distkNN, gpu=gpu)
group[np.where(group_new == 1)[0]] = 1
labels[np.where(group_new == 1)[0]] = labels_new[np.where(group_new == 1)[0]]
if n.name != None:
# special case; if n has only 1 child
if len(n.descendants) == 1:
group[np.squeeze(np.isin(labels, n.name))] = 1
# train_knn
_train_knn(data,labels,group,n,n_neighbors,dynamic_neighbors,distkNN)
_train_knn(data,labels,group,n,n_neighbors,dynamic_neighbors,distkNN,gpu=gpu)
# rename all group == 1 to node.name
group[np.squeeze(np.isin(labels, n.name))] = 1
labels[group==1] = n.name[0]
Expand Down Expand Up @@ -271,7 +275,7 @@ def _train_svm(data, labels, group, n):
n.set_classifier(clf) #save classifier to the node


def _train_knn(data, labels, group, n, n_neighbors, dynamic_neighbors, distkNN):
def _train_knn(data, labels, group, n, n_neighbors, dynamic_neighbors, distkNN, gpu=None):
'''Train a linear svm and attach to the node
Parameters:
Expand Down Expand Up @@ -300,7 +304,7 @@ def _train_knn(data, labels, group, n, n_neighbors, dynamic_neighbors, distkNN):
try:
import faiss
from .faissKNeighbors import FaissKNeighbors
clf = FaissKNeighbors(k=k)
clf = FaissKNeighbors(k=k, gpu=gpu)
clf.fit(data_knn, labels_knn)
#print('Using FAISS library')

Expand Down
25 changes: 10 additions & 15 deletions scHPL/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def set_classifier(self, classifier):
"""
Add a classifier to the node.
"""
self.classifier = copy.deepcopy(classifier)
self.classifier = classifier

def get_classifier(self):
return self.classifier
Expand Down Expand Up @@ -371,20 +371,15 @@ def _print_node(node, hor, ver_steps, fig, new_nodes):
x, y = ([np.max([0.05, hor-0.045]), hor], [ver, ver])
line = mlines.Line2D(x,y, lw=1)
fig.add_artist(line)

# Add textbox
if np.isin(node.name[0], new_nodes):
txt = r"$\bf{" + node.name[0] + "}$"
else:
txt = node.name[0]

for n in node.name:
if(n != node.name[0]):
if np.isin(n, new_nodes):
txt = txt + ' & ' + r"$\bf{" + n + "}$"
else:
txt = txt + ' & ' + n


def format_node(name):
if np.isin(name, new_nodes):
return r"$\bf{" + name.replace("_", "\_") + "}$"
else:
return name

txt = " & ".join([format_node(n) for n in node.name])

fig.text(hor,ver, txt, size=10,
ha = 'left', va='center',
bbox = dict(boxstyle='round', fc='w', ec='k'))
Expand Down

0 comments on commit e7f3a2e

Please sign in to comment.