diff --git a/src/measurements.rs b/src/measurements.rs index 9c996d9..ba95173 100644 --- a/src/measurements.rs +++ b/src/measurements.rs @@ -1,21 +1,39 @@ use ndarray::{s, Array3, ArrayBase, Axis, Data, Ix3, Zip}; +use num_traits::{Bounded, FromPrimitive, NumAssignOps, ToPrimitive, Unsigned}; use crate::Mask; -const BACKGROUND: u16 = 0; -const FOREGROUND: u16 = 1; +pub trait LabelType: + Copy + FromPrimitive + ToPrimitive + Ord + Unsigned + NumAssignOps + Bounded +{ + fn background() -> Self; + fn foreground() -> Self; +} + +impl LabelType for T +where + T: Copy + FromPrimitive + ToPrimitive + Ord + Unsigned + NumAssignOps + Bounded, +{ + fn background() -> Self { + T::zero() + } + fn foreground() -> Self { + T::one() + } +} /// Calculates the histogram of a label image. /// -/// * `labels` - `u16` 3D labels image, returned by the `label` function. +/// * `labels` - 3D labels image, returned by the `label` function. /// * `nb_features` - Number of unique labels, returned by the `label` function. pub fn label_histogram(labels: &ArrayBase, nb_features: usize) -> Vec where - S: Data, + S: Data, + S::Elem: LabelType, { let mut count = vec![0; nb_features + 1]; Zip::from(labels).for_each(|&l| { - count[l as usize] += 1; + count[l.to_usize().unwrap()] += 1; }); count } @@ -24,19 +42,20 @@ where /// /// Ignores the background label. A blank label image will return None. /// -/// * `labels` - `u16` 3D labels image, returned by the `label` function. +/// * `labels` - 3D labels image, returned by the `label` function. /// * `nb_features` - Number of unique labels, returned by the `label` function. pub fn most_frequent_label( labels: &ArrayBase, nb_features: usize, -) -> Option<(u16, usize)> +) -> Option<(S::Elem, usize)> where - S: Data, + S: Data, + S::Elem: LabelType, { let hist = label_histogram(labels, nb_features); let (max, max_index) = hist[1..].iter().enumerate().fold((0, 0), |acc, (i, &nb)| acc.max((nb, i))); - (max > 0).then(|| ((max_index + 1) as u16, max)) + (max > 0).then(|| (S::Elem::from_usize(max_index + 1).unwrap(), max)) } /// Returns a new mask, containing the biggest zone of `mask`. @@ -44,6 +63,8 @@ where /// * `mask` - Binary image to be labeled and studied. /// * `structure` - Structuring element used for the labeling. Must be 3x3x3 (e.g. the result /// of [`Kernel3d::generate`](crate::Kernel3d::generate)) and centrosymmetric. The center must be `true`. +/// +/// The labeling is done using `u16`, this may be too small when `mask` has more than [`u16::MAX`] elements. pub fn largest_connected_components( mask: &ArrayBase, structure: &ArrayBase, @@ -51,7 +72,7 @@ pub fn largest_connected_components( where S: Data, { - let (labels, nb_features) = label(mask, structure); + let (labels, nb_features) = label::<_, u16>(mask, structure); let (right_label, _) = most_frequent_label(&labels, nb_features)?; Some(labels.mapv(|l| l == right_label)) } @@ -63,19 +84,35 @@ where /// * `mask` - Binary image to be labeled. `false` values are considered the background. /// * `structure` - Structuring element used for the labeling. Must be 3x3x3 (e.g. the result /// of [`Kernel3d::generate`](crate::Kernel3d::generate)) and centrosymmetric. The center must be `true`. -pub fn label(data: &ArrayBase, structure: &ArrayBase) -> (Array3, usize) +/// +/// The return type of `label` can be specified using turbofish syntax: +/// +/// ``` +/// // Will use `u16` as the label type +/// ndarray_ndimage::label::<_, u16>( +/// &ndarray::Array3::from_elem((100, 100, 100), true), +/// &ndarray_ndimage::Kernel3d::Star.generate() +/// ); +/// ``` +/// +/// As a rough rule of thumb, the maximum value of the label type should be larger than `data.len()`. +/// This is the worst case, the exact bound will depend on the kernel used. If the label type overflows +/// while assigning labels, a panic will occur. +pub fn label(data: &ArrayBase, structure: &ArrayBase) -> (Array3, usize) where S: Data, + O: LabelType, { assert!(structure.shape() == &[3, 3, 3], "`structure` must be size 3 in all dimensions"); assert!(structure == structure.slice(s![..;-1, ..;-1, ..;-1]), "`structure is not symmetric"); let len = data.dim().2; - let mut line_buffer = vec![BACKGROUND; len + 2]; - let mut neighbors = vec![BACKGROUND; len + 2]; + let mut line_buffer = vec![O::background(); len + 2]; + let mut neighbors = vec![O::background(); len + 2]; - let mut next_region = FOREGROUND + 1; - let mut equivalences: Vec<_> = (0..next_region).collect(); + let mut next_region = O::foreground() + O::one(); + let mut equivalences: Vec<_> = + (0..next_region.to_usize().unwrap()).map(|x| O::from_usize(x).unwrap()).collect(); // We only handle 3D data for now, but this algo can handle N-dimensional data. // https://github.com/scipy/scipy/blob/v0.16.1/scipy/ndimage/src/_ni_label.pyx @@ -102,10 +139,10 @@ where let use_previous = structure[(1, 1, 0)]; let width = data.dim().0 as isize; let height = data.dim().1 as isize; - let mut labels = Array3::from_elem(data.dim(), BACKGROUND); + let mut labels = Array3::from_elem(data.dim(), O::background()); Zip::indexed(data.lanes(Axis(2))).for_each(|idx, data| { for (&v, b) in data.iter().zip(&mut line_buffer[1..]) { - *b = if !v { BACKGROUND } else { FOREGROUND } + *b = if !v { O::background() } else { O::foreground() } } let mut needs_self_labeling = true; @@ -154,7 +191,7 @@ where // Compact and apply the equivalences let nb_features = compact_equivalences(&mut equivalences, next_region); - labels.mapv_inplace(|l| equivalences[l as usize]); + labels.mapv_inplace(|l| equivalences[l.to_usize().unwrap()]); (labels, nb_features) } @@ -173,18 +210,21 @@ fn is_valid(idx: &[usize; 2], coords: &[isize; 2], dims: &[isize; 2]) -> Option< .and_then(|x| valid(idx[1], coords[1], dims[1]).and_then(|y| Some((x, y)))) } -fn label_line_with_neighbor( - line: &mut [u16], - neighbors: &[u16], - equivalences: &mut Vec, +fn label_line_with_neighbor( + line: &mut [O], + neighbors: &[O], + equivalences: &mut Vec, kernel: [bool; 3], use_previous: bool, label_unlabeled: bool, - mut next_region: u16, -) -> u16 { + mut next_region: O, +) -> O +where + O: LabelType, +{ let mut previous = line[0]; for (n, l) in neighbors.windows(3).zip(&mut line[1..]) { - if *l != BACKGROUND { + if *l != O::background() { for (&n, &k) in n.iter().zip(&kernel) { if k { *l = take_label_or_merge(*l, n, equivalences); @@ -195,10 +235,11 @@ fn label_line_with_neighbor( *l = take_label_or_merge(*l, previous, equivalences); } // Still needs a label? - if *l == FOREGROUND { + if *l == O::foreground() { *l = next_region; equivalences.push(next_region); - next_region += 1; + assert!(next_region < O::max_value(), "Overflow when assigning label"); + next_region += O::one(); } } } @@ -208,11 +249,14 @@ fn label_line_with_neighbor( } /// Take the label of a neighbor, or mark them for merging -fn take_label_or_merge(current: u16, neighbor: u16, equivalences: &mut [u16]) -> u16 { - if neighbor == BACKGROUND { +fn take_label_or_merge(current: O, neighbor: O, equivalences: &mut [O]) -> O +where + O: LabelType, +{ + if neighbor == O::background() { current - } else if current == FOREGROUND { - neighbor // neighbor is not BACKGROUND + } else if current == O::foreground() { + neighbor // neighbor is not background } else if current != neighbor { mark_for_merge(neighbor, current, equivalences) } else { @@ -221,57 +265,63 @@ fn take_label_or_merge(current: u16, neighbor: u16, equivalences: &mut [u16]) -> } /// Mark two labels to be merged -fn mark_for_merge(mut a: u16, mut b: u16, equivalences: &mut [u16]) -> u16 { +fn mark_for_merge(mut a: O, mut b: O, equivalences: &mut [O]) -> O +where + O: LabelType, +{ // Find smallest root for each of a and b let original_a = a; - while a != equivalences[a as usize] { - a = equivalences[a as usize]; + while a != equivalences[a.to_usize().unwrap()] { + a = equivalences[a.to_usize().unwrap()]; } let original_b = b; - while b != equivalences[b as usize] { - b = equivalences[b as usize]; + while b != equivalences[b.to_usize().unwrap()] { + b = equivalences[b.to_usize().unwrap()]; } let lowest_label = a.min(b); // Merge roots - equivalences[a as usize] = lowest_label; - equivalences[b as usize] = lowest_label; + equivalences[a.to_usize().unwrap()] = lowest_label; + equivalences[b.to_usize().unwrap()] = lowest_label; // Merge every step to minlabel a = original_a; while a != lowest_label { let a_copy = a; - a = equivalences[a as usize]; - equivalences[a_copy as usize] = lowest_label; + a = equivalences[a.to_usize().unwrap()]; + equivalences[a_copy.to_usize().unwrap()] = lowest_label; } b = original_b; while b != lowest_label { let b_copy = b; - b = equivalences[b as usize]; - equivalences[b_copy as usize] = lowest_label; + b = equivalences[b.to_usize().unwrap()]; + equivalences[b_copy.to_usize().unwrap()] = lowest_label; } lowest_label } /// Compact the equivalences vector -fn compact_equivalences(equivalences: &mut [u16], next_region: u16) -> usize { - let no_labelling = next_region == 2; +fn compact_equivalences(equivalences: &mut [O], next_region: O) -> usize +where + O: LabelType, +{ + let no_labelling = next_region == O::from_usize(2).unwrap(); let mut dest_label = if no_labelling { 0 } else { 1 }; - for i in 2..next_region as usize { - if equivalences[i] == i as u16 { - equivalences[i] = dest_label; - dest_label += 1; + for i in 2..next_region.to_usize().unwrap() { + if equivalences[i] == O::from_usize(i).unwrap() { + equivalences[i] = O::from_usize(dest_label).unwrap(); + dest_label = dest_label + 1; } else { // We've compacted every label below this, and equivalences has an invariant that it // always points downward. Therefore, we can fetch the final label by two steps of // indirection. - equivalences[i] = equivalences[equivalences[i] as usize]; + equivalences[i] = equivalences[equivalences[i].to_usize().unwrap()]; } } if no_labelling { 0 } else { - *equivalences.iter().max().unwrap() as usize + equivalences.iter().max().unwrap().to_usize().unwrap() } } diff --git a/tests/measurements.rs b/tests/measurements.rs index 51da409..1d16128 100644 --- a/tests/measurements.rs +++ b/tests/measurements.rs @@ -8,7 +8,7 @@ use ndarray_ndimage::{ fn test_label_0() { let star = Kernel3d::Star.generate(); let data = Array3::zeros((3, 3, 3)); - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &star); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v > 0), &star); assert_eq!(labels, data); assert_eq!(nb_features, 0); assert_eq!(label_histogram(&labels, nb_features), vec![27]); @@ -23,7 +23,7 @@ fn test_label_2() { [[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]], ]); - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &star); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v > 0), &star); assert_eq!(labels, data); assert_eq!(nb_features, 1); assert_eq!(label_histogram(&labels, nb_features), vec![18, 9]); @@ -43,7 +43,7 @@ fn test_label_3() { [[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[2, 2, 2], [2, 2, 2], [2, 2, 2]], ]); - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &star); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v > 0), &star); assert_eq!(labels, gt); assert_eq!(nb_features, 2); assert_eq!(label_histogram(&labels, nb_features), vec![12, 6, 9]); @@ -91,7 +91,7 @@ fn test_label_4() { [[0, 0, 2, 2], [0, 0, 2, 0], [0, 0, 0, 0]], [[3, 3, 0, 0], [3, 0, 0, 0], [0, 0, 0, 0]], ]); - let (labels, nb_features) = label(&data.mapv(|v| v >= 0.7), &star); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v >= 0.7), &star); assert_eq!(labels, gt); assert_eq!(nb_features, 3); assert_eq!(label_histogram(>, nb_features), vec![71, 127, 3, 3]); @@ -177,7 +177,7 @@ fn test_label_5() { [3, 0, 0, 0, 0], ], ]); - let (labels, nb_features) = label(&data.mapv(|v| v >= 0.7).view(), &star.view()); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v >= 0.7).view(), &star.view()); assert_eq!(labels, gt); assert_eq!(nb_features, 3); assert_eq!(label_histogram(>, nb_features), vec![113, 33, 2, 2]); @@ -255,28 +255,61 @@ fn test_label_different_kernels() { [[0, 0, 0, 0], [0, 0, 6, 0], [0, 0, 5, 0]], ]); { - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &Kernel3d::Star.generate()); + let (labels, nb_features) = + label::<_, u16>(&data.mapv(|v| v > 0), &Kernel3d::Star.generate()); assert_eq!(labels, star_result); assert_eq!(nb_features, 6); } { - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &Kernel3d::Ball.generate()); + let (labels, nb_features) = + label::<_, u16>(&data.mapv(|v| v > 0), &Kernel3d::Ball.generate()); assert_eq!(labels, ball_result); assert_eq!(nb_features, 2); } { - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &Kernel3d::Full.generate()); + let (labels, nb_features) = + label::<_, u16>(&data.mapv(|v| v > 0), &Kernel3d::Full.generate()); assert_eq!(labels, full_result); assert_eq!(nb_features, 1); } { - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &odd1_kernel); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v > 0), &odd1_kernel); assert_eq!(labels, odd1_result); assert_eq!(nb_features, 2); } { - let (labels, nb_features) = label(&data.mapv(|v| v > 0), &odd2_kernel); + let (labels, nb_features) = label::<_, u16>(&data.mapv(|v| v > 0), &odd2_kernel); assert_eq!(labels, odd2_result); assert_eq!(nb_features, 6); } } + +#[test] +fn test_label_u8() { + let data = arr3(&[ + [[0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 1]], + [[0, 0, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]], + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]], + [[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], + ]); + let star_result = arr3(&[ + [[0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 2]], + [[0, 0, 0, 0], [1, 0, 3, 0], [0, 0, 0, 0]], + [[4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 0, 0]], + [[0, 0, 0, 0], [0, 0, 6, 0], [0, 0, 6, 0]], + ]); + + let (labels, nb_features) = label::<_, u8>(&data.mapv(|v| v > 0), &Kernel3d::Star.generate()); + assert_eq!(labels, star_result); + assert_eq!(nb_features, 6); +} + +#[should_panic] +#[test] +fn test_label_u8_panic() { + let mut unconnected_kernel = Array3::from_elem((3, 3, 3), false); + unconnected_kernel[(1, 1, 1)] = true; + + // Try to label 1000 items, this would require 1000 labels and u8 only has 255, therefore we overflow and panic + label::<_, u8>(&Array3::from_elem((10, 10, 10), true), &unconnected_kernel); +}