diff --git a/pyxem/utils/plotting_utils.py b/pyxem/utils/plotting_utils.py index b3f932304..22155f623 100644 --- a/pyxem/utils/plotting_utils.py +++ b/pyxem/utils/plotting_utils.py @@ -1,5 +1,7 @@ import matplotlib.pyplot as plt import numpy as np +from tqdm import tqdm +from hyperspy.utils.markers import point from pyxem.utils.polar_transform_utils import ( get_template_cartesian_coordinates, get_template_polar_coordinates, @@ -113,3 +115,138 @@ def plot_template_over_pattern( color=marker_color, ) return (ax, im, sp) + +def plot_templates_over_signal( + signal, + library, + result: dict, + phase_key_dict: dict, + n_best: int = None, + find_direct_beam: bool = False, + direct_beam_position: tuple[int, int] = None, + marker_colors: list[str] = None, + marker_type: str = "x", + size_factor: float = 1.0, + verbose: bool = True, + **plot_kwargs + ): + + if n_best is None: + n_best = result["template_index"].shape[2] + + if direct_beam_position is None and not find_direct_beam: + direct_beam_position = (0, 0) + + if marker_colors is None: + marker_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + + if len(marker_colors) < n_best: + print("Warning: not enough colors in `marker_colors` for `n_best` different colored marks") + + # Add markers as iterable, recommended in hyperspy docs. + # Using a genenerator will hopefully reduce memory. + # To avoid scope errors, pass all variables as inputs + def _get_markers_iter( + signal, + library, + result, + phase_key_dict, + n_best, + find_direct_beam, + direct_beam_position, + marker_colors, + marker_type, + size_factor, + verbose, + ): + + # Hyperspy wants one marker for all pixels in the navigation space, + # so we generate all the data for a given solution and then yield them + + # Allocate space for all navigator pixels to potentially have the maximum amount of simulated diffraction spots + max_marker_count = max(len(sim) for lib in library.values() for sim in lib["simulations"]) + + shape = (signal.axes_manager[1].size, signal.axes_manager[0].size, max_marker_count) + + # Explicit zeroes instead of empty, since we won't fill all elements in the final axis + marker_data_x = np.zeros(shape) + marker_data_y = np.zeros(shape) + marker_data_i = np.zeros(shape) + + for n in range(n_best): + color = marker_colors[n % len(marker_colors)] + + # Generate data for a given solution index. + x_iter = range(signal.axes_manager[0].size) + if verbose: + x_iter = tqdm(x_iter) + + for px in x_iter: + for py in range(signal.axes_manager[1].size): + + sim_sol_index = result["template_index"][py, px, n] + mirrored_sol = result["mirrored_template"][py, px, n] + in_plane_angle = result["orientation"][py, px, n, 0] + + phase_key = result["phase_index"][py, px, n] + phase = phase_key_dict[phase_key] + simulations = library[phase]["simulations"] + pattern = simulations[sim_sol_index] + + if find_direct_beam: + x, y = find_beam_center_blur(signal.inav[px, py], 1) + + # The result of `find_beam_center_blur` is in a corner. + # Move to center of image + x -= signal.axes_manager[2].size // 2 + y -= signal.axes_manager[3].size // 2 + direct_beam_position = (x, y) + + x, y, intensities = get_template_cartesian_coordinates( + pattern, + center=direct_beam_position, + in_plane_angle=in_plane_angle, + mirrored=mirrored_sol, + ) + + x *= signal.axes_manager[2].scale + y *= signal.axes_manager[3].scale + + marker_count = len(x) + marker_data_x[py, px, :marker_count] = x + marker_data_y[py, px, :marker_count] = y + marker_data_i[py, px, :marker_count] = intensities + + marker_kwargs = { + "color": color, + "marker": marker_type, + "label": f"Solution index: {n}", + } + + # Plot for the given solution index + for i in range(max_marker_count): + yield point( + marker_data_x[..., i], + marker_data_y[..., i], + size = 4 * np.sqrt(marker_data_i[..., i]) * size_factor, + **marker_kwargs + ) + # We only need one set of labels per solution + if i == 0: + marker_kwargs.pop("label") + + signal.plot(**plot_kwargs) + signal.add_marker(_get_markers_iter( + signal, + library, + result, + phase_key_dict, + n_best, + find_direct_beam, + direct_beam_position, + marker_colors, + marker_type, + size_factor, + verbose, + )) + plt.gcf().legend()