Skip to content

Commit

Permalink
Merge pull request #1986 from larrybradley/segm_labels-dtype
Browse files Browse the repository at this point in the history
Ensure labels dtype matches the segmentation data
  • Loading branch information
larrybradley authored Jan 3, 2025
2 parents 08ac8a2 + 201a7e7 commit 3d06161
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ Bug Fixes
representation instead of ``str`` representation when using NumPy
2.0+. [#1956]

- Fixed a bug to ensure that the dtype of the ``SegmentationImage``
``labels`` always matches the image dtype. [#1986]

API Changes
^^^^^^^^^^^

Expand Down
4 changes: 2 additions & 2 deletions photutils/segmentation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_labels(data):
:meth:`remove_masked_labels` on a masked version of the
segmentation array.
"""
# np.unique also sorts elements
# np.unique preserves dtype and also sorts elements
return np.unique(data[data != 0])

@lazyproperty
Expand Down Expand Up @@ -209,7 +209,7 @@ def labels(self):
for label, slc in zip(labels_all, self._raw_slices, strict=True):
if slc is not None:
labels.append(label)
return np.array(labels)
return np.array(labels, dtype=self._data.dtype)

return self._get_labels(self.data)

Expand Down
2 changes: 1 addition & 1 deletion photutils/segmentation/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _detect_sources(data, threshold, npixels, footprint, inverse_mask, *,
# NOTE: recasting segment_img to int and using output=segment_img
# gives similar performance
segment_img, nlabels = ndi_label(segment_img, structure=footprint)
labels = np.arange(nlabels) + 1
labels = np.arange(nlabels, dtype=segment_img.dtype) + 1

# remove objects with less than npixels
# NOTE: making cutout images and setting their pixels to 0 is
Expand Down
3 changes: 3 additions & 0 deletions photutils/segmentation/tests/test_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def test_detection(self):
segm = detect_sources(self.data, threshold=0.9, npixels=2)
assert_equal(segm.data, self.refdata)

assert segm.data.dtype == np.int32
assert segm.labels.dtype == np.int32

segm = detect_sources(self.data << u.uJy, threshold=0.9 * u.uJy,
npixels=2)
assert_equal(segm.data, self.refdata)
Expand Down

0 comments on commit 3d06161

Please sign in to comment.