Skip to content

Commit

Permalink
Merge pull request #235 from YosefLab/missing_data_speedup
Browse files Browse the repository at this point in the history
sped up average missing data
  • Loading branch information
mattjones315 authored Feb 6, 2024
2 parents b4d6208 + 8e8cab1 commit 5b8a452
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 46 deletions.
31 changes: 21 additions & 10 deletions cassiopeia/data/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,28 @@ def get_lca_characters(
all_states = [
vec[i] for vec in vecs if vec[i] != missing_state_indicator
]
chars = set.intersection(
*map(
set,
[
state if is_ambiguous_state(state) else [state]
for state in all_states
],

# this check is specifically if all_states consists of a single
# ambiguous state.
if len(list(set(all_states))) == 1:
state = all_states[0]
# lca_vec[i] = state
if is_ambiguous_state(state) and len(state) == 1:
lca_vec[i] = state[0]
else:
lca_vec[i] = all_states[0]
else:
chars = set.intersection(
*map(
set,
[
state if is_ambiguous_state(state) else [state]
for state in all_states
],
)
)
)
if len(chars) == 1:
lca_vec[i] = list(chars)[0]
if len(chars) == 1:
lca_vec[i] = list(chars)[0]
return lca_vec


Expand Down
76 changes: 42 additions & 34 deletions cassiopeia/solver/missing_data_methods.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This file contains included missing data imputation methods."""

from typing import Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -56,50 +57,57 @@ def assign_missing_average(
sample_names, missing
)

def score_side(subset_character_matrix, missing_sample):
def score_side(subset_character_states, query_states, weights):
score = 0
for char in range(character_matrix.shape[1]):
state = character_array[missing_sample, char]
if state != missing_state_indicator and state != 0:
all_states = unravel_ambiguous_states(
subset_character_matrix[:, char]
)
state_counts = np.unique(all_states, return_counts=True)

if is_ambiguous_state(state):
ambiguous_states = [s for s in state if s != 0]
for ambiguous_state in ambiguous_states:
ind = np.where(state_counts[0] == ambiguous_state)
if len(ind[0]) > 0:
if weights:
score += (
weights[char][ambiguous_state]
* state_counts[1][ind[0][0]]
)
else:
score += state_counts[1][ind[0][0]]

for char in range(len(subset_character_states)):

query_state = [
q
for q in query_states[char]
if q != 0 and q != missing_state_indicator
]
all_states = np.array(subset_character_states[char])
for q in query_state:
if weights:
score += weights[char][q] * np.count_nonzero(
all_states == q
)
else:
ind = np.where(state_counts[0] == state)
if len(ind[0]) > 0:
if weights:
score += (
weights[char][state]
* state_counts[1][ind[0][0]]
)
else:
score += state_counts[1][ind[0][0]]
score += np.count_nonzero(all_states == q)

return score

subset_character_array_left = character_array[left_indices, :]
subset_character_array_right = character_array[right_indices, :]

all_left_states = [
unravel_ambiguous_states(subset_character_array_left[:, char])
for char in range(subset_character_array_left.shape[1])
]
all_right_states = [
unravel_ambiguous_states(subset_character_array_right[:, char])
for char in range(subset_character_array_right.shape[1])
]

for sample_index in missing_indices:
left_score = score_side(subset_character_array_left, sample_index)
right_score = score_side(subset_character_array_right, sample_index)

if left_score / len(left_set) > right_score / len(right_set):
all_states_for_sample = [
unravel_ambiguous_states([character_array[sample_index, char]])
for char in range(character_array.shape[1])
]

left_score = score_side(
np.array(all_left_states, dtype=object),
np.array(all_states_for_sample, dtype=object),
weights,
)
right_score = score_side(
np.array(all_right_states, dtype=object),
np.array(all_states_for_sample, dtype=object),
weights,
)

if (left_score / len(left_set)) > (right_score / len(right_set)):
left_set.append(sample_names[sample_index])
else:
right_set.append(sample_names[sample_index])
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ nbsphinx = {version = "*", optional = true}
nbsphinx-link = {version = "*", optional = true}
networkx = "==3.1"
ngs-tools = ">=1.5.6"
numba = ">=0.51.0"
numba = ">=0.51.0,<0.59.0"
numpy = ">=1.22"
opencv-python = {version = ">=4.5.4.60", optional = true}
pandas = ">=1.1.4"
Expand All @@ -75,7 +75,7 @@ pyvista = {version = "=0.41.0", optional = true}
scanpydoc = {version = ">=0.5", optional = true}
scikit-image = {version = ">=0.19.1", optional = true}
scikit-learn = {version = ">=1.0.2", optional = true}
scipy = ">=1.2.0"
scipy = ">=1.2.0,<=1.11.4"
sphinx = {version = ">=3.4", optional = true}
sphinx-autodoc-typehints = {version = "*", optional = true}
sphinx-gallery = {version = ">0.6", optional = true}
Expand Down
11 changes: 11 additions & 0 deletions test/data_tests/data_utilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,17 @@ def test_lca_characters_ambiguous(self):
)
self.assertEqual(ret_vec, [1, 2, 3, 0, 5])

def test_lca_characters_ambiguous_and_missing(self):
vecs = [
[(1, 1), (0, 2), (3, 0), (4,), (5,)],
[1, -1, -1, 3, -1],
[1, -1, (3, 0), 2, -1],
]
ret_vec = data_utilities.get_lca_characters(
vecs, missing_state_indicator=-1
)
self.assertEqual(ret_vec, [1, (0,2), (3,0), 0, 5])

def test_resolve_most_abundant(self):
state = (1, 2, 3, 3)
self.assertEqual(data_utilities.resolve_most_abundant(state), 3)
Expand Down

0 comments on commit 5b8a452

Please sign in to comment.