Skip to content

Commit

Permalink
Add function for interactive template matching plot
Browse files Browse the repository at this point in the history
  • Loading branch information
viljarjf committed Nov 7, 2023
1 parent da9a045 commit 03e7e8d
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions pyxem/utils/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from hyperspy.utils.markers import point
from pyxem.utils.polar_transform_utils import (
get_template_cartesian_coordinates,
get_template_polar_coordinates,
Expand Down Expand Up @@ -113,3 +115,138 @@ def plot_template_over_pattern(
color=marker_color,
)
return (ax, im, sp)

def plot_templates_over_signal(
signal,
library,
result: dict,
phase_key_dict: dict,
n_best: int = None,
find_direct_beam: bool = False,
direct_beam_position: tuple[int, int] = None,
marker_colors: list[str] = None,
marker_type: str = "x",
size_factor: float = 1.0,
verbose: bool = True,
**plot_kwargs
):

if n_best is None:
n_best = result["template_index"].shape[2]

if direct_beam_position is None and not find_direct_beam:
direct_beam_position = (0, 0)

if marker_colors is None:
marker_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

if len(marker_colors) < n_best:
print("Warning: not enough colors in `marker_colors` for `n_best` different colored marks")

# Add markers as iterable, recommended in hyperspy docs.
# Using a genenerator will hopefully reduce memory.
# To avoid scope errors, pass all variables as inputs
def _get_markers_iter(
signal,
library,
result,
phase_key_dict,
n_best,
find_direct_beam,
direct_beam_position,
marker_colors,
marker_type,
size_factor,
verbose,
):

# Hyperspy wants one marker for all pixels in the navigation space,
# so we generate all the data for a given solution and then yield them

# Allocate space for all navigator pixels to potentially have the maximum amount of simulated diffraction spots
max_marker_count = max(len(sim) for lib in library.values() for sim in lib["simulations"])

shape = (signal.axes_manager[1].size, signal.axes_manager[0].size, max_marker_count)

# Explicit zeroes instead of empty, since we won't fill all elements in the final axis
marker_data_x = np.zeros(shape)
marker_data_y = np.zeros(shape)
marker_data_i = np.zeros(shape)

for n in range(n_best):
color = marker_colors[n % len(marker_colors)]

# Generate data for a given solution index.
x_iter = range(signal.axes_manager[0].size)
if verbose:
x_iter = tqdm(x_iter)

for px in x_iter:
for py in range(signal.axes_manager[1].size):

sim_sol_index = result["template_index"][py, px, n]
mirrored_sol = result["mirrored_template"][py, px, n]
in_plane_angle = result["orientation"][py, px, n, 0]

phase_key = result["phase_index"][py, px, n]
phase = phase_key_dict[phase_key]
simulations = library[phase]["simulations"]
pattern = simulations[sim_sol_index]

if find_direct_beam:
x, y = find_beam_center_blur(signal.inav[px, py], 1)

# The result of `find_beam_center_blur` is in a corner.
# Move to center of image
x -= signal.axes_manager[2].size // 2
y -= signal.axes_manager[3].size // 2
direct_beam_position = (x, y)

x, y, intensities = get_template_cartesian_coordinates(
pattern,
center=direct_beam_position,
in_plane_angle=in_plane_angle,
mirrored=mirrored_sol,
)

x *= signal.axes_manager[2].scale
y *= signal.axes_manager[3].scale

marker_count = len(x)
marker_data_x[py, px, :marker_count] = x
marker_data_y[py, px, :marker_count] = y
marker_data_i[py, px, :marker_count] = intensities

marker_kwargs = {
"color": color,
"marker": marker_type,
"label": f"Solution index: {n}",
}

# Plot for the given solution index
for i in range(max_marker_count):
yield point(
marker_data_x[..., i],
marker_data_y[..., i],
size = 4 * np.sqrt(marker_data_i[..., i]) * size_factor,
**marker_kwargs
)
# We only need one set of labels per solution
if i == 0:
marker_kwargs.pop("label")

signal.plot(**plot_kwargs)
signal.add_marker(_get_markers_iter(
signal,
library,
result,
phase_key_dict,
n_best,
find_direct_beam,
direct_beam_position,
marker_colors,
marker_type,
size_factor,
verbose,
))
plt.gcf().legend()

0 comments on commit 03e7e8d

Please sign in to comment.