diff --git a/cassiopeia/data/utilities.py b/cassiopeia/data/utilities.py index 35cde4f4..bcf008f4 100755 --- a/cassiopeia/data/utilities.py +++ b/cassiopeia/data/utilities.py @@ -61,17 +61,28 @@ def get_lca_characters( all_states = [ vec[i] for vec in vecs if vec[i] != missing_state_indicator ] - chars = set.intersection( - *map( - set, - [ - state if is_ambiguous_state(state) else [state] - for state in all_states - ], + + # this check is specifically if all_states consists of a single + # ambiguous state. + if len(list(set(all_states))) == 1: + state = all_states[0] + # lca_vec[i] = state + if is_ambiguous_state(state) and len(state) == 1: + lca_vec[i] = state[0] + else: + lca_vec[i] = all_states[0] + else: + chars = set.intersection( + *map( + set, + [ + state if is_ambiguous_state(state) else [state] + for state in all_states + ], + ) ) - ) - if len(chars) == 1: - lca_vec[i] = list(chars)[0] + if len(chars) == 1: + lca_vec[i] = list(chars)[0] return lca_vec diff --git a/cassiopeia/solver/missing_data_methods.py b/cassiopeia/solver/missing_data_methods.py index eac5249c..4cf10351 100644 --- a/cassiopeia/solver/missing_data_methods.py +++ b/cassiopeia/solver/missing_data_methods.py @@ -1,4 +1,5 @@ """This file contains included missing data imputation methods.""" + from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -56,50 +57,57 @@ def assign_missing_average( sample_names, missing ) - def score_side(subset_character_matrix, missing_sample): + def score_side(subset_character_states, query_states, weights): score = 0 - for char in range(character_matrix.shape[1]): - state = character_array[missing_sample, char] - if state != missing_state_indicator and state != 0: - all_states = unravel_ambiguous_states( - subset_character_matrix[:, char] - ) - state_counts = np.unique(all_states, return_counts=True) - - if is_ambiguous_state(state): - ambiguous_states = [s for s in state if s != 0] - for ambiguous_state in ambiguous_states: - ind = np.where(state_counts[0] == ambiguous_state) - if len(ind[0]) > 0: - if weights: - score += ( - weights[char][ambiguous_state] - * state_counts[1][ind[0][0]] - ) - else: - score += state_counts[1][ind[0][0]] - + for char in range(len(subset_character_states)): + + query_state = [ + q + for q in query_states[char] + if q != 0 and q != missing_state_indicator + ] + all_states = np.array(subset_character_states[char]) + for q in query_state: + if weights: + score += weights[char][q] * np.count_nonzero( + all_states == q + ) else: - ind = np.where(state_counts[0] == state) - if len(ind[0]) > 0: - if weights: - score += ( - weights[char][state] - * state_counts[1][ind[0][0]] - ) - else: - score += state_counts[1][ind[0][0]] + score += np.count_nonzero(all_states == q) return score subset_character_array_left = character_array[left_indices, :] subset_character_array_right = character_array[right_indices, :] + all_left_states = [ + unravel_ambiguous_states(subset_character_array_left[:, char]) + for char in range(subset_character_array_left.shape[1]) + ] + all_right_states = [ + unravel_ambiguous_states(subset_character_array_right[:, char]) + for char in range(subset_character_array_right.shape[1]) + ] + for sample_index in missing_indices: - left_score = score_side(subset_character_array_left, sample_index) - right_score = score_side(subset_character_array_right, sample_index) - if left_score / len(left_set) > right_score / len(right_set): + all_states_for_sample = [ + unravel_ambiguous_states([character_array[sample_index, char]]) + for char in range(character_array.shape[1]) + ] + + left_score = score_side( + np.array(all_left_states, dtype=object), + np.array(all_states_for_sample, dtype=object), + weights, + ) + right_score = score_side( + np.array(all_right_states, dtype=object), + np.array(all_states_for_sample, dtype=object), + weights, + ) + + if (left_score / len(left_set)) > (right_score / len(right_set)): left_set.append(sample_names[sample_index]) else: right_set.append(sample_names[sample_index]) diff --git a/pyproject.toml b/pyproject.toml index 4a84f3b5..d1d9fdc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ nbsphinx = {version = "*", optional = true} nbsphinx-link = {version = "*", optional = true} networkx = "==3.1" ngs-tools = ">=1.5.6" -numba = ">=0.51.0" +numba = ">=0.51.0,<0.59.0" numpy = ">=1.22" opencv-python = {version = ">=4.5.4.60", optional = true} pandas = ">=1.1.4" @@ -75,7 +75,7 @@ pyvista = {version = "=0.41.0", optional = true} scanpydoc = {version = ">=0.5", optional = true} scikit-image = {version = ">=0.19.1", optional = true} scikit-learn = {version = ">=1.0.2", optional = true} -scipy = ">=1.2.0" +scipy = ">=1.2.0,<=1.11.4" sphinx = {version = ">=3.4", optional = true} sphinx-autodoc-typehints = {version = "*", optional = true} sphinx-gallery = {version = ">0.6", optional = true} diff --git a/test/data_tests/data_utilities_test.py b/test/data_tests/data_utilities_test.py index dcbcec2a..94902faa 100755 --- a/test/data_tests/data_utilities_test.py +++ b/test/data_tests/data_utilities_test.py @@ -316,6 +316,17 @@ def test_lca_characters_ambiguous(self): ) self.assertEqual(ret_vec, [1, 2, 3, 0, 5]) + def test_lca_characters_ambiguous_and_missing(self): + vecs = [ + [(1, 1), (0, 2), (3, 0), (4,), (5,)], + [1, -1, -1, 3, -1], + [1, -1, (3, 0), 2, -1], + ] + ret_vec = data_utilities.get_lca_characters( + vecs, missing_state_indicator=-1 + ) + self.assertEqual(ret_vec, [1, (0,2), (3,0), 0, 5]) + def test_resolve_most_abundant(self): state = (1, 2, 3, 3) self.assertEqual(data_utilities.resolve_most_abundant(state), 3)