Skip to content

Commit

Permalink
Changed the polyfit to compute function
Browse files Browse the repository at this point in the history
Signed-off-by: ashmeigh <[email protected]>
  • Loading branch information
ashmeigh authored and samtygier-stfc committed Jul 15, 2024
1 parent cb05980 commit a5e7ada
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions mantidimaging/core/rotation/polyfit_correlation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (C) 2024 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
from __future__ import annotations

from logging import getLogger
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -26,15 +25,30 @@ def do_calculate_correlation_err(store: np.ndarray, search_index: int, p0_and_18


def find_center(images: ImageStack, progress: Progress) -> tuple[ScalarCoR, Degrees]:
# assume the ROI is the full image, i.e. the slices are ALL rows of the image
slices = np.arange(images.height)
shift = pu.create_array((images.height, ))
if images is None or images.proj180deg is None:
raise ValueError("images and images.proj180deg cannot be None")

# Assume the ROI is the full image, i.e. the slices are ALL rows of the image
slices = np.arange(images.height)
shift = pu.create_array((images.height, ), dtype=np.float32)
search_range = get_search_range(images.width)
min_correlation_error = pu.create_array((len(search_range), images.height))
min_correlation_error = pu.create_array((len(search_range), images.height), dtype=np.float32)
shared_search_range = pu.create_array((len(search_range), ), dtype=np.int32)
shared_search_range.array[:] = np.asarray(search_range, dtype=np.int32)
_calculate_correlation_error(images, shared_search_range, min_correlation_error, progress)

# Copy projections to shared memory
shared_projections = pu.create_array((2, images.height, images.width), dtype=np.float32)
shared_projections.array[0][:] = images.projection(0)
shared_projections.array[1][:] = np.fliplr(images.proj180deg.data[0])

# Prepare parameters for the compute function
params = {
'image_width': images.width,
}
ps.run_compute_func(compute_correlation_error,
len(search_range), [min_correlation_error, shared_projections, shared_search_range],
params,
progress=progress)

# Originally the output of do_search is stored in dimensions
# corresponding to (search_range, square sum). This is awkward to navigate
Expand All @@ -46,44 +60,41 @@ def find_center(images: ImageStack, progress: Progress) -> tuple[ScalarCoR, Degr
m = par[0]
q = par[1]
LOG.debug(f"m={m}, q={q}")

theta = Degrees(np.rad2deg(np.arctan(0.5 * m)))
offset = np.round(m * images.height * 0.5 + q) * 0.5
LOG.info(f"found offset: {-offset} and tilt {theta}")
return ScalarCoR(images.h_middle + -offset), theta

return ScalarCoR(images.h_middle + -offset), theta

def _calculate_correlation_error(images, shared_search_range, min_correlation_error, progress):
# if the projections are passed in the partial they are copied to every process on every iteration
# this makes the multiprocessing significantly slower
# so they are copied into a shared array to avoid that copying
shared_projections = pu.create_array((2, images.height, images.width))
shared_projections.array[0][:] = images.projection(0)
shared_projections.array[1][:] = np.fliplr(images.proj180deg.data[0])

do_search_partial = ps.create_partial(do_calculate_correlation_err, ps.inplace3, image_width=images.width)
def compute_correlation_error(index: int, arrays: list[Any], params: dict[str, Any]):
min_correlation_error = arrays[0]
shared_projections = arrays[1]
shared_search_range = arrays[2]
image_width = params['image_width']

arrays = [min_correlation_error, shared_search_range, shared_projections]
ps.execute(do_search_partial,
arrays,
num_operations=min_correlation_error.array.shape[0],
progress=progress,
msg="Finding correlation on row")
search_index = shared_search_range[index]
do_calculate_correlation_err(min_correlation_error[index], search_index,
(shared_projections[0], shared_projections[1]), image_width)


def _find_shift(images: ImageStack, search_range: range, min_correlation_error: np.ndarray, shift: np.ndarray):
# Then we just find the index of the minimum one (minimum error)
min_correlation_error = np.transpose(min_correlation_error)
# argmin returns a list of where the minimum argument is found
# just in case that happens - get the first minimum one, should be close enough
for row in range(images.height):
# then we just find the index of the minimum one (minimum error)
min_arg_positions = min_correlation_error[row].argmin()
# argmin returns a list of where the minimum argument is found
# just in case that happens - get the first minimum one, should be close enough
min_arg = min_arg_positions if isinstance(min_arg_positions, np.int64) else min_arg_positions[0]
# and we get which search range is at that index
# And we get which search range is at that index
# that is the number that we then pass into polyfit
shift[row] = search_range[min_arg]

return shift


def get_search_range(width):
def get_search_range(width: int) -> range:
tmin = -width // 2
tmax = width - width // 2
search_range = range(tmin, tmax + 1)
Expand Down

0 comments on commit a5e7ada

Please sign in to comment.