Skip to content

Commit

Permalink
Black
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-nicodemus committed Jun 17, 2024
1 parent d00228b commit c3bc5e6
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 215 deletions.
67 changes: 34 additions & 33 deletions src/cajal/combined_slb_qgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,22 @@
from scipy.spatial.distance import squareform
from tqdm.notebook import tqdm

from .qgw import (Array, _init_qgw_pool, _quantized_gw_index, _tuple_set_of,
quantized_icdm, slb_parallel_memory)
from .run_gw import (DistanceMatrix, Distribution, Matrix, _batched, cell_iterator_csv,
uniform)
from .qgw import (
Array,
_init_qgw_pool,
_quantized_gw_index,
_tuple_set_of,
quantized_icdm,
slb_parallel_memory,
)
from .run_gw import (
DistanceMatrix,
Distribution,
Matrix,
_batched,
cell_iterator_csv,
uniform,
)

# A BooleanSquareMatrix is a square matrix of booleans.
BooleanSquareMatrix = NewType("BooleanSquareMatrix", npt.NDArray[np.bool_])
Expand Down Expand Up @@ -165,9 +177,7 @@ def cdf(
return probabilities


def _nn_indices_from_slb(
slb_dmat: DistanceMatrix,
nearest_neighbors: int):
def _nn_indices_from_slb(slb_dmat: DistanceMatrix, nearest_neighbors: int):
N = slb_dmat.shape[0]
# The following line of code was changed from:
# ind_y = np.argsort(slb_dmat, axis=1)[:, 1 : nearest_neighbors + 1]
Expand All @@ -182,11 +192,7 @@ def _nn_indices_from_slb(
return b


def _sample_indices_by_bin(
slb_dmat: DistanceMatrix,
slb_bins: int,
sn: SamplingNumber
):
def _sample_indices_by_bin(slb_dmat: DistanceMatrix, slb_bins: int, sn: SamplingNumber):
slb_quantile_bins = np.quantile(
squareform(slb_dmat, force="tovector"),
np.arange(slb_bins + 1).astype(float) / float(slb_bins),
Expand Down Expand Up @@ -217,13 +223,13 @@ def _get_initial_indices(


def _indices_from_cdf_prob(
N : int,
X : npt.NDArray[np.int_],
Y : npt.NDArray[np.int_],
cdf_prob : Array,
nearest_neighbors : int,
accuracy : float,
exp_decay : float,
N: int,
X: npt.NDArray[np.int_],
Y: npt.NDArray[np.int_],
cdf_prob: Array,
nearest_neighbors: int,
accuracy: float,
exp_decay: float,
) -> list[tuple[int, int]]:
# This array is crucial. It contains the list of cell pair indices
# to be computed in order of priority - in *descending* order of
Expand All @@ -243,9 +249,7 @@ def _indices_from_cdf_prob(
np.searchsorted(-cdf_prob[undershooting_prob_indices], -0.5)
)
total_expected_injuries = np.sum(cdf_prob)
incremental_expected_injuries = np.cumsum(
cdf_prob[undershooting_prob_indices]
)
incremental_expected_injuries = np.cumsum(cdf_prob[undershooting_prob_indices])
acceptable_injuries = (nearest_neighbors * N) * (1 - accuracy)
acceptable_injury_index = int(
np.searchsorted(
Expand All @@ -263,15 +267,12 @@ def _indices_from_cdf_prob(
return list(_tuple_set_of(X[indices], Y[indices]))


def cutoff_of(
estimator_matrix : DistanceMatrix,
nn : int
):
def cutoff_of(estimator_matrix: DistanceMatrix, nn: int):
"""Return the vector of current cutoffs for the nn-th nearest neighbor."""
return np.sort(estimator_matrix, axis=1)[:, nn + 1]


def unknown_indices_of(gw_known : BooleanSquareMatrix):
def unknown_indices_of(gw_known: BooleanSquareMatrix):
"""Return upper-triangular indices (i,j) where gw is unknown."""
Xuk_ts, Yuk_ts = np.nonzero(np.logical_not(gw_known))
upper_triangular = Xuk_ts <= Yuk_ts
Expand All @@ -281,12 +282,12 @@ def unknown_indices_of(gw_known : BooleanSquareMatrix):


def estimator_matrix_of(
slb_dmat : DistanceMatrix,
gw_dmat : DistanceMatrix,
gw_known : BooleanSquareMatrix,
ed : _Error_Distribution,
Xuk : npt.NDArray[np.int_],
Yuk : npt.NDArray[np.int_],
slb_dmat: DistanceMatrix,
gw_dmat: DistanceMatrix,
gw_known: BooleanSquareMatrix,
ed: _Error_Distribution,
Xuk: npt.NDArray[np.int_],
Yuk: npt.NDArray[np.int_],
):
"""
Compute a best-estimate gw matrix.
Expand Down
73 changes: 38 additions & 35 deletions src/cajal/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from typing import Optional, Any

# import numpy.typing as npt
from copy import copy
from scipy.spatial.distance import euclidean
Expand All @@ -12,13 +13,14 @@
from .run_gw import uniform, gw_pairwise_parallel, DistanceMatrix


def dist(n1 : NeuronNode, n2 : NeuronNode):
def dist(n1: NeuronNode, n2: NeuronNode):
"""Compute the Euclidean distance between two NeuronNodes."""
return euclidean(
np.array(n1.coord_triple),
np.array(n2.coord_triple),
)


def lengthtree(tree: NeuronTree):
"""Compute the length of the tree."""
ell = list([t for t in tree])
Expand All @@ -30,8 +32,7 @@ def lengthtree(tree: NeuronTree):
children = [length_tree_dict[i] for i in child_ids]
p1 = tree.root.coord_triple
length = sum(
[euclidean(p1, ct.root.coord_triple)
for ct in tree.child_subgraphs]
[euclidean(p1, ct.root.coord_triple) for ct in tree.child_subgraphs]
)
length += sum([child["length"] for child in children])
new_dict = {"length": length, "children": children}
Expand All @@ -44,12 +45,12 @@ def lengthtree_test(tree: NeuronTree):
ell = [(tree, lengthtree(tree))]
while ell:
t, lt = ell.pop()
assert len(t.child_subgraphs) == len(lt['children'])
assert (total_length(t) - lt['length']) < 0.01
ell += list(zip(t.child_subgraphs, lt['children']))
assert len(t.child_subgraphs) == len(lt["children"])
assert (total_length(t) - lt["length"]) < 0.01
ell += list(zip(t.child_subgraphs, lt["children"]))


def newNeuronNode(n1 : NeuronNode, n2 : NeuronNode, d : float):
def newNeuronNode(n1: NeuronNode, n2: NeuronNode, d: float):
"""Create a new neuron node between n1 and n2, at distance d away from n1."""
ct1 = np.array(n1.coord_triple)
ct2 = np.array(n2.coord_triple)
Expand All @@ -60,7 +61,7 @@ def newNeuronNode(n1 : NeuronNode, n2 : NeuronNode, d : float):
structure_id=n1.structure_id,
coord_triple=(ct3[0], ct3[1], ct3[2]),
radius=n1.radius,
parent_sample_number=n1.sample_number
parent_sample_number=n1.sample_number,
)


Expand All @@ -71,7 +72,7 @@ def newNeuronNode(n1 : NeuronNode, n2 : NeuronNode, d : float):
# cut = p * lt["length"]
# ltchildren = copy(lt['children'])
# while cut >= 0.0:

# num_children = len(t.child_subgraphs)
# length_of_children = [
# dist(t.root, t.child_subgraphs[i].root) + ltchildren[i]['length']
Expand Down Expand Up @@ -104,23 +105,24 @@ def newNeuronNode(n1 : NeuronNode, n2 : NeuronNode, d : float):
# ltchildren = copy(ltchildren[i]['children'])


def trim_swc_no_mutate(t: NeuronTree, lt : dict[str, Any], p : float,
rng: np.random.Generator) -> NeuronTree:
def trim_swc_no_mutate(
t: NeuronTree, lt: dict[str, Any], p: float, rng: np.random.Generator
) -> NeuronTree:
"""Randomly cut off proportion p of the tree t."""
cut = p * lt["length"]
ltchildren = copy(lt['children'])
ltchildren = copy(lt["children"])
t0 = NeuronTree(root=copy(t.root), child_subgraphs=copy(t.child_subgraphs))
children_copied = True
t1 = t0 # t1 *should* be mutated. This is what we want to return.
t1 = t0 # t1 *should* be mutated. This is what we want to return.

while cut > 0.0:
num_children = len(t0.child_subgraphs)
length_of_children = [
dist(t0.root, t0.child_subgraphs[i].root) + ltchildren[i]['length']
dist(t0.root, t0.child_subgraphs[i].root) + ltchildren[i]["length"]
for i in range(num_children)
]
assert(abs(sum(length_of_children)-total_length(t0)) < .01)

assert abs(sum(length_of_children) - total_length(t0)) < 0.01
x = rng.uniform(low=0.0, high=sum(length_of_children))
i = 0
thres = 0.0
Expand All @@ -135,12 +137,14 @@ def trim_swc_no_mutate(t: NeuronTree, lt : dict[str, Any], p : float,
cut -= length_of_children[i]

ltchildren.pop(i)
elif cut >= ltchildren[i]['length']:
elif cut >= ltchildren[i]["length"]:
if not children_copied:
t0.child_subgraphs = copy(t0.child_subgraphs)
children_copied = True
t0.child_subgraphs[i] = NeuronTree(
root=newNeuronNode(t0.root, t0.child_subgraphs[i].root, cut - ltchildren[i]['length']),
root=newNeuronNode(
t0.root, t0.child_subgraphs[i].root, cut - ltchildren[i]["length"]
),
child_subgraphs=[],
)
cut = 0.0
Expand All @@ -150,10 +154,11 @@ def trim_swc_no_mutate(t: NeuronTree, lt : dict[str, Any], p : float,
children_copied = True
t0.child_subgraphs[i] = copy(t0.child_subgraphs[i])
t0 = t0.child_subgraphs[i]
ltchildren = copy(ltchildren[i]['children'])
ltchildren = copy(ltchildren[i]["children"])
children_copied = False
return t1


def trim_swc(t: NeuronTree, params: list[float], rng: np.random.Generator):
"""
Make trimmed copies of t.
Expand All @@ -168,35 +173,36 @@ def trim_swc(t: NeuronTree, params: list[float], rng: np.random.Generator):
return ell


def test_trim_nt(nt : NeuronTree):
def test_trim_nt(nt: NeuronTree):
"""Test trim_swc."""
params = [.1, .2, .3, .4]
params = [0.1, 0.2, 0.3, 0.4]
rng = np.random.default_rng(seed=0)
trees = trim_swc(nt, params, rng)
nt_len = total_length(nt)
nt_trim_lens = list(map(total_length, trees))
for i in range(4):
if (nt_len * (1 - (params[i]))) < 0.99 * nt_trim_lens[i] or \
(nt_len * (1 - (params[i]))) > 1.01 * nt_trim_lens[i]:
if (nt_len * (1 - (params[i]))) < 0.99 * nt_trim_lens[i] or (
nt_len * (1 - (params[i]))
) > 1.01 * nt_trim_lens[i]:
print("nt_len", nt_len)
print("i", i)
print("nt_trim_len", nt_trim_lens[i])


def partial_matching_analysis(
infolder : str,
firstkcells : int,
samplepts : int,
parameters : list[float],
gw_dist_csv : str,
seed : Optional[int] = 0
infolder: str,
firstkcells: int,
samplepts: int,
parameters: list[float],
gw_dist_csv: str,
seed: Optional[int] = 0,
):
"""Carries out a preset analysis routine to assess partial matching on neurons."""
rng = np.random.default_rng(seed)
num_params = len(parameters)
pe = preprocessor_eu(structure_ids=[1, 3, 4], soma_component_only=True)
ci = it.islice(cell_iterator(infolder), firstkcells)
dms : list[DistanceMatrix] = []
dms: list[DistanceMatrix] = []
names = []
for cell_name, cell in ci:
pe_cell = pe(cell)
Expand All @@ -209,9 +215,6 @@ def partial_matching_analysis(
cells = [(dm, u) for dm in dms]
assert len(cells) == num_params * firstkcells
a = gw_pairwise_parallel(
cells=cells,
num_processes=14,
names=names,
gw_dist_csv=gw_dist_csv)
cells=cells, num_processes=14, names=names, gw_dist_csv=gw_dist_csv
)
return a[0]

Loading

0 comments on commit c3bc5e6

Please sign in to comment.