Skip to content

Commit

Permalink
Merge branch 'main' into toggle_shuttercount
Browse files Browse the repository at this point in the history
  • Loading branch information
JackEAllen authored Jun 24, 2024
2 parents bf6fd5b + 7af6d74 commit ed3c260
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#2213 : unit tests have been added to check that the Time of Flight modes behave correctly when switching between stacks
1 change: 1 addition & 0 deletions docs/release_notes/next/fix-2219-live-view-del-dir
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#2219: Live viewer: informative error if live directory deleted
2 changes: 0 additions & 2 deletions environment-dev.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: mantidimaging-dev
channels:
- mantidimaging/label/unstable
- dtasev
- astra-toolbox
- conda-forge
- ccpi
Expand Down Expand Up @@ -29,7 +28,6 @@ dependencies:
- pyfakefs==5.3.*
- parameterized==0.9.*
- pyinstaller==6.1.*
- sarepy=2020.07 # For building old docs
- make==4.3
- ruff=0.3.3
- pre-commit==3.5.*
Expand Down
63 changes: 13 additions & 50 deletions mantidimaging/core/operations/flat_fielding/flat_fielding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from mantidimaging import helper as h
from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup
from mantidimaging.core.parallel import utility as pu, shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.parallel import shared as ps
from mantidimaging.gui.utility.qt_helpers import Type
from mantidimaging.gui.widgets.dataset_selector import DatasetSelectorWidgetView

Expand Down Expand Up @@ -75,7 +74,6 @@ def filter_func(images: ImageStack,
use_dark: bool = True,
progress=None) -> ImageStack:
"""Do background correction with flat and dark images.
:param images: Sample data which is to be processed. Expected in radiograms
:param flat_before: Flat (open beam) image to use in normalization, collected before the sample was imaged
:param flat_after: Flat (open beam) image to use in normalization, collected after the sample was imaged
Expand All @@ -87,7 +85,6 @@ def filter_func(images: ImageStack,
:return: Filtered data (stack of images)
"""
h.check_data_stack(images)

if selected_flat_fielding not in ["Both, concatenated", "Only Before", "Only After"]:
raise ValueError(f"Invalid flat fielding method: {selected_flat_fielding}")

Expand Down Expand Up @@ -130,19 +127,26 @@ def filter_func(images: ImageStack,
raise ValueError(
f"Incorrect shape of the flat image ({flat_avg.shape}) or dark image ({dark_avg.shape}) "
f"which should match the shape of the sample images ({images.data.shape[1:]})")

if not (images.data.shape[1:] == flat_avg.shape == dark_avg.shape):
raise ValueError(f"Not all images are the expected shape: {images.data.shape[1:]}, instead "
f"flat had shape: {flat_avg.shape}, and dark had shape: {dark_avg.shape}")

progress = Progress.ensure_instance(progress,
num_steps=images.data.shape[0],
task_name='Background Correction')
_execute(images, flat_avg, dark_avg, progress)
params = {'flat_avg': flat_avg, 'dark_avg': dark_avg}
ps.run_compute_func(FlatFieldFilter._compute_flat_field, len(images.data), [images.shared_array], params)

h.check_data_stack(images)
return images

@staticmethod
def _compute_flat_field(index: int, array: np.ndarray, params: dict):
flat_avg = params['flat_avg']
dark_avg = params['dark_avg']

norm_divide = flat_avg - dark_avg
norm_divide[norm_divide == 0] = MINIMUM_PIXEL_VALUE
array[index] -= dark_avg
array[index] /= norm_divide

@staticmethod
def register_gui(form, on_change, view) -> dict[str, Any]:
from mantidimaging.gui.utility import add_property_to_form
Expand Down Expand Up @@ -276,44 +280,3 @@ def validate_execute_kwargs(kwargs):
@staticmethod
def group_name() -> FilterGroup:
return FilterGroup.Basic


def _divide(data, norm_divide):
np.true_divide(data, norm_divide, out=data)


def _subtract(data, dark=None):
# specify out to do in place, otherwise the data is copied
np.subtract(data, dark, out=data)


def _norm_divide(flat: np.ndarray, dark: np.ndarray) -> np.ndarray:
# subtract dark from flat
return np.subtract(flat, dark)


def _execute(images: ImageStack, flat=None, dark=None, progress=None):
with progress:
progress.update(msg="Applying background correction")

if images.uses_shared_memory:
shared_dark = pu.copy_into_shared_memory(dark)
norm_divide = pu.copy_into_shared_memory(_norm_divide(flat, dark))
else:
shared_dark = pu.SharedArray(dark, None)
norm_divide = pu.SharedArray(_norm_divide(flat, dark), None)

# prevent divide-by-zero issues, and negative pixels make no sense
norm_divide.array[norm_divide.array == 0] = MINIMUM_PIXEL_VALUE

# subtract the dark from all images
do_subtract = ps.create_partial(_subtract, fwd_function=ps.inplace_second_2d)
arrays = [images.shared_array, shared_dark]
ps.execute(do_subtract, arrays, images.data.shape[0], progress)

# divide the data by (flat - dark)
do_divide = ps.create_partial(_divide, fwd_function=ps.inplace_second_2d)
arrays = [images.shared_array, norm_divide]
ps.execute(do_divide, arrays, images.data.shape[0], progress)

return images
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def filter_func(images: ImageStack, sigma=3, size=21, window_dim=1, filtering_di
:param window_dim: Whether to perform the median on 1D or 2D view of the
data.
:param filtering_dim: Whether to use a 1D or 2D low-pass filter. This
uses different Sarepy methods.
uses different Algotom methods.
:return: The ImageStack object with the stripes removed using the
filtering and sorting technique.
Expand Down Expand Up @@ -119,7 +119,7 @@ def register_gui(form, on_change, view):
form=form,
on_change=on_change,
tooltip="Whether to use a 1D or 2D low-pass filter. "
"This uses different Sarepy methods")
"This uses different Algotom methods")
return {'sigma': sigma, 'size': size, 'window_dim': window_dim, 'filtering_dim': filtering_dim}

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions mantidimaging/gui/windows/live_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def _handle_directory_change(self) -> None:
self.add_sub_directory(this_dir)

self.clear_deleted_sub_directories(directory_path)
if not self.sub_directories:
raise FileNotFoundError(f"Live directory not found: {self.directory}"
"\nHas it been deleted?")
self.find_sub_directories(directory_path)
self.sort_sub_directory_by_modified_time()

Expand Down
16 changes: 16 additions & 0 deletions mantidimaging/gui/windows/live_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ def test_WHEN_find_images_deleted_file_THEN_handles_error(self):
images = [image.image_path for image in images_datas]
self._file_list_count_equal(images, file_list[1:])

def test_WHEN_empty_dir_THEN_no_files(self):
self.watcher.changed_directory = self.top_path
self.watcher._handle_directory_change()

emitted_images = self._get_recent_emitted_files()
self._file_list_count_equal(emitted_images, [])

def test_WHEN_missing_dir_THEN_useful_error(self):
self.fs.rmdir(self.top_path)
self.watcher.changed_directory = self.top_path

with self.assertRaises(FileNotFoundError) as context_manager:
self.watcher._handle_directory_change()
self.assertIn("Live directory not found", str(context_manager.exception))
self.assertIn(str(self.top_path), str(context_manager.exception))

def test_WHEN_find_sub_directories_called_THEN_finds_subdirs(self):
self._file_in_sequence(self.top_path, self.watcher.sub_directories.keys())
self.assertEqual(len(self.watcher.sub_directories), 1)
Expand Down
11 changes: 9 additions & 2 deletions mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,19 @@ def __init__(self, name: str, sensible_roi: SensibleROI, *args, **kwargs):

def onChangeColor(self):
current_color = QColor(*self._colour)
selected_color = QColorDialog.getColor(current_color)
if selected_color.isValid():
selected_color = self.openColorDialog(current_color)
color_valid = self.check_color_valid(selected_color)
if color_valid:
new_color = (selected_color.red(), selected_color.green(), selected_color.blue(), 255)
self._colour = new_color
self.sig_colour_change.emit(self._name, new_color)

def openColorDialog(self, current_color) -> QColor:
return QColorDialog.getColor(current_color)

def check_color_valid(self, get_colour) -> bool:
return get_colour.isValid()

def contextMenuEnabled(self):
return True

Expand Down
58 changes: 45 additions & 13 deletions mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from unittest import mock

import numpy as np
from PyQt5.QtWidgets import QPushButton, QActionGroup, QGroupBox, QAction, QCheckBox
from parameterized import parameterized

Expand Down Expand Up @@ -308,16 +309,55 @@ def test_WHEN_tof_unit_selected_THEN_model_mode_changes(self, mode_text, expecte
self.presenter.handle_tof_unit_change_via_menu()
self.assertEqual(self.presenter.model.tof_mode, expected_mode)

@parameterized.expand([
(None, ToFUnitMode.IMAGE_NUMBER),
(np.arange(1, 10), ToFUnitMode.WAVELENGTH),
])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.get_stack_time_of_flight")
def test_WHEN_no_spectrum_data_THEN_mode_is_image_index(self, get_stack_time_of_flight):
def test_WHEN_data_loaded_THEN_relevant_mode_set(self, tof_data, expected_tof_mode, get_stack_time_of_flight):
self.presenter.model.set_stack(generate_images())
self.presenter.get_dataset_id_for_stack = mock.Mock(return_value=uuid.uuid4())
self.presenter.main_window.get_stack = mock.Mock(return_value=generate_images())
get_stack_time_of_flight.return_value = None
get_stack_time_of_flight.return_value = tof_data
self.view.tof_units_mode = "Wavelength"
self.presenter.refresh_spectrum_plot = mock.Mock()
self.presenter.handle_sample_change(uuid.uuid4())
self.assertEqual(self.presenter.model.tof_mode, ToFUnitMode.IMAGE_NUMBER)
self.assertEqual(self.presenter.model.tof_mode, expected_tof_mode)

@parameterized.expand([
(None, "Image Index", ToFUnitMode.IMAGE_NUMBER, np.arange(1, 10), [False, True], ToFUnitMode.WAVELENGTH),
(np.arange(1, 10), "Wavelength", ToFUnitMode.WAVELENGTH, None, [True, False], ToFUnitMode.IMAGE_NUMBER),
(None, "Image Index", ToFUnitMode.IMAGE_NUMBER, None, [False, False], ToFUnitMode.IMAGE_NUMBER),
(np.arange(1, 10), "Wavelength", ToFUnitMode.WAVELENGTH, np.arange(2, 20), [True,
True], ToFUnitMode.WAVELENGTH),
(np.arange(1, 10), "Energy", ToFUnitMode.ENERGY, np.arange(2, 20), [True, True], ToFUnitMode.ENERGY),
(np.arange(1, 10), "Time of Flight (\u03BCs)", ToFUnitMode.TOF_US, np.arange(2, 20), [True,
True], ToFUnitMode.TOF_US)
])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.get_stack_time_of_flight")
def test_WHEN_switch_between_no_spectra_to_spectra_files_THEN_tof_modes_availability_set(
self, tof_data_before, tof_mode_text_before, tof_mode_before, tof_data_after, expected_calls, expected_mode,
get_stack_time_of_flight):
self.presenter.model.set_stack(generate_images())
self.presenter.get_dataset_id_for_stack = mock.Mock(return_value=uuid.uuid4())
self.presenter.main_window.get_stack = mock.Mock(return_value=generate_images())
get_stack_time_of_flight.return_value = tof_data_before
self.presenter.model.tof_mode = tof_mode_before
self.view.tof_units_mode = tof_mode_text_before
self.presenter.refresh_spectrum_plot = mock.Mock()
self.presenter.handle_sample_change(uuid.uuid4())

get_stack_time_of_flight.return_value = tof_data_after
self.presenter.handle_sample_change(uuid.uuid4())
expected_calls = [mock.call(b) for b in expected_calls]
self.view.tof_mode_select_group.setEnabled.assert_has_calls(expected_calls)
self.view.tofPropertiesGroupBox.setEnabled.assert_has_calls(expected_calls)
self.assertEqual(self.presenter.model.tof_mode, expected_mode)

def test_WHEN_no_stack_available_THEN_units_menu_disabled(self):
self.presenter.current_stack_uuid = uuid.uuid4()
self.presenter.handle_sample_change(None)
self.view.tof_mode_select_group.setEnabled.assert_called_once_with(False)

def test_WHEN_tof_flight_path_changed_THEN_unit_conversion_flight_path_set(self):
self.view.flightPathSpinBox = mock.Mock()
Expand All @@ -342,12 +382,7 @@ def test_WHEN_menu_option_selected_THEN_menu_option_changed(self):
self.presenter.check_action = mock.Mock()
self.view.tof_mode_select_group.actions = mock.Mock(return_value=menu_options)
self.presenter.change_selected_menu_option("opt2")
calls = [
mock.call(menu_options[0], False),
mock.call(menu_options[1], True),
mock.call(menu_options[2], False),
mock.call(menu_options[3], False)
]
calls = [mock.call(menu_options[a], b) for a, b in [(0, False), (1, True), (2, False), (3, False)]]
self.presenter.check_action.assert_has_calls(calls)

def test_WHEN_roi_changed_via_spinboxes_THEN_roi_adjusted(self):
Expand Down Expand Up @@ -395,10 +430,7 @@ def spec_roi_mock(name):
self.presenter.redraw_all_rois()
self.assertEqual(self.presenter.model.get_roi("all"), SensibleROI(0, 0, 10, 8))
self.assertEqual(self.presenter.model.get_roi("roi"), SensibleROI(1, 4, 3, 2))
calls = [
mock.call("all", mock.ANY),
mock.call("roi", mock.ANY),
]
calls = [mock.call(a, b) for a, b in [("all", mock.ANY), ("roi", mock.ANY)]]
self.view.set_spectrum.assert_has_calls(calls)

@parameterized.expand([("roi", "roi_clicked", "roi_clicked"), ("roi", ROI_RITS, "roi")])
Expand Down
24 changes: 20 additions & 4 deletions mantidimaging/gui/windows/spectrum_viewer/test/spectrum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import uuid
import numpy as np
from unittest import mock

from PyQt5.QtGui import QColor
from parameterized import parameterized

from pyqtgraph import Point
Expand All @@ -20,11 +22,25 @@
@start_qapplication
class SpectrumROITest(unittest.TestCase):

def setUp(self) -> None:
self.roi = SensibleROI(10, 20, 30, 40)
self.spectrum_roi = SpectrumROI("test_roi", self.roi)

def test_WHEN_initialise_THEN_pos_and_size_correct(self):
roi = SensibleROI(10, 20, 30, 40)
spectrum_roi = SpectrumROI("", roi)
self.assertEqual(spectrum_roi.getState()["pos"], Point(10, 20))
self.assertEqual(spectrum_roi.getState()["size"], Point(20, 20))
self.assertEqual(self.spectrum_roi.getState()["pos"], Point(10, 20))
self.assertEqual(self.spectrum_roi.getState()["size"], Point(20, 20))

def test_WHEN_colour_changed_THEN_roi_colour_is_set(self):
colour = (1, 2, 58, 255)
self.spectrum_roi.openColorDialog = mock.Mock(return_value=QColor(*colour))
self.spectrum_roi.onChangeColor()
self.assertEqual(self.spectrum_roi.colour, colour)

def test_WHEN_colour_is_not_valid_THEN_roi_colour_is_unchanged(self):
colour = (10, 20, 480, 255)
self.spectrum_roi.openColorDialog = mock.Mock(return_value=QColor(*colour))
self.spectrum_roi.onChangeColor()
self.assertEqual(self.spectrum_roi.colour, (0, 0, 0, 255))


@mock_versions
Expand Down

0 comments on commit ed3c260

Please sign in to comment.