Skip to content

Commit

Permalink
Run black
Browse files Browse the repository at this point in the history
  • Loading branch information
viljarjf committed Nov 8, 2023
1 parent 24f13b9 commit 014e1e0
Showing 1 changed file with 61 additions and 50 deletions.
111 changes: 61 additions & 50 deletions pyxem/utils/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def plot_template_over_pattern(
marker_color="red",
marker_type="x",
size_factor=1.0,
**kwargs
**kwargs,
):
"""
A quick utility function to plot a simulated pattern over an experimental image
Expand Down Expand Up @@ -116,22 +116,23 @@ def plot_template_over_pattern(
)
return (ax, im, sp)


def plot_templates_over_signal(
signal,
library,
result: dict,
phase_key_dict: dict,
n_best: int = None,
signal,
library,
result: dict,
phase_key_dict: dict,
n_best: int = None,
find_direct_beam: bool = False,
direct_beam_position: tuple[int, int] = None,
marker_colors: list[str] = None,
marker_type: str = "x",
size_factor: float = 1.0,
verbose: bool = True,
**plot_kwargs
):
**plot_kwargs,
):
"""
Display an interactive plot of the diffraction signal,
Display an interactive plot of the diffraction signal,
with simulated diffraction patterns corresponding to template matching results displayed on top.
Parameters
Expand All @@ -142,7 +143,7 @@ def plot_templates_over_signal(
The library of simulated diffraction patterns.
result : dict
Template matching results dictionary containing keys: phase_index, template_index,
orientation, correlation, and mirrored_template.
orientation, correlation, and mirrored_template.
Returned from pyxem.utils.indexation_utils.index_dataset_with_template_rotation.
phase_key_dict: dictionary
A small dictionary to translate the integers in the phase_index array
Expand Down Expand Up @@ -174,37 +175,46 @@ def plot_templates_over_signal(

if direct_beam_position is None and not find_direct_beam:
direct_beam_position = (0, 0)

if marker_colors is None:
marker_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
marker_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

if len(marker_colors) < n_best:
print("Warning: not enough colors in `marker_colors` for `n_best` different colored marks. Colors will loop")
print(
"Warning: not enough colors in `marker_colors` for `n_best` different colored marks. Colors will loop"
)

# Add markers as iterable, recommended in hyperspy docs.
# Using a genenerator will hopefully reduce memory.
# To avoid scope errors, pass all variables as inputs
def _get_markers_iter(
signal,
library,
result,
phase_key_dict,
n_best,
find_direct_beam,
direct_beam_position,
marker_colors,
marker_type,
size_factor,
signal,
library,
result,
phase_key_dict,
n_best,
find_direct_beam,
direct_beam_position,
marker_colors,
marker_type,
size_factor,
verbose,
):

):
# Hyperspy wants one marker for all pixels in the navigation space,
# so we generate all the data for a given solution and then yield them

# Allocate space for all navigator pixels to potentially have the maximum amount of simulated diffraction spots
max_marker_count = max(sim.intensities.size for lib in library.values() for sim in lib["simulations"])
max_marker_count = max(
sim.intensities.size
for lib in library.values()
for sim in lib["simulations"]
)

shape = (signal.axes_manager[1].size, signal.axes_manager[0].size, max_marker_count)
shape = (
signal.axes_manager[1].size,
signal.axes_manager[0].size,
max_marker_count,
)

# Explicit zeroes instead of empty, since we won't fill all elements in the final axis
marker_data_x = np.zeros(shape)
Expand All @@ -221,7 +231,6 @@ def _get_markers_iter(

for px in x_iter:
for py in range(signal.axes_manager[1].size):

sim_sol_index = result["template_index"][py, px, n]
mirrored_sol = result["mirrored_template"][py, px, n]
in_plane_angle = result["orientation"][py, px, n, 0]
Expand All @@ -233,8 +242,8 @@ def _get_markers_iter(

if find_direct_beam:
x, y = find_beam_center_blur(signal.inav[px, py], 1)
# The result of `find_beam_center_blur` is in a corner.

# The result of `find_beam_center_blur` is in a corner.
# Move to center of image
x -= signal.axes_manager[2].size // 2
y -= signal.axes_manager[3].size // 2
Expand All @@ -254,37 +263,39 @@ def _get_markers_iter(
marker_data_x[py, px, :marker_count] = x
marker_data_y[py, px, :marker_count] = y
marker_data_i[py, px, :marker_count] = intensities

marker_kwargs = {
"color": color,
"color": color,
"marker": marker_type,
"label": f"Solution index: {n}",
}

# Plot for the given solution index
for i in range(max_marker_count):
yield point(
marker_data_x[..., i],
marker_data_y[..., i],
size = 4 * np.sqrt(marker_data_i[..., i]) * size_factor,
**marker_kwargs
)
marker_data_x[..., i],
marker_data_y[..., i],
size=4 * np.sqrt(marker_data_i[..., i]) * size_factor,
**marker_kwargs,
)
# We only need one set of labels per solution
if i == 0:
marker_kwargs.pop("label")

signal.plot(**plot_kwargs)
signal.add_marker(_get_markers_iter(
signal,
library,
result,
phase_key_dict,
n_best,
find_direct_beam,
direct_beam_position,
marker_colors,
marker_type,
size_factor,
verbose,
))
signal.add_marker(
_get_markers_iter(
signal,
library,
result,
phase_key_dict,
n_best,
find_direct_beam,
direct_beam_position,
marker_colors,
marker_type,
size_factor,
verbose,
)
)
plt.gcf().legend()

0 comments on commit 014e1e0

Please sign in to comment.