Skip to content

Commit

Permalink
Update predict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lcmmichielsen committed Apr 30, 2024
1 parent e7f3a2e commit e653a5e
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions scHPL/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import numpy as np
from numpy import linalg as LA
from .utils import TreeNode
from .faissKNeighbors import FaissKNeighbors
try:
from .faissKNeighbors import FaissKNeighbors
except:
None
try:
from tqdm import tqdm
except ImportError:
Expand Down Expand Up @@ -58,11 +61,14 @@ def predict_labels(testdata,
pca, pcs = tree[0].get_pca()
testdata = pca.transform(testdata)
dimred = True

if (tree[0].classifier and
tree[0].classifier.__class__ == FaissKNeighbors and
gpu is not None):
tree[0].classifier.to_gpu(gpu)

try:
if (tree[0].classifier and
tree[0].classifier.__class__ == FaissKNeighbors and
gpu is not None):
tree[0].classifier.to_gpu(gpu)
except:
None

labels_all = []
prob_all = np.zeros((np.shape(testdata)[0],1))
Expand Down

0 comments on commit e653a5e

Please sign in to comment.