Skip to content

Commit

Permalink
refactored roi_ranges out of the model
Browse files Browse the repository at this point in the history
  • Loading branch information
ashmeigh committed Jan 10, 2025
1 parent 3f099d2 commit c7ce611
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 63 deletions.
19 changes: 9 additions & 10 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,13 @@ class SpectrumViewerWindowModel:
_normalise_stack: ImageStack | None = None
tof_range: tuple[int, int] = (0, 0)
tof_plot_range: tuple[float, float] | tuple[int, int] = (0, 0)
_roi_ranges: dict[str, SensibleROI]
tof_mode: ToFUnitMode = ToFUnitMode.WAVELENGTH
tof_data: np.ndarray | None = None
tof_range_full: tuple[int, int] = (0, 0)

def __init__(self, presenter: SpectrumViewerWindowPresenter):
self.presenter = presenter
self._roi_id_counter = 0
self._roi_ranges = {}
self.special_roi_list = [ROI_ALL]

self.units = UnitConversion()
Expand Down Expand Up @@ -128,8 +126,6 @@ def set_stack(self, stack: ImageStack | None) -> None:
self.tof_range = (0, stack.data.shape[0] - 1)
self.tof_range_full = self.tof_range
self.tof_data = self.get_stack_time_of_flight()
height, width = self.get_image_shape()
self._roi_ranges[ROI_ALL] = SensibleROI.from_list([0, 0, width, height])

def set_normalise_stack(self, normalise_stack: ImageStack | None) -> None:
self._normalise_stack = normalise_stack
Expand Down Expand Up @@ -337,14 +333,14 @@ def save_csv(self,
csv_output.write(outfile)
self.save_roi_coords(self.get_roi_coords_filename(path))

def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode) -> None:
def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode, roi: SensibleROI) -> None:
"""
Saves the spectrum for the RITS ROI to a RITS file.
@param path: The path to save the CSV file to.
@param error_mode: Which version (standard deviation or propagated) of the error to use in the RITS export.
"""
self.save_rits_roi(path, error_mode, self._roi_ranges[ROI_RITS])
self.save_rits_roi(path, error_mode, roi)

def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI, normalise: bool = False) -> None:
"""
Expand Down Expand Up @@ -401,6 +397,7 @@ def save_rits_images(self,
error_mode: ErrorMode,
bin_size: int,
step: int,
roi: SensibleROI,
normalise: bool = False,
progress: Progress | None = None) -> None:
"""
Expand All @@ -420,11 +417,13 @@ def save_rits_images(self,
error_mode (ErrorMode): The error mode to use when saving the images.
bin_size (int): The size of the sub-regions.
step (int): The step size to use when sliding the window across the ROI.
roi (SensibleROI): The parent ROI to be subdivided.
normalise (bool): If True, the images will be normalised.
progress (Progress | None): Optional progress reporter.
Returns:
None
"""
roi = self._roi_ranges[ROI_RITS]
left, top, right, bottom = roi
x_iterations = min(ceil((right - left) / step), ceil((right - left - bin_size) / step) + 1)
y_iterations = min(ceil((bottom - top) / step), ceil((bottom - top - bin_size) / step) + 1)
Expand Down Expand Up @@ -463,16 +462,17 @@ def get_roi_coords_filename(self, path: Path) -> Path:
"""
return path.with_stem(f"{path.stem}_roi_coords")

def save_roi_coords(self, path: Path) -> None:
def save_roi_coords(self, path: Path, rois: dict[str, SensibleROI]) -> None:
"""
Save the coordinates of the ROIs to a csv file (ROI name, x_min, x_max, y_min, y_max)
following Pascal VOC format.
@param path: The path to save the CSV file to.
@param rois: A dictionary of ROI names and their coordinates.
"""
with open(path, encoding='utf-8', mode='w') as f:
csv_writer = csv.DictWriter(f, fieldnames=["ROI", "X Min", "X Max", "Y Min", "Y Max"])
csv_writer.writeheader()
for roi_name, coords in self._roi_ranges.items():
for roi_name, coords in rois.items():
csv_writer.writerow({
"ROI": roi_name,
"X Min": coords.left,
Expand All @@ -494,7 +494,6 @@ def remove_all_roi(self) -> None:
Remove all ROIs from the model
"""
self._roi_id_counter = 0
self._roi_ranges = {}

def set_relevant_tof_units(self) -> None:
if self._stack is not None:
Expand Down
24 changes: 14 additions & 10 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,24 @@ def handle_rits_export(self) -> None:
if path is None:
LOG.debug("No path selected, aborting export")
return
run_function = partial(self.model.save_rits_images,
path,
error_mode,
self.view.bin_size,
self.view.bin_step,
normalise=self.view.shuttercount_norm_enabled())

run_function = partial(
self.model.save_rits_images,
path,
error_mode,
self.view.bin_size,
self.view.bin_step,
normalise=self.view.shuttercount_norm_enabled(),
)
start_async_task_view(self.view, run_function, self._async_save_done)

else:
path = self.view.get_rits_export_filename()
if path is None:
LOG.debug("No path selected, aborting export")
return
if path and path.suffix != ".dat":
path = path.with_suffix(".dat")
self.model.save_single_rits_spectrum(path, error_mode)
roi = self.view.spectrum_widget.get_roi(ROI_RITS)
self.model.save_single_rits_spectrum(path, error_mode, roi)

def _async_save_done(self, task: TaskWorkerThread) -> None:
if task.error is not None:
Expand Down Expand Up @@ -349,7 +350,10 @@ def change_roi_colour(self, roi_name: str, new_colour: tuple[int, int, int]) ->
self.view.on_visibility_change()

def add_rits_roi(self) -> None:
roi = self.model._roi_ranges.setdefault(ROI_RITS, SensibleROI.from_list([0, 0, *self.model.get_image_shape()]))
"""
Add the RITS ROI to the spectrum widget and initialize it with default dimensions.
"""
roi = SensibleROI.from_list([0, 0, *self.model.get_image_shape()])
self.view.spectrum_widget.add_roi(roi, ROI_RITS)
self.view.set_spectrum(ROI_RITS,
self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
Expand Down
56 changes: 21 additions & 35 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def _set_sample_stack(self, with_tof=False, with_shuttercount=False):
mock_shuttercounts.get_column.return_value = np.arange(5, 15)
stack._shutter_count_file = mock_shuttercounts
self.model.set_stack(stack)
height, width = stack.data.shape[1], stack.data.shape[2]
self.model._roi_ranges["roi"] = SensibleROI.from_list([0, 0, width, height])
return stack, spectrum

def _set_normalise_stack(self, with_shuttercount=False):
Expand Down Expand Up @@ -151,11 +149,8 @@ def test_normalise_issue(self):
self.assertEqual("", self.model.normalise_issue())

def test_set_stack_sets_roi(self):
self._set_sample_stack()
roi_all = self.model._roi_ranges["all"]
roi = self.model._roi_ranges["roi"]

self.assertEqual(roi_all, roi)
stack, _ = self._set_sample_stack()
roi_all = SensibleROI.from_list([0, 0, stack.data.shape[2], stack.data.shape[1]])
self.assertEqual(roi_all.top, 0)
self.assertEqual(roi_all.left, 0)
self.assertEqual(roi_all.right, 12)
Expand Down Expand Up @@ -228,14 +223,12 @@ def test_save_rits_roi_dat(self):
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2

self.model._roi_ranges["rits_roi"] = SensibleROI.from_list([0, 0, 10, 11])
self.model.set_normalise_stack(norm)
self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 10, 11])
roi = SensibleROI.from_list([0, 0, 10, 11])
mock_stream, mock_path = self._make_mock_path_stream()

with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, self.model._roi_ranges["ROI_RITS"])
self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, roi)
mock_path.open.assert_called_once_with("w")
self.assertIn("0.0\t0.0\t0.0", mock_stream.captured[0])
self.assertIn("100000.0\t0.75\t0.25", mock_stream.captured[1])
Expand All @@ -253,11 +246,12 @@ def test_save_rits_data_errors(self, _, error_mode, expected_error):
stack.data[:, :, :5] *= 2
self.model.set_normalise_stack(norm)

self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 10, 11])
roi = SensibleROI.from_list([0, 0, 10, 11])
mock_stream, mock_path = self._make_mock_path_stream()

with mock.patch.object(self.model, "save_roi_coords"):
with mock.patch.object(self.model, "export_spectrum_to_rits") as mock_export:
self.model.save_rits_roi(mock_path, error_mode, self.model._roi_ranges["ROI_RITS"])
self.model.save_rits_roi(mock_path, error_mode, roi)

calculated_errors = mock_export.call_args[0][3]
np.testing.assert_allclose(expected_error, calculated_errors, atol=1e-4)
Expand All @@ -279,7 +273,6 @@ def test_save_rits_no_norm_err(self):
mock_inst_log = mock.create_autospec(InstrumentLog, source_file="", instance=True)
stack.log_file = mock_inst_log
roi = SensibleROI.from_list([0, 0, 12, 11])
self.model._roi_ranges["ROI_RITS"] = roi

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
Expand All @@ -299,8 +292,9 @@ def test_save_rits_no_tof_err(self):

def test_WHEN_save_csv_called_THEN_save_roi_coords_called_WITH_correct_args(self):
path = Path("test_file.csv")
rois = {"roi1": SensibleROI.from_list([0, 0, 10, 10]), "roi2": SensibleROI.from_list([10, 10, 20, 20])}
with mock.patch('builtins.open', mock.mock_open()) as mock_open:
self.model.save_roi_coords(path)
self.model.save_roi_coords(path, rois)
mock_open.assert_called_once_with(path, encoding='utf-8', mode='w')

def test_WHEN_get_roi_coords_filename_called_THEN_correct_filename_returned(self):
Expand Down Expand Up @@ -388,12 +382,8 @@ def test_WHEN_stack_value_set_THEN_can_export_returns_(self, _, image_stack, exp

def test_WHEN_remove_all_rois_called_THEN_all_but_default_rois_removed(self):
self.model.set_stack(generate_images())
rois = ["new_roi", "new_roi_2"]
for roi in rois:
self.model._roi_ranges[roi] = SensibleROI.from_list([0, 0, 10, 10])
self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"] + rois)
self.model.remove_all_roi()
self.assertListEqual(list(self.model._roi_ranges.keys()), [])
self.assertEqual(self.model._roi_id_counter, 0)

def test_WHEN_no_stack_tof_THEN_time_of_flight_none(self):
# No Stack
Expand Down Expand Up @@ -449,43 +439,40 @@ def test_save_rits_images_write_correct_number_of_files(self, _, roi_size, bin_s
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2
roi_name = "rits_roi"
roi = SensibleROI.from_list([0, 0, roi_size, roi_size])
self.model._roi_ranges[roi_name] = roi
self.model.set_normalise_stack(norm)

Mx, My = roi.width, roi.height
x_iterations = min(math.ceil(Mx / step), math.ceil((Mx - bin_size) / step) + 1)
y_iterations = min(math.ceil(My / step), math.ceil((My - bin_size) / step) + 1)
expected_number_of_calls = x_iterations * y_iterations
_, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, bin_size, step)
self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, bin_size, step, roi)

self.assertEqual(mock_save_rits_roi.call_count, expected_number_of_calls)

@mock.patch.object(SpectrumViewerWindowModel, "save_rits_roi")
def test_save_single_rits_spectrum(self, mock_save_rits_roi):
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2
self.model._roi_ranges["rits_roi"] = SensibleROI.from_list([0, 0, 5, 5])
self.model.set_normalise_stack(norm)
roi = SensibleROI.from_list([0, 0, 5, 5])
_, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_single_rits_spectrum(mock_path, ErrorMode.STANDARD_DEVIATION)
mock_save_rits_roi.assert_called_once_with(mock_path, mock.ANY, SensibleROI.from_list([0, 0, 5, 5]))
self.model.save_single_rits_spectrum(mock_path, ErrorMode.STANDARD_DEVIATION, roi)
mock_save_rits_roi.assert_called_once_with(mock_path, mock.ANY, roi)

@mock.patch.object(SpectrumViewerWindowModel, "export_spectrum_to_rits")
def test_save_rits_correct_transmission(self, mock_save_rits_roi):
stack, spectrum = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
for i in range(10):
stack.data[:, :, i] *= i
self.model._roi_ranges["rits_roi"] = SensibleROI.from_list([1, 0, 6, 4])
self.model.set_normalise_stack(norm)
mock_path = mock.create_autospec(Path, instance=True)

self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, 3, 1)
roi = SensibleROI.from_list([1, 0, 6, 4])
self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, 3, 1, roi)

self.assertEqual(6, len(mock_save_rits_roi.call_args_list))
expected_means = [1, 1.5, 2, 1, 1.5, 2] # running average of [1, 2, 3, 4, 5], divided by 2 for normalisation
Expand Down Expand Up @@ -540,20 +527,19 @@ def test_get_transmission_error_standard_dev(self):
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]
roi = SensibleROI.from_list([0, 0, 5, 5])
self.model._roi_ranges["roi"] = roi

left, top, right, bottom = roi
sample = stack.data[:, top:bottom, left:right]
open = normalise_stack.data[:, top:bottom, left:right]
expected = np.divide(sample, open, out=np.zeros_like(sample), where=open != 0) / average_shutter_counts
expected = np.std(expected, axis=(1, 2))

with (mock.patch.object(self.model,
"get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as
mock_get_shuttercount_normalised_correction_parameter):
with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_standard_dev(roi, normalise_with_shuttercount=True)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

Expand Down
13 changes: 5 additions & 8 deletions mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,19 @@ def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock, mock_
normalise_with_shuttercount=False)

@parameterized.expand(["/fake/path", "/fake/path.dat"])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi")
def test_handle_rits_export(self, path_name: str, mock_save_rits_roi: mock.Mock):
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_single_rits_spectrum")
def test_handle_rits_export(self, path_name: str, mock_save_single_rits_spectrum: mock.Mock):
self.view.get_rits_export_filename = mock.Mock(return_value=Path(path_name))
self.view.transmission_error_mode = "Standard Deviation"

mock_roi = SensibleROI.from_list([0, 0, 5, 5])
self.presenter.model._roi_ranges[ROI_RITS] = mock_roi
self.view.spectrum_widget.get_roi = mock.Mock(return_value=mock_roi)
self.presenter.model.set_stack(generate_images())
self.presenter.handle_rits_export()

self.view.get_rits_export_filename.assert_called_once()
mock_save_rits_roi.assert_called_once_with(Path("/fake/path.dat"), ErrorMode.STANDARD_DEVIATION, mock_roi)
mock_save_single_rits_spectrum.assert_called_once_with(Path("/fake/path.dat"), ErrorMode.STANDARD_DEVIATION,
mock_roi)

def test_WHEN_do_add_roi_called_THEN_new_roi_added(self):
self.view.spectrum_widget.roi_dict = {"all": mock.Mock()}
Expand Down Expand Up @@ -292,19 +293,15 @@ def test_WHEN_ROI_renamed_THEN_roi_renamed(self):
def test_WHEN_invalid_ROI_renamed_THEN_error_raised(self):
rois = ["all", "roi", "roi_1"]
self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois}
self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois}
self.view.spectrum_widget.rename_roi = mock.Mock(side_effect=KeyError("Invalid ROI"))
self.view.spectrum_widget.rois = {roi: mock.Mock() for roi in rois}
with self.assertRaises(KeyError):
self.presenter.rename_roi("invalid_roi", "new_name")

def test_WHEN_do_remove_roi_called_with_no_arguments_THEN_all_rois_removed(self):
rois = ["all", "roi", "roi_1", "roi_2"]
self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois}
self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois}
self.presenter.do_remove_roi()
self.assertEqual(self.view.spectrum_widget.roi_dict, {})
self.assertEqual(self.presenter.model._roi_ranges, {})

@parameterized.expand([("Image Index", ToFUnitMode.IMAGE_NUMBER), ("Wavelength", ToFUnitMode.WAVELENGTH),
("Energy", ToFUnitMode.ENERGY), ("Time of Flight (\u03BCs)", ToFUnitMode.TOF_US)])
Expand Down

0 comments on commit c7ce611

Please sign in to comment.