diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 2c9b596c..cfdd841c 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -46,7 +46,7 @@ jobs:
- os: ubuntu-latest
python-version: 3.8
OLDEST_SUPPORTED_VERSION: true
- DEPENDENCIES: diffpy.structure==3.0.2 matplotlib==3.5 numpy==1.17.3 orix==0.9.0 scipy==1.8 tqdm==4.9
+ DEPENDENCIES: diffpy.structure==3.0.2 matplotlib==3.5 numpy==1.17.3 orix==0.12.1 scipy==1.8 tqdm==4.9
LABEL: -oldest
steps:
- uses: actions/checkout@v4
diff --git a/diffsims/crystallography/_diffracting_vector.py b/diffsims/crystallography/_diffracting_vector.py
new file mode 100644
index 00000000..be9a61f3
--- /dev/null
+++ b/diffsims/crystallography/_diffracting_vector.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+from diffsims.crystallography import ReciprocalLatticeVector
+import numpy as np
+from orix.vector.miller import _transform_space
+from orix.quaternion import Rotation
+
+
+class DiffractingVector(ReciprocalLatticeVector):
+ r"""Reciprocal lattice vectors :math:`(hkl)` for use in electron
+ diffraction analysis and simulation.
+
+ All lengths are assumed to be given in Å or inverse Å.
+
+ This extends the :class:`ReciprocalLatticeVector` class. `DiffractingVector`
+ focus on the subset of reciprocal lattice vectors that are relevant for
+ electron diffraction based on the intersection of the Ewald sphere with the
+ reciprocal lattice.
+
+ This class is only used internally to store the DiffractionVectors generated from the
+ :class:`~diffsims.simulations.DiffractionSimulation` class. It is not (currently)
+ intended to be used directly by the user.
+
+ Parameters
+ ----------
+ phase : orix.crystal_map.Phase
+ A phase with a crystal lattice and symmetry.
+ xyz : numpy.ndarray, list, or tuple, optional
+ Cartesian coordinates of indices of reciprocal lattice vector(s)
+ ``hkl``. Default is ``None``. This, ``hkl``, or ``hkil`` is
+ required.
+ hkl : numpy.ndarray, list, or tuple, optional
+ Indices of reciprocal lattice vector(s). Default is ``None``.
+ This, ``xyz``, or ``hkil`` is required.
+ hkil : numpy.ndarray, list, or tuple, optional
+ Indices of reciprocal lattice vector(s), often preferred over
+ ``hkl`` in trigonal and hexagonal lattices. Default is ``None``.
+ This, ``xyz``, or ``hkl`` is required.
+ intensity : numpy.ndarray, list, or tuple, optional
+ Intensity of the diffraction vector(s). Default is ``None``.
+ rotation : orix.quaternion.Rotation, optional
+ Rotation matrix previously applied to the reciprocal lattice vector(s) and the
+ lattice of the phase. Default is ``None`` which corresponds to the
+ identity matrix.
+
+
+ Examples
+ --------
+ >>> from diffpy.structure import Atom, Lattice, Structure
+ >>> from orix.crystal_map import Phase
+ >>> from diffsims.crystallography import DiffractingVector
+ >>> phase = Phase(
+ ... "al",
+ ... space_group=225,
+ ... structure=Structure(
+ ... lattice=Lattice(4.04, 4.04, 4.04, 90, 90, 90),
+ ... atoms=[Atom("Al", [0, 0, 1])],
+ ... ),
+ ... )
+ >>> rlv = DiffractingVector(phase, hkl=[[1, 1, 1], [2, 0, 0]])
+ >>> rlv
+ ReciprocalLatticeVector (2,), al (m-3m)
+ [[1. 1. 1.]
+ [2. 0. 0.]]
+
+ """
+
+ def __init__(self, phase, xyz=None, hkl=None, hkil=None, intensity=None):
+ super().__init__(phase, xyz=xyz, hkl=hkl, hkil=hkil)
+ if intensity is None:
+ self._intensity = np.full(self.shape, np.nan)
+ elif len(intensity) != self.size:
+ raise ValueError("Length of intensity array must match number of vectors")
+ else:
+ self._intensity = np.array(intensity)
+
+ def __getitem__(self, key):
+ new_data = self.data[key]
+ dv_new = self.__class__(self.phase, xyz=new_data)
+
+ if np.isnan(self.structure_factor).all():
+ dv_new._structure_factor = np.full(dv_new.shape, np.nan, dtype="complex128")
+
+ else:
+ dv_new._structure_factor = self.structure_factor[key]
+ if np.isnan(self.theta).all():
+ dv_new._theta = np.full(dv_new.shape, np.nan)
+ else:
+ dv_new._theta = self.theta[key]
+ if np.isnan(self.intensity).all():
+ dv_new._intensity = np.full(dv_new.shape, np.nan)
+ else:
+ slic = self.intensity[key]
+ if not hasattr(slic, "__len__"):
+ slic = np.array(
+ [
+ slic,
+ ]
+ )
+ dv_new._intensity = slic
+
+ return dv_new
+
+ @property
+ def basis_rotation(self):
+ """
+ Returns the lattice basis rotation.
+ """
+ return Rotation.from_matrix(self.phase.structure.lattice.baserot)
+
+ def rotate_with_basis(self, rotation):
+ """Rotate both vectors and the basis with a given `Rotation`.
+ This differs from simply multiplying with a `Rotation`,
+ as that would NOT update the basis.
+
+ Parameters
+ ----------
+ rot : orix.quaternion.Rotation
+ A rotation to apply to vectors and the basis.
+
+ Returns
+ -------
+ DiffractingVector
+ A new DiffractingVector with the rotated vectors and basis. This maintains
+ the hkl indices of the vectors, but the underlying vector xyz coordinates
+ are rotated by the given rotation.
+
+ Notes
+ -----
+ Rotating the lattice basis may lead to undefined behavior in orix as it violates
+ the assumption that the basis is aligned with the crystal axes. Particularly,
+ applying symmetry operations to the phase may lead to unexpected results.
+ """
+
+ if rotation.size != 1:
+ raise ValueError("Rotation must be a single rotation")
+ # rotate basis
+ new_phase = self.phase.deepcopy()
+ br = new_phase.structure.lattice.baserot
+ # In case the base rotation is set already
+ new_br = br @ rotation.to_matrix().squeeze()
+ new_phase.structure.lattice.setLatPar(baserot=new_br)
+ # rotate vectors
+ vecs = ~rotation * self.to_miller()
+ return ReciprocalLatticeVector(new_phase, xyz=vecs.data)
+
+ @property
+ def intensity(self):
+ return self._intensity
+
+ @intensity.setter
+ def intensity(self, value):
+ if not hasattr(value, "__len__"):
+ value = np.array(
+ [
+ value,
+ ]
+ * self.size
+ )
+ if len(value) != self.size:
+ raise ValueError("Length of intensity array must match number of vectors")
+ self._intensity = np.array(value)
+
+ def calculate_structure_factor(self):
+ raise NotImplementedError(
+ "Structure factor calculation not implemented for DiffractionVector. "
+ "Use ReciprocalLatticeVector instead."
+ )
+
+ def to_flat_polar(self):
+ """Return the vectors in polar coordinates as projected onto the x,y plane"""
+ flat_self = self.flatten()
+ r = np.linalg.norm(flat_self.data[:, :2], axis=1)
+ theta = np.arctan2(
+ flat_self.data[:, 1],
+ flat_self.data[:, 0],
+ )
+ return r, theta
diff --git a/diffsims/crystallography/reciprocal_lattice_vector.py b/diffsims/crystallography/reciprocal_lattice_vector.py
index 8ce16ba7..12f190f2 100644
--- a/diffsims/crystallography/reciprocal_lattice_vector.py
+++ b/diffsims/crystallography/reciprocal_lattice_vector.py
@@ -119,7 +119,6 @@ def __init__(self, phase, xyz=None, hkl=None, hkil=None):
self._coordinate_format = "hkl"
xyz = _transform_space(hkl, "r", "c", phase.structure.lattice)
super().__init__(xyz)
-
self._theta = np.full(self.shape, np.nan)
self._structure_factor = np.full(self.shape, np.nan, dtype="complex128")
@@ -1023,7 +1022,7 @@ def symmetrise(self, return_multiplicity=False, return_index=False):
return new_out
@classmethod
- def from_highest_hkl(cls, phase, hkl):
+ def from_highest_hkl(cls, phase, hkl, include_zero_vector=False):
"""Create a set of unique reciprocal lattice vectors from three
highest indices.
@@ -1033,6 +1032,8 @@ def from_highest_hkl(cls, phase, hkl):
A phase with a crystal lattice and symmetry.
hkl : numpy.ndarray, list, or tuple
Three highest reciprocal lattice vector indices.
+ include_zero_vector : bool
+ If ``True``, include the zero vector (000) in the set of vectors.
Examples
--------
@@ -1067,10 +1068,14 @@ def from_highest_hkl(cls, phase, hkl):
"""
idx = _get_indices_from_highest(highest_indices=hkl)
- return cls(phase, hkl=idx).unique()
+ new = cls(phase, hkl=idx).unique()
+ if include_zero_vector:
+ new_data = np.vstack((new.hkl, np.zeros(3, dtype=int)))
+ new = ReciprocalLatticeVector(phase, hkl=new_data)
+ return new
@classmethod
- def from_min_dspacing(cls, phase, min_dspacing=0.7):
+ def from_min_dspacing(cls, phase, min_dspacing=0.7, include_zero_vector=False):
"""Create a set of unique reciprocal lattice vectors with a
a direct space interplanar spacing greater than a lower
threshold.
@@ -1083,6 +1088,8 @@ def from_min_dspacing(cls, phase, min_dspacing=0.7):
Smallest interplanar spacing to consider. Default is 0.7,
in the unit used to define the lattice parameters in
``phase``, which is assumed to be Ångström.
+ include_zero_vector: bool
+ If ``True``, include the zero vector (000) in the set of vectors.
Examples
--------
@@ -1128,7 +1135,11 @@ def from_min_dspacing(cls, phase, min_dspacing=0.7):
dspacing = 1 / phase.structure.lattice.rnorm(hkl)
idx = dspacing >= min_dspacing
hkl = hkl[idx]
- return cls(phase, hkl=hkl).unique()
+ new = cls(phase, hkl=hkl).unique()
+ if include_zero_vector:
+ new_data = np.vstack((new.hkl, np.zeros(3, dtype=int)))
+ new = cls(phase, hkl=new_data)
+ return new
@classmethod
def from_miller(cls, miller):
diff --git a/diffsims/generators/__init__.py b/diffsims/generators/__init__.py
index 56b17d09..ab61bf8a 100644
--- a/diffsims/generators/__init__.py
+++ b/diffsims/generators/__init__.py
@@ -26,6 +26,7 @@
rotation_list_generators,
sphere_mesh_generators,
zap_map_generator,
+ simulation_generator,
)
__all__ = [
@@ -33,5 +34,6 @@
"library_generator",
"rotation_list_generators",
"sphere_mesh_generators",
+ "simulation_generator",
"zap_map_generator",
]
diff --git a/diffsims/generators/simulation_generator.py b/diffsims/generators/simulation_generator.py
new file mode 100644
index 00000000..0a701e3c
--- /dev/null
+++ b/diffsims/generators/simulation_generator.py
@@ -0,0 +1,428 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+"""Kinematic Diffraction Simulation Generator."""
+
+from typing import Union, Sequence
+import numpy as np
+
+from orix.quaternion import Rotation
+from orix.crystal_map import Phase
+
+from diffsims.crystallography._diffracting_vector import DiffractingVector
+from diffsims.utils.shape_factor_models import (
+ linear,
+ atanc,
+ lorentzian,
+ sinc,
+ sin2c,
+ lorentzian_precession,
+ _shape_factor_precession,
+)
+
+from diffsims.utils.sim_utils import (
+ get_electron_wavelength,
+ get_kinematical_intensities,
+ is_lattice_hexagonal,
+ get_points_in_sphere,
+ get_intensities_params,
+)
+
+_shape_factor_model_mapping = {
+ "linear": linear,
+ "atanc": atanc,
+ "sinc": sinc,
+ "sin2c": sin2c,
+ "lorentzian": lorentzian,
+}
+
+from diffsims.simulations import Simulation1D, Simulation2D
+
+__all__ = ["SimulationGenerator"]
+
+
+class SimulationGenerator:
+ """
+ A class for generating kinematic diffraction simulations.
+ """
+
+ def __repr__(self):
+ return (
+ f"SimulationGenerator(accelerating_voltage={self.accelerating_voltage}, "
+ f"scattering_params={self.scattering_params}, "
+ f"approximate_precession={self.approximate_precession})"
+ )
+
+ def __init__(
+ self,
+ accelerating_voltage: float = 200,
+ scattering_params: str = "lobato",
+ precession_angle: float = 0,
+ shape_factor_model: str = "lorentzian",
+ approximate_precession: bool = True,
+ minimum_intensity: float = 1e-20,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ accelerating_voltage
+ The accelerating voltage of the electrons in keV.
+ scattering_params
+ The scattering parameters to use. One of 'lobato', 'xtables'
+ precession_angle
+ The precession angle in degrees. If 0, no precession is applied.
+ shape_factor_model
+ The shape factor model to use. One of 'linear', 'atanc', 'sinc', 'sin2c', 'lorentzian'
+ approximate_precession
+ If True, the precession is approximated by a Lorentzian function.
+ minimum_intensity
+ The minimum intensity of a reflection to be included in the profile.
+ kwargs
+ Keyword arguments to pass to the shape factor model.
+
+ """
+ self.accelerating_voltage = accelerating_voltage
+ self.precession_angle = np.abs(precession_angle)
+ self.approximate_precession = approximate_precession
+ if isinstance(shape_factor_model, str):
+ if shape_factor_model in _shape_factor_model_mapping.keys():
+ self.shape_factor_model = _shape_factor_model_mapping[
+ shape_factor_model
+ ]
+ else:
+ raise NotImplementedError(
+ f"{shape_factor_model} is not a recognized shape factor "
+ f"model, choose from: {_shape_factor_model_mapping.keys()} "
+ f"or provide your own function."
+ )
+ else:
+ self.shape_factor_model = shape_factor_model
+ self.minimum_intensity = minimum_intensity
+ self.shape_factor_kwargs = kwargs
+ if scattering_params in ["lobato", "xtables", None]:
+ self.scattering_params = scattering_params
+ else:
+ raise NotImplementedError(
+ "The scattering parameters `{}` is not implemented. "
+ "See documentation for available "
+ "implementations.".format(scattering_params)
+ )
+
+ @property
+ def wavelength(self):
+ return get_electron_wavelength(self.accelerating_voltage)
+
+ def calculate_diffraction2d(
+ self,
+ phase: Union[Phase, Sequence[Phase]],
+ rotation: Union[Rotation, Sequence[Rotation]] = Rotation.from_euler(
+ (0, 0, 0), degrees=True
+ ),
+ reciprocal_radius: float = 1.0,
+ with_direct_beam: bool = True,
+ max_excitation_error: float = 1e-2,
+ shape_factor_width: float = None,
+ debye_waller_factors: dict = None,
+ ):
+ """Calculates the diffraction pattern for one or more phases given a list
+ of rotations for each phase.
+
+ Parameters
+ ----------
+ phase:
+ The phase(s) for which to derive the diffraction pattern.
+ reciprocal_radius
+ The maximum radius of the sphere of reciprocal space to
+ sample, in reciprocal Angstroms.
+ rotation
+ The Rotation object(s) to apply to the structure and then
+ calculate the diffraction pattern.
+ with_direct_beam
+ If True, the direct beam is included in the simulated
+ diffraction pattern. If False, it is not.
+ max_excitation_error
+ The cut-off for geometric excitation error in the z-direction
+ in units of reciprocal Angstroms. Spots with a larger distance
+ from the Ewald sphere are removed from the pattern.
+ Related to the extinction distance and roughly equal to 1/thickness.
+ shape_factor_width
+ Determines the width of the reciprocal rel-rod, for fine-grained
+ control. If not set will be set equal to max_excitation_error.
+ debye_waller_factors
+ Maps element names to their temperature-dependent Debye-Waller factors.
+
+ Returns
+ -------
+ diffsims.sims.diffraction_simulation.DiffractionSimulation
+ The data associated with this structure and diffraction setup.
+ """
+ if isinstance(phase, Phase):
+ phase = [phase]
+ if isinstance(rotation, Rotation):
+ rotation = [rotation]
+ if len(phase) != len(rotation):
+ raise ValueError(
+ "The number of phases and rotations must be equal. "
+ f"Got {len(phase)} phases and {len(rotation)} rotations."
+ )
+
+ if debye_waller_factors is None:
+ debye_waller_factors = {}
+ # Specify variables used in calculation
+ wavelength = self.wavelength
+
+ # Rotate using all the rotations in the list
+ vectors = []
+ for p, rotate in zip(phase, rotation):
+ recip = DiffractingVector.from_min_dspacing(
+ p,
+ min_dspacing=1 / reciprocal_radius,
+ include_zero_vector=with_direct_beam,
+ )
+ phase_vectors = []
+ for rot in rotate:
+ # Calculate the reciprocal lattice vectors that intersect the Ewald sphere.
+ (
+ intersected_vectors,
+ hkl,
+ shape_factor,
+ ) = self.get_intersecting_reflections(
+ recip,
+ rot,
+ wavelength,
+ max_excitation_error,
+ shape_factor_width=shape_factor_width,
+ with_direct_beam=with_direct_beam,
+ )
+
+ # Calculate diffracted intensities based on a kinematic model.
+ intensities = get_kinematical_intensities(
+ p.structure,
+ hkl,
+ intersected_vectors.gspacing,
+ prefactor=shape_factor,
+ scattering_params=self.scattering_params,
+ debye_waller_factors=debye_waller_factors,
+ )
+
+ # Threshold peaks included in simulation as factor of zero beam intensity.
+ peak_mask = intensities > np.max(intensities) * self.minimum_intensity
+ intensities = intensities[peak_mask]
+ intersected_vectors = intersected_vectors[peak_mask]
+ intersected_vectors.intensity = intensities
+ phase_vectors.append(intersected_vectors)
+ vectors.append(phase_vectors)
+
+ if len(phase) == 1:
+ vectors = vectors[0]
+ phase = phase[0]
+ rotation = rotation[0]
+ if rotation.size == 1:
+ vectors = vectors[0]
+
+ # Create a simulation object
+ sim = Simulation2D(
+ phases=phase,
+ coordinates=vectors,
+ rotations=rotation,
+ simulation_generator=self,
+ reciprocal_radius=reciprocal_radius,
+ )
+ return sim
+
+ def calculate_diffraction1d(
+ self,
+ phase: Phase,
+ reciprocal_radius: float = 1.0,
+ minimum_intensity: float = 1e-3,
+ debye_waller_factors: dict = None,
+ ):
+ """Calculates the 1-D profile of the diffraction pattern for one phases.
+
+ This is useful for plotting the diffracting reflections for some phases.
+
+ Parameters
+ ----------
+ phase:
+ The phase for which to derive the diffraction pattern.
+ reciprocal_radius
+ The maximum radius of the sphere of reciprocal space to
+ sample, in reciprocal Angstroms.
+ minimum_intensity
+ The minimum intensity of a reflection to be included in the profile.
+ debye_waller_factors
+ Maps element names to their temperature-dependent Debye-Waller factors.
+ """
+ latt = phase.structure.lattice
+
+ # Obtain crystallographic reciprocal lattice points within range
+ recip_latt = latt.reciprocal()
+ spot_indices, _, spot_distances = get_points_in_sphere(
+ recip_latt, reciprocal_radius
+ )
+
+ ##spot_indicies is a numpy.array of the hkls allowed in the recip radius
+ g_indices, multiplicities, g_hkls = get_intensities_params(
+ recip_latt, reciprocal_radius
+ )
+
+ i_hkl = get_kinematical_intensities(
+ phase.structure,
+ g_indices,
+ np.asarray(g_hkls),
+ prefactor=multiplicities,
+ scattering_params=self.scattering_params,
+ debye_waller_factors=debye_waller_factors,
+ )
+
+ if is_lattice_hexagonal(latt):
+ # Use Miller-Bravais indices for hexagonal lattices.
+ g_indices = np.array(
+ [
+ g_indices[:, 0],
+ g_indices[:, 1],
+ g_indices[:, 0] - g_indices[:, 1],
+ g_indices[:, 2],
+ ]
+ ).T
+
+ hkls_labels = ["".join([str(int(x)) for x in xs]) for xs in g_indices]
+
+ peaks = []
+ for l, i, g in zip(hkls_labels, i_hkl, g_hkls):
+ peaks.append((l, [i, g]))
+
+ # Scale intensities so that the max intensity is 100.
+
+ max_intensity = max([v[1][0] for v in peaks])
+ reciporical_spacing = []
+ intensities = []
+ hkls = []
+ for p in peaks:
+ label, v = p # (label, [intensity,g])
+ if v[0] / max_intensity * 100 > minimum_intensity and (label != "000"):
+ reciporical_spacing.append(v[1])
+ intensities.append(v[0])
+ hkls.append(label)
+
+ intensities = np.asarray(intensities) / max(intensities) * 100
+
+ return Simulation1D(
+ phase=phase,
+ reciprocal_spacing=reciporical_spacing,
+ intensities=intensities,
+ hkl=hkls,
+ reciprocal_radius=reciprocal_radius,
+ wavelength=self.wavelength,
+ )
+
+ def get_intersecting_reflections(
+ self,
+ recip: DiffractingVector,
+ rot: np.ndarray,
+ wavelength: float,
+ max_excitation_error: float,
+ shape_factor_width: float = None,
+ with_direct_beam: bool = True,
+ ):
+ """Calculates the reciprocal lattice vectors that intersect the Ewald sphere.
+
+ Parameters
+ ----------
+ recip
+ The reciprocal lattice vectors to rotate.
+ rot
+ The rotation matrix to apply to the reciprocal lattice vectors.
+ wavelength
+ The wavelength of the electrons in Angstroms.
+ max_excitation_error
+ The cut-off for geometric excitation error in the z-direction
+ in units of reciprocal Angstroms. Spots with a larger distance
+ from the Ewald sphere are removed from the pattern.
+ Related to the extinction distance and roungly equal to 1/thickness.
+ shape_factor_width
+ Determines the width of the reciprocal rel-rod, for fine-grained
+ control. If not set will be set equal to max_excitation_error.
+ """
+ initial_hkl = recip.hkl
+ rotated_vectors = recip.rotate_with_basis(rotation=rot)
+ rotated_phase = rotated_vectors.phase
+ rotated_vectors = rotated_vectors.data
+ if with_direct_beam:
+ rotated_vectors = np.vstack([rotated_vectors.data, [0, 0, 0]])
+ initial_hkl = np.vstack([initial_hkl, [0, 0, 0]])
+ # Identify the excitation errors of all points (distance from point to Ewald sphere)
+ r_sphere = 1 / wavelength
+ r_spot = np.sqrt(np.sum(np.square(rotated_vectors[:, :2]), axis=1))
+ z_spot = rotated_vectors[:, 2]
+
+ z_sphere = -np.sqrt(r_sphere**2 - r_spot**2) + r_sphere
+ excitation_error = z_sphere - z_spot
+
+ # determine the pre-selection reflections
+ if self.precession_angle == 0:
+ intersection = np.abs(excitation_error) < max_excitation_error
+ else:
+ # only consider points that intersect the ewald sphere at some point
+ # the center point of the sphere
+ P_z = r_sphere * np.cos(np.deg2rad(self.precession_angle))
+ P_t = r_sphere * np.sin(np.deg2rad(self.precession_angle))
+ # the extremes of the ewald sphere
+ z_surf_up = P_z - np.sqrt(r_sphere**2 - (r_spot + P_t) ** 2)
+ z_surf_do = P_z - np.sqrt(r_sphere**2 - (r_spot - P_t) ** 2)
+ intersection = (z_spot - max_excitation_error <= z_surf_up) & (
+ z_spot + max_excitation_error >= z_surf_do
+ )
+
+ # select these reflections
+ intersected_vectors = rotated_vectors[intersection]
+ intersected_vectors = DiffractingVector(
+ phase=rotated_phase,
+ xyz=intersected_vectors,
+ )
+ excitation_error = excitation_error[intersection]
+ r_spot = r_spot[intersection]
+ hkl = initial_hkl[intersection]
+
+ if shape_factor_width is None:
+ shape_factor_width = max_excitation_error
+ # select and evaluate shape factor model
+ if self.precession_angle == 0:
+ # calculate shape factor
+ shape_factor = self.shape_factor_model(
+ excitation_error, shape_factor_width, **self.shape_factor_kwargs
+ )
+ else:
+ if self.approximate_precession:
+ shape_factor = lorentzian_precession(
+ excitation_error,
+ shape_factor_width,
+ r_spot,
+ np.deg2rad(self.precession_angle),
+ )
+ else:
+ shape_factor = _shape_factor_precession(
+ excitation_error,
+ r_spot,
+ np.deg2rad(self.precession_angle),
+ self.shape_factor_model,
+ shape_factor_width,
+ **self.shape_factor_kwargs,
+ )
+ return intersected_vectors, hkl, shape_factor
diff --git a/diffsims/simulations/__init__.py b/diffsims/simulations/__init__.py
new file mode 100644
index 00000000..91ea1d10
--- /dev/null
+++ b/diffsims/simulations/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+"""Kinematic Diffraction Simulation Results."""
+
+from diffsims.simulations.simulation2d import Simulation2D
+from diffsims.simulations.simulation1d import Simulation1D
+
+__all__ = [
+ "Simulation1D",
+ "Simulation2D",
+]
diff --git a/diffsims/simulations/simulation1d.py b/diffsims/simulations/simulation1d.py
new file mode 100644
index 00000000..49391eff
--- /dev/null
+++ b/diffsims/simulations/simulation1d.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import matplotlib.pyplot as plt
+from orix.crystal_map import Phase
+from diffsims.utils.sim_utils import get_electron_wavelength
+
+# to avoid circular imports
+if TYPE_CHECKING: # pragma: no cover
+ from diffsims.generators.simulation_generator import SimulationGenerator
+
+
+class Simulation1D:
+ """Holds the result of a 1D simulation for some phase"""
+
+ def __init__(
+ self,
+ phase: Phase,
+ reciprocal_spacing: np.ndarray,
+ intensities: np.ndarray,
+ hkl: np.ndarray,
+ reciprocal_radius: float,
+ wavelength: float,
+ ):
+ """Initializes the DiffractionSimulation object with data values for
+ the coordinates, indices, intensities, calibration and offset.
+
+ Parameters
+ ----------
+ phase
+ The phase of the simulation
+ reciprocal_spacing
+ The spacing of the reciprocal lattice vectors in A^-1
+ intensities
+ The intensities of the diffraction spots
+ hkl
+ The hkl indices of the diffraction spots
+ reciprocal_radius
+ The radius which the reciprocal lattice spacings are plotted out to
+ wavelength
+ The wavelength of the beam in A^-1
+ """
+ self.phase = phase
+ self.reciprocal_spacing = reciprocal_spacing
+ self.intensities = intensities
+ self.hkl = hkl
+ self.reciprocal_radius = reciprocal_radius
+ self.wavelength = wavelength
+
+ def __repr__(self):
+ return f"Simulation1D(name: {self.phase.name}, wavelength: {self.wavelength})"
+
+ @property
+ def theta(self):
+ return np.arctan2(np.array(self.reciprocal_spacing), 1 / self.wavelength)
+
+ def plot(self, ax=None, annotate_peaks=False, fontsize=12, with_labels=True):
+ """Plots the 1D diffraction pattern,"""
+ if ax is None:
+ fig, ax = plt.subplots(1, 1)
+ for g, i, hkls in zip(self.reciprocal_spacing, self.intensities, self.hkl):
+ label = hkls
+ ax.plot([g, g], [0, i], color="k", linewidth=3, label=label)
+ if annotate_peaks:
+ ax.annotate(label, xy=[g, i], xytext=[g, i], fontsize=fontsize)
+
+ if with_labels:
+ ax.set_xlabel("A ($^{-1}$)")
+ ax.set_ylabel("Intensities (scaled)")
+ return ax
diff --git a/diffsims/simulations/simulation2d.py b/diffsims/simulations/simulation2d.py
new file mode 100644
index 00000000..9852c5c4
--- /dev/null
+++ b/diffsims/simulations/simulation2d.py
@@ -0,0 +1,746 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+from typing import Union, Sequence, TYPE_CHECKING, Any
+import copy
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.widgets import Slider
+from orix.crystal_map import Phase
+from orix.quaternion import Rotation
+from orix.vector import Vector3d
+
+from diffsims.crystallography._diffracting_vector import DiffractingVector
+from diffsims.pattern.detector_functions import add_shot_and_point_spread
+
+# to avoid circular imports
+if TYPE_CHECKING: # pragma: no cover
+ from diffsims.generators.simulation_generator import SimulationGenerator
+
+__all__ = [
+ "Simulation2D",
+ "get_closest",
+]
+
+
+class PhaseGetter:
+ """A class for getting the phases of a simulation library.
+
+ Parameters
+ ----------
+ simulation : Simulation2D
+ The simulation to get from.
+ """
+
+ def __init__(self, simulation):
+ self.simulation = simulation
+
+ def __getitem__(self, item):
+ all_phases = self.simulation.phases
+ if isinstance(all_phases, Phase):
+ raise ValueError("Only one phase in the simulation")
+ elif isinstance(item, str):
+ ind = [phase.name for phase in all_phases].index(item)
+ elif isinstance(item, (int, slice)):
+ ind = item
+ else:
+ raise ValueError("Item must be a string or integer")
+ new_coords = self.simulation.coordinates[ind]
+ new_rotations = self.simulation.rotations[ind]
+ new_phases = all_phases[ind]
+ return Simulation2D(
+ phases=new_phases,
+ coordinates=new_coords,
+ rotations=new_rotations,
+ simulation_generator=self.simulation.simulation_generator,
+ )
+
+
+class RotationGetter:
+ """A class for getting a Rotation of a simulation library.
+
+ Parameters
+ ----------
+ simulation : Simulation2D
+ The simulation to get from.
+ """
+
+ def __init__(self, simulation):
+ self.simulation = simulation
+
+ def __getitem__(self, item):
+ all_phases = self.simulation.phases
+ if self.simulation.current_size == 1:
+ raise ValueError("Only one rotation in the simulation")
+ elif isinstance(all_phases, Phase): # only one phase in the simulation
+ coords = self.simulation.coordinates[item]
+ phases = self.simulation.phases
+ rotations = self.simulation.rotations[item]
+ else: # multiple phases in the simulation
+ coords = [c[item] for c in self.simulation.coordinates]
+ phases = self.simulation.phases
+ rotations = [rot[item] for rot in self.simulation.rotations]
+ return Simulation2D(
+ phases=phases,
+ coordinates=coords,
+ rotations=rotations,
+ simulation_generator=self.simulation.simulation_generator,
+ )
+
+
+class Simulation2D:
+ """Holds the result of a kinematic diffraction simulation for some phase
+ and rotation. This class is iterable and can be used to iterate through
+ simulations of different phases and rotations.
+ """
+
+ def __init__(
+ self,
+ phases: Sequence[Phase],
+ coordinates: Union[
+ DiffractingVector,
+ Sequence[DiffractingVector],
+ Sequence[Sequence[DiffractingVector]],
+ ],
+ rotations: Union[Rotation, Sequence[Rotation]],
+ simulation_generator: "SimulationGenerator",
+ reciprocal_radius: float = 1.0,
+ ):
+ """Initializes the DiffractionSimulation object with data values for
+ the coordinates, indices, intensities, calibration and offset.
+
+ Parameters
+ ----------
+ coordinates
+ The list of DiffractingVector objects for each phase and rotation. If there
+ are multiple phases, then this should be a list of lists of DiffractingVector objects.
+ If there is only one phase, then this should be a list of DiffractingVector objects.
+ rotations
+ The list of Rotation objects for each phase. If there are multiple phases, then this should
+ be a list of Rotation objects. If there is only one phase, then this should be a single
+ Rotation object.
+ phases
+ The list of Phase objects for each phase. If there is only one phase, then this should be
+ a single Phase object.
+ simulation_generator
+ The SimulationGenerator object used to generate the diffraction patterns.
+
+ """
+ # Basic data
+ if isinstance(rotations, Rotation) and rotations.size == 1:
+ if not isinstance(coordinates, DiffractingVector):
+ raise ValueError(
+ "If there is only one rotation, then the coordinates must be a DiffractingVector object"
+ )
+ elif isinstance(rotations, Rotation):
+ coordinates = np.array(coordinates, dtype=object)
+ if coordinates.size != rotations.size:
+ raise ValueError(
+ f"The number of rotations: {rotations.size} must match the number of "
+ f"coordinates {coordinates.size}"
+ )
+ else: # iterable of Rotation
+ rotations = np.array(rotations, dtype=object)
+ coordinates = np.array(coordinates, dtype=object)
+ phases = np.array(phases)
+ if rotations.size != phases.size:
+ raise ValueError(
+ f"The number of rotations: {rotations.size} must match the number of "
+ f"phases {phases.size}"
+ )
+
+ for r, c in zip(rotations, coordinates):
+ if isinstance(c, DiffractingVector):
+ c = np.array(
+ [
+ c,
+ ]
+ )
+ if r.size != len(c):
+ raise ValueError(
+ f"The number of rotations: {r.size} must match the number of "
+ f"coordinates {c.shape[0]}"
+ )
+ self.phases = phases
+ self.rotations = rotations
+ self.coordinates = coordinates
+ self.simulation_generator = simulation_generator
+
+ # for interactive plotting and iterating through the Simulations
+ self.phase_index = 0
+ self.rotation_index = 0
+ self._rot_plot = None
+ self._diff_plot = None
+ self.reciporical_radius = reciprocal_radius
+
+ # for slicing a simulation
+ self.iphase = PhaseGetter(self)
+ self.irot = RotationGetter(self)
+ self._rotation_slider = None
+ self._phase_slider = None
+
+ def get_simulation(self, item):
+ """Return the rotation and the phase index of the simulation"""
+ if self.has_multiple_phases:
+ cumsum = np.cumsum(self._num_rotations())
+ ind = np.searchsorted(cumsum, item, side="right")
+ cumsum = np.insert(cumsum, 0, 0)
+ num_rot = cumsum[ind]
+ if self.has_multiple_rotations[ind]:
+ return (
+ self.rotations[ind][item - num_rot],
+ ind,
+ self.coordinates[ind][item - num_rot],
+ )
+ else:
+ return self.rotations[ind], ind, self.coordinates[ind]
+ elif self.has_multiple_rotations:
+ return self.rotations[item], 0, self.coordinates[item]
+ else:
+ return self.rotations[item], 0, self.coordinates
+
+ def _num_rotations(self):
+ if self.has_multiple_phases:
+ return [r.size for r in self.rotations]
+ else:
+ return self.rotations.size
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.phase_index == self.num_phases:
+ self.phase_index = 0
+ raise StopIteration
+ else:
+ if self.has_multiple_phases:
+ coords = self.coordinates[self.phase_index]
+ else:
+ coords = self.coordinates
+ if self.has_multiple_rotations:
+ coords = coords[self.rotation_index]
+ else:
+ coords = coords
+ if self.rotation_index + 1 == self.current_size:
+ self.rotation_index = 0
+ self.phase_index += 1
+ else:
+ self.rotation_index += 1
+ return coords
+
+ @property
+ def current_size(self):
+ """Returns the number of rotations in the current phase"""
+ if self.has_multiple_phases:
+ return self.rotations[self.phase_index].size
+ else:
+ return self.rotations.size
+
+ def deepcopy(self):
+ return copy.deepcopy(self)
+
+ def _get_transformed_coordinates(
+ self,
+ angle: float,
+ center: Sequence = (0, 0),
+ mirrored: bool = False,
+ units: str = "real",
+ calibration: float = None,
+ ):
+ """Translate, rotate or mirror the pattern spot coordinates"""
+
+ coords = self.get_current_coordinates()
+
+ if units != "real":
+ center = np.array(center)
+ coords.data = coords.data / calibration
+ transformed_coords = coords
+ cx, cy = center
+ x = transformed_coords.data[:, 0]
+ y = transformed_coords.data[:, 1]
+ mirrored_factor = -1 if mirrored else 1
+ theta = mirrored_factor * np.arctan2(y, x) + np.deg2rad(angle)
+ rd = np.sqrt(x**2 + y**2)
+ transformed_coords[:, 0] = rd * np.cos(theta) + cx
+ transformed_coords[:, 1] = rd * np.sin(theta) + cy
+ return transformed_coords
+
+ @property
+ def current_phase(self):
+ if self.has_multiple_phases:
+ return self.phases[self.phase_index]
+ else:
+ return self.phases
+
+ def rotate_shift_coordinates(
+ self, angle: float, center: Sequence = (0, 0), mirrored: bool = False
+ ):
+ """Rotate, flip or shift patterns in-plane
+
+ Parameters
+ ----------
+ angle
+ In plane rotation angle in degrees
+ center
+ Center coordinate of the patterns
+ mirrored
+ Mirror across the x-axis
+ """
+ coords_new = self._get_transformed_coordinates(
+ angle, center, mirrored, units="real"
+ )
+ return coords_new
+
+ def polar_flatten_simulations(self, radial_axes=None, azimuthal_axes=None):
+ """Flattens the simulations into polar coordinates for use in template matching.
+ The resulting arrays are of shape (n_simulations, n_spots) where n_spots is the
+ maximum number of spots in any simulation.
+
+
+ Returns
+ -------
+ r_templates, theta_templates, intensities_templates
+ """
+
+ flattened_vectors = [sim for sim in self]
+ max_num_spots = max([v.size for v in flattened_vectors])
+
+ r_templates = np.zeros((len(flattened_vectors), max_num_spots))
+ theta_templates = np.zeros((len(flattened_vectors), max_num_spots))
+ intensities_templates = np.zeros((len(flattened_vectors), max_num_spots))
+ for i, v in enumerate(flattened_vectors):
+ r, t = v.to_flat_polar()
+ if radial_axes is not None and azimuthal_axes is not None:
+ r = get_closest(radial_axes, r)
+ t = get_closest(azimuthal_axes, t)
+ r = r[r < len(radial_axes)]
+ t = t[t < len(azimuthal_axes)]
+ r_templates[i, : len(r)] = r
+ theta_templates[i, : len(t)] = t
+ intensities_templates[i, : len(v.intensity)] = v.intensity
+ if radial_axes is not None and azimuthal_axes is not None:
+ r_templates = np.array(r_templates, dtype=int)
+ theta_templates = np.array(theta_templates, dtype=int)
+
+ return r_templates, theta_templates, intensities_templates
+
+ def get_diffraction_pattern(
+ self,
+ shape=None,
+ sigma=10,
+ direct_beam_position=None,
+ in_plane_angle=0,
+ calibration=0.01,
+ mirrored=False,
+ ):
+ """Returns the diffraction data as a numpy array with
+ two-dimensional Gaussians representing each diffracted peak. Should only
+ be used for qualitative work.
+
+ Parameters
+ ----------
+ shape : tuple of ints
+ The size of a side length (in pixels)
+ sigma : float
+ Standard deviation of the Gaussian function to be plotted (in pixels).
+ direct_beam_position: 2-tuple of ints, optional
+ The (x,y) coordinate in pixels of the direct beam. Defaults to
+ the center of the image.
+ in_plane_angle: float, optional
+ In plane rotation of the pattern in degrees
+ mirrored: bool, optional
+ Whether the pattern should be flipped over the x-axis,
+ corresponding to the inverted orientation
+
+ Returns
+ -------
+ diffraction-pattern : numpy.array
+ The simulated electron diffraction pattern, normalised.
+
+ Notes
+ -----
+ If don't know the exact calibration of your diffraction signal using 1e-2
+ produces reasonably good patterns when the lattice parameters are on
+ the order of 0.5nm and the default size and sigma are used.
+ """
+ if direct_beam_position is None:
+ direct_beam_position = (shape[1] // 2, shape[0] // 2)
+ transformed = self._get_transformed_coordinates(
+ in_plane_angle,
+ direct_beam_position,
+ mirrored,
+ units="pixel",
+ calibration=calibration,
+ )
+ in_frame = (
+ (transformed.data[:, 0] >= 0)
+ & (transformed.data[:, 0] < shape[1])
+ & (transformed.data[:, 1] >= 0)
+ & (transformed.data[:, 1] < shape[0])
+ )
+ spot_coords = transformed.data[in_frame].astype(int)
+
+ spot_intens = transformed.intensity[in_frame]
+ pattern = np.zeros(shape)
+ # checks that we have some spots
+ if spot_intens.shape[0] == 0:
+ return pattern
+ else:
+ pattern[spot_coords[:, 0], spot_coords[:, 1]] = spot_intens
+ pattern = add_shot_and_point_spread(pattern.T, sigma, shot_noise=False)
+ return np.divide(pattern, np.max(pattern))
+
+ @property
+ def num_phases(self):
+ """Returns the number of phases in the simulation"""
+ if hasattr(self.phases, "__len__"):
+ return len(self.phases)
+ else:
+ return 1
+
+ @property
+ def has_multiple_phases(self):
+ """Returns True if the simulation has multiple phases"""
+ return self.num_phases > 1
+
+ @property
+ def has_multiple_rotations(self):
+ """Returns True if the simulation has multiple rotations"""
+ if isinstance(self.rotations, Rotation):
+ return self.rotations.size > 1
+ else:
+ return [r.size > 1 for r in self.rotations]
+
+ def get_current_coordinates(self):
+ """Returns the coordinates of the current phase and rotation"""
+ if self.has_multiple_phases:
+ return copy.deepcopy(
+ self.coordinates[self.phase_index][self.rotation_index]
+ )
+ elif not self.has_multiple_phases and self.has_multiple_rotations:
+ return copy.deepcopy(self.coordinates[self.rotation_index])
+ else:
+ return copy.deepcopy(self.coordinates)
+
+ def get_current_rotation_matrix(self):
+ """Returns the current rotation matrix based on the phase and rotation index"""
+ if self.has_multiple_phases:
+ return copy.deepcopy(
+ self.rotations[self.phase_index].to_matrix()[self.rotation_index]
+ )
+ else:
+ return copy.deepcopy(self.rotations.to_matrix()[self.rotation_index])
+
+ def plot_rotations(self, beam_direction: Vector3d = Vector3d.zvector()):
+ """Plots the rotations of the current phase in stereographic projection"""
+ if self.has_multiple_phases:
+ rots = self.rotations[self.phase_index]
+ else:
+ rots = self.rotations
+ vect_rot = rots * beam_direction
+ facecolor = ["k"] * rots.size
+ facecolor[self.rotation_index] = "r" # highlight the current rotation
+ fig = vect_rot.scatter(
+ grid=True,
+ facecolor=facecolor,
+ return_figure=True,
+ )
+ pointer = vect_rot[self.rotation_index]
+ _plot = fig.axes[0]
+ _plot.scatter(pointer.data[0][0], pointer.data[0][1], color="r")
+ _plot = fig.axes[0]
+ _plot.set_title("Rotations" + self.current_phase.name)
+
+ def _get_spots(
+ self,
+ in_plane_angle,
+ direct_beam_position,
+ mirrored,
+ units,
+ calibration,
+ include_direct_beam,
+ ):
+ """Returns the spots of the current phase and rotation for plotting"""
+ coords = self._get_transformed_coordinates(
+ in_plane_angle,
+ direct_beam_position,
+ mirrored,
+ units=units,
+ calibration=calibration,
+ )
+ if include_direct_beam:
+ spots = coords.data[:, :2]
+ spots = np.concatenate((spots, np.array([direct_beam_position])))
+ intensity = np.concatenate((coords.intensity, np.array([1])))
+ else:
+ spots = coords.data[:, :2]
+ intensity = coords.intensity
+ return spots, intensity, coords
+
+ def _get_labels(self, coords, intensity, min_label_intensity, xlim, ylim):
+ condition = (
+ (coords.data[:, 0] > min(xlim))
+ & (coords.data[:, 0] < max(xlim))
+ & (coords.data[:, 1] > min(ylim))
+ & (coords.data[:, 1] < max(ylim))
+ )
+ in_range_coords = coords.data[condition]
+ millers = np.round(
+ np.matmul(
+ np.matmul(in_range_coords, self.get_current_rotation_matrix().T),
+ coords.phase.structure.lattice.base.T,
+ )
+ ).astype(np.int16)
+ labels = []
+ for miller, coordinate, inten in zip(millers, in_range_coords, intensity):
+ if np.isnan(inten) or inten > min_label_intensity:
+ label = "("
+ for index in miller:
+ if index < 0:
+ label += r"$\bar{" + str(abs(index)) + r"}$"
+ else:
+ label += str(abs(index))
+ label += " "
+ label = label[:-1] + ")"
+ labels.append((coordinate, label))
+ return labels
+
+ def plot(
+ self,
+ size_factor=1,
+ direct_beam_position=None,
+ in_plane_angle=0,
+ mirrored=False,
+ units="real",
+ show_labels=False,
+ label_offset=(0, 0),
+ label_formatting=None,
+ min_label_intensity=0.1,
+ include_direct_beam=True,
+ calibration=0.1,
+ ax=None,
+ interactive=False,
+ **kwargs,
+ ):
+ """A quick-plot function for a simulation of spots
+
+ Parameters
+ ----------
+ size_factor : float, optional
+ linear spot size scaling, default to 1
+ direct_beam_position: 2-tuple of ints, optional
+ The (x,y) coordinate in pixels of the direct beam. Defaults to
+ the center of the image.
+ in_plane_angle: float, optional
+ In plane rotation of the pattern in degrees
+ mirrored: bool, optional
+ Whether the pattern should be flipped over the x-axis,
+ corresponding to the inverted orientation
+ units : str, optional
+ 'real' or 'pixel', only changes scalebars, falls back on 'real', the default
+ show_labels : bool, optional
+ draw the miller indices near the spots
+ label_offset : 2-tuple, optional
+ the relative location of the spot labels. Does nothing if `show_labels`
+ is False.
+ label_formatting : dict, optional
+ keyword arguments passed to `ax.text` for drawing the labels. Does
+ nothing if `show_labels` is False.
+ min_label_intensity : float, optional
+ minimum intensity for a spot to be labelled
+ include_direct_beam : bool, optional
+ whether to include the direct beam in the plot
+ ax : matplotlib Axes, optional
+ axes on which to draw the pattern. If `None`, a new axis is created
+ interactive : bool, optional
+ Whether to add sliders for selecting the rotation and phase. This
+ is an experimental feature and will evolve/change in the future.
+ **kwargs :
+ passed to ax.scatter() method
+
+ Returns
+ -------
+ ax,sp
+
+ Notes
+ -----
+ spot size scales with the square root of the intensity.
+ """
+
+ if label_formatting is None:
+ label_formatting = {}
+ if direct_beam_position is None:
+ direct_beam_position = (0, 0)
+ if ax is None:
+ fig, ax = plt.subplots()
+ ax.set_aspect("equal")
+
+ spots, intensity, coords = self._get_spots(
+ in_plane_angle=in_plane_angle,
+ direct_beam_position=direct_beam_position,
+ mirrored=mirrored,
+ units=units,
+ calibration=calibration,
+ include_direct_beam=include_direct_beam,
+ )
+ sp = ax.scatter(
+ spots[:, 0],
+ spots[:, 1],
+ s=size_factor * np.sqrt(intensity),
+ **kwargs,
+ )
+ ax.set_xlim(-self.reciporical_radius, self.reciporical_radius)
+ ax.set_ylim(-self.reciporical_radius, self.reciporical_radius)
+ texts = []
+ if show_labels:
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ labels = self._get_labels(
+ coords, intensity, min_label_intensity, xlim, ylim
+ )
+ # default alignment options
+ if (
+ "ha" not in label_offset
+ and "horizontalalignment" not in label_formatting
+ ):
+ label_formatting["ha"] = "center"
+ if "va" not in label_offset and "verticalalignment" not in label_formatting:
+ label_formatting["va"] = "center"
+ for coordinate, label in labels:
+ texts.append(
+ ax.text(
+ coordinate[0] + label_offset[0],
+ coordinate[1] + label_offset[1],
+ label,
+ **label_formatting,
+ )
+ )
+ if units == "real":
+ ax.set_xlabel(r"$\AA^{-1}$")
+ ax.set_ylabel(r"$\AA^{-1}$")
+ else:
+ ax.set_xlabel("pixels")
+ ax.set_ylabel("pixels")
+ if (
+ interactive and self.has_multiple_rotations or self.has_multiple_phases
+ ): # pragma: no cover
+ axrot = fig.add_axes([0.5, 0.05, 0.4, 0.03])
+ axphase = fig.add_axes([0.1, 0.05, 0.2, 0.03])
+
+ fig.subplots_adjust(left=0.25, bottom=0.25)
+ if self.has_multiple_phases:
+ max_rot = np.max([r.size for r in self.rotations])
+ rotation_slider = Slider(
+ ax=axrot,
+ label="Rotation",
+ valmin=0,
+ valmax=max_rot - 1,
+ valinit=self.rotation_index,
+ valstep=1,
+ orientation="horizontal",
+ )
+ phase_slider = Slider(
+ ax=axphase,
+ label="Phase ",
+ valmin=0,
+ valmax=self.phases.size - 1,
+ valinit=self.phase_index,
+ valstep=1,
+ orientation="horizontal",
+ )
+ else: # self.has_multiple_rotations:
+ rotation_slider = Slider(
+ ax=axrot,
+ label="Rotation",
+ valmin=0,
+ valmax=self.rotations.size - 1,
+ valinit=self.rotation_index,
+ valstep=1,
+ orientation="horizontal",
+ )
+ phase_slider = None
+ self._rotation_slider = rotation_slider
+ self._phase_slider = phase_slider
+
+ def update(val):
+ if self.has_multiple_rotations and self.has_multiple_phases:
+ self.rotation_index = int(rotation_slider.val)
+ self.phase_index = int(phase_slider.val)
+ self._rotation_slider.valmax = (
+ self.rotations[self.phase_index].size - 1
+ )
+ elif self.has_multiple_rotations:
+ self.rotation_index = int(rotation_slider.val)
+ else:
+ self.phase_index = int(phase_slider.val)
+ spots, intensity, coords = self._get_spots(
+ in_plane_angle,
+ direct_beam_position,
+ mirrored,
+ units,
+ calibration,
+ include_direct_beam,
+ )
+ sp.set(
+ offsets=spots,
+ sizes=size_factor * np.sqrt(intensity),
+ )
+ for t in texts:
+ t.remove()
+ texts.clear()
+ if show_labels:
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ labels = self._get_labels(
+ coords, intensity, min_label_intensity, xlim, ylim
+ )
+ for coordinate, label in labels:
+ # this could be faster using a TextCollection when available in matplotlib
+ texts.append(
+ ax.text(
+ coordinate[0] + label_offset[0],
+ coordinate[1] + label_offset[1],
+ label,
+ **label_formatting,
+ )
+ )
+ fig.canvas.draw_idle()
+
+ if self._rotation_slider is not None:
+ self._rotation_slider.on_changed(update)
+ if self._phase_slider is not None:
+ self._phase_slider.on_changed(update)
+ return ax, sp
+
+
+def get_closest(array, values):
+ # make sure array is a numpy array
+ array = np.array(array)
+
+ # get insert positions
+ idxs = np.searchsorted(array, values, side="left")
+
+ # find indexes where previous index is closer
+ prev_idx_is_less = (idxs == len(array)) | (
+ np.fabs(values - array[np.maximum(idxs - 1, 0)])
+ < np.fabs(values - array[np.minimum(idxs, len(array) - 1)])
+ )
+ idxs[prev_idx_is_less] -= 1
+
+ return idxs
diff --git a/diffsims/tests/crystallography/test_diffracting_vector.py b/diffsims/tests/crystallography/test_diffracting_vector.py
new file mode 100644
index 00000000..bf4ccdf6
--- /dev/null
+++ b/diffsims/tests/crystallography/test_diffracting_vector.py
@@ -0,0 +1,81 @@
+from diffsims.crystallography import ReciprocalLatticeVector
+from diffsims.crystallography._diffracting_vector import DiffractingVector
+from orix.quaternion import Rotation
+
+import pytest
+import numpy as np
+
+
+class TestDiffractingVector:
+ def test_init(self, ferrite_phase):
+ rlv = DiffractingVector(
+ ferrite_phase, hkl=[[1, 1, 1], [2, 0, 0]], intensity=[1, 2]
+ )
+ assert rlv.phase == ferrite_phase
+ assert rlv.shape == (2,)
+ assert rlv.hkl.shape == (2, 3)
+ assert np.allclose(rlv.hkl, [[1, 1, 1], [2, 0, 0]])
+ assert np.allclose(rlv.intensity, [1, 2])
+
+ def test_init_wrong_intensity_length(self, ferrite_phase):
+ with pytest.raises(ValueError):
+ DiffractingVector(ferrite_phase, hkl=[[1, 1, 1], [2, 0, 0]], intensity=[1])
+
+ def test_add_intensity(self, ferrite_phase):
+ rlv = DiffractingVector.from_min_dspacing(ferrite_phase, 1.5)
+ rlv.intensity = 1
+ assert isinstance(rlv.intensity, np.ndarray)
+ assert np.allclose(rlv.intensity, np.ones(rlv.size))
+
+ def test_add_intensity_error(self, ferrite_phase):
+ rlv = DiffractingVector.from_min_dspacing(ferrite_phase, 1.5)
+ with pytest.raises(ValueError):
+ rlv.intensity = [0, 1]
+
+ def test_slicing(self, ferrite_phase):
+ rlv = DiffractingVector.from_min_dspacing(ferrite_phase, 1.5)
+ rlv.intensity = 1
+ rlv_slice = rlv[0:3]
+ assert rlv_slice.size == 3
+ assert np.allclose(rlv_slice.intensity, np.ones(3))
+
+ def test_structure_factor(self, ferrite_phase):
+ rlv = DiffractingVector.from_min_dspacing(ferrite_phase, 1.5)
+ with pytest.raises(NotImplementedError):
+ rlv.calculate_structure_factor()
+
+ def test_hkl(self, ferrite_phase):
+ rlv = ReciprocalLatticeVector(ferrite_phase, hkl=[[1, 1, 1], [2, 0, 0]])
+ rot = Rotation.from_euler([90, 90, 0], degrees=True)
+ rotated_vectors = (~rot * rlv.to_miller()).data
+ ferrite_phase2 = ferrite_phase.deepcopy()
+ ferrite_phase2.structure.lattice.setLatPar(baserot=rot.to_matrix()[0])
+ dv = DiffractingVector(ferrite_phase2, xyz=rotated_vectors)
+ assert np.allclose(rlv.hkl, dv.hkl)
+
+ def test_flat_polar(self, ferrite_phase):
+ dv = DiffractingVector(ferrite_phase, xyz=[[1, 1, 1], [0.5, -0.5, 0]])
+ r, t = dv.to_flat_polar()
+ assert np.allclose(r, [np.sqrt(2), 0.70710678])
+ assert np.allclose(t, [np.pi / 4, -np.pi / 4])
+ dv = DiffractingVector(
+ ferrite_phase,
+ xyz=[[[1, 1, 1], [0.5, -0.5, 0]], [[1, 1, 1], [0.5, -0.5, 0]]],
+ )
+ r, t = dv.to_flat_polar()
+ assert np.allclose(r, [np.sqrt(2), np.sqrt(2), 0.70710678, 0.70710678])
+ assert np.allclose(t, [np.pi / 4, np.pi / 4, -np.pi / 4, -np.pi / 4])
+
+ def test_get_lattice_basis_rotation(self, ferrite_phase):
+ """Rotation matrix to align the lattice basis with the Cartesian
+ basis is correct.
+ """
+ rlv = DiffractingVector(ferrite_phase, hkl=[[1, 1, 1], [2, 0, 0]])
+ rot = rlv.basis_rotation
+ assert np.allclose(rot.to_matrix(), np.eye(3))
+
+ def test_rotation_with_basis_raises(self, ferrite_phase):
+ rlv = DiffractingVector(ferrite_phase, hkl=[[1, 1, 1], [2, 0, 0]])
+ rot = Rotation.from_euler([[90, 90, 0], [90, 90, 1]], degrees=True)
+ with pytest.raises(ValueError):
+ rlv.rotate_with_basis(rotation=rot)
diff --git a/diffsims/tests/crystallography/test_reciprocal_lattice_vector.py b/diffsims/tests/crystallography/test_reciprocal_lattice_vector.py
index d8717107..bd9ada2f 100644
--- a/diffsims/tests/crystallography/test_reciprocal_lattice_vector.py
+++ b/diffsims/tests/crystallography/test_reciprocal_lattice_vector.py
@@ -20,6 +20,7 @@
import numpy as np
from orix.crystal_map import Phase
from orix.vector import Miller, Vector3d
+
import pytest
from diffsims.crystallography import ReciprocalLatticeVector
@@ -72,12 +73,20 @@ def test_init_raises(self, nickel_phase):
with pytest.raises(ValueError, match="Exactly one of "):
_ = ReciprocalLatticeVector(nickel_phase)
+ @pytest.mark.parametrize("include_zero_vector", [True, False])
@pytest.mark.parametrize("d, desired_size", [(2, 18), (1, 92), (0.5, 750)])
- def test_init_from_min_dspacing(self, ferrite_phase, d, desired_size):
+ def test_init_from_min_dspacing(
+ self, ferrite_phase, d, desired_size, include_zero_vector
+ ):
"""Class method gives desired number of vectors."""
- rlv = ReciprocalLatticeVector.from_min_dspacing(ferrite_phase, d)
+ rlv = ReciprocalLatticeVector.from_min_dspacing(
+ ferrite_phase, d, include_zero_vector=include_zero_vector
+ )
+ if include_zero_vector:
+ desired_size += 1
assert rlv.size == desired_size
+ @pytest.mark.parametrize("include_zero_vector", [True, False])
@pytest.mark.parametrize(
"hkl, desired_highest_hkl, desired_lowest_hkl, desired_size",
[
@@ -93,9 +102,15 @@ def test_init_from_highest_hkl(
desired_highest_hkl,
desired_lowest_hkl,
desired_size,
+ include_zero_vector,
):
"""Class method gives desired number of vectors and indices."""
- rlv = ReciprocalLatticeVector.from_highest_hkl(silicon_carbide_phase, hkl)
+ rlv = ReciprocalLatticeVector.from_highest_hkl(
+ silicon_carbide_phase, hkl, include_zero_vector=include_zero_vector
+ )
+ if include_zero_vector:
+ desired_size += 1
+ desired_lowest_hkl = [0, 0, 0]
assert np.allclose(rlv[0].hkl, desired_highest_hkl)
assert np.allclose(rlv[-1].hkl, desired_lowest_hkl)
assert rlv.size == desired_size
diff --git a/diffsims/tests/generators/old_simulation.npy b/diffsims/tests/generators/old_simulation.npy
new file mode 100644
index 00000000..90008a11
Binary files /dev/null and b/diffsims/tests/generators/old_simulation.npy differ
diff --git a/diffsims/tests/generators/test_simulation_generator.py b/diffsims/tests/generators/test_simulation_generator.py
new file mode 100644
index 00000000..a40c6f50
--- /dev/null
+++ b/diffsims/tests/generators/test_simulation_generator.py
@@ -0,0 +1,339 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+import numpy as np
+import pytest
+from pathlib import Path
+
+import diffpy.structure
+from orix.crystal_map import Phase
+from orix.quaternion import Rotation
+
+from diffsims.generators.simulation_generator import SimulationGenerator
+from diffsims.utils.shape_factor_models import (
+ linear,
+ binary,
+ sin2c,
+ atanc,
+ lorentzian,
+ _shape_factor_precession,
+)
+from diffsims.simulations import Simulation1D
+from diffsims.utils.sim_utils import is_lattice_hexagonal
+
+TEST_DATA_DIR = Path(__file__).parent
+FILE1 = TEST_DATA_DIR / "old_simulation.npy"
+
+
+@pytest.fixture(params=[(300)])
+def diffraction_calculator(request):
+ return SimulationGenerator(request.param)
+
+
+@pytest.fixture(scope="module")
+def diffraction_calculator_precession_full():
+ return SimulationGenerator(300, precession_angle=0.5, approximate_precession=False)
+
+
+@pytest.fixture(scope="module")
+def diffraction_calculator_precession_simple():
+ return SimulationGenerator(300, precession_angle=0.5, approximate_precession=True)
+
+
+def local_excite(excitation_error, maximum_excitation_error, t):
+ return (np.sin(t) * excitation_error) / maximum_excitation_error
+
+
+@pytest.fixture(scope="module")
+def diffraction_calculator_custom():
+ return SimulationGenerator(300, shape_factor_model=local_excite, t=0.2)
+
+
+def make_phase(lattice_parameter=None):
+ """
+ We construct an Fd-3m silicon (with lattice parameter 5.431 as a default)
+ """
+ if lattice_parameter is not None:
+ a = lattice_parameter
+ else:
+ a = 5.431
+ latt = diffpy.structure.lattice.Lattice(a, a, a, 90, 90, 90)
+ # TODO - Make this construction with internal diffpy syntax
+ atom_list = []
+ for coords in [[0, 0, 0], [0.5, 0, 0.5], [0, 0.5, 0.5], [0.5, 0.5, 0]]:
+ x, y, z = coords[0], coords[1], coords[2]
+ atom_list.append(
+ diffpy.structure.atom.Atom(atype="Si", xyz=[x, y, z], lattice=latt)
+ ) # Motif part A
+ atom_list.append(
+ diffpy.structure.atom.Atom(
+ atype="Si", xyz=[x + 0.25, y + 0.25, z + 0.25], lattice=latt
+ )
+ ) # Motif part B
+ struct = diffpy.structure.Structure(atoms=atom_list, lattice=latt)
+ p = Phase(structure=struct, space_group=227)
+ return p
+
+
+@pytest.fixture()
+def local_structure():
+ return make_phase()
+
+
+@pytest.mark.parametrize("model", [binary, linear, atanc, sin2c, lorentzian])
+def test_shape_factor_precession(model):
+ excitation = np.array([-0.1, 0.1])
+ r = np.array([1, 5])
+ _ = _shape_factor_precession(excitation, r, 0.5, model, 0.1)
+
+
+def test_linear_shape_factor():
+ excitation = np.array([-2, -1, -0.5, 0, 0.5, 1, 2])
+ totest = linear(excitation, 1)
+ np.testing.assert_allclose(totest, np.array([0, 0, 0.5, 1, 0.5, 0, 0]))
+ np.testing.assert_allclose(linear(0.5, 1), 0.5)
+
+
+@pytest.mark.parametrize(
+ "model, expected",
+ [("linear", linear), ("lorentzian", lorentzian), (binary, binary)],
+)
+def test_diffraction_generator_init(model, expected):
+ generator = SimulationGenerator(300, shape_factor_model=model)
+ assert generator.shape_factor_model == expected
+
+
+class TestDiffractionCalculator:
+ def test_init(self, diffraction_calculator: SimulationGenerator):
+ assert diffraction_calculator.scattering_params == "lobato"
+ assert diffraction_calculator.precession_angle == 0
+ assert diffraction_calculator.shape_factor_model == lorentzian
+ assert diffraction_calculator.approximate_precession == True
+ assert diffraction_calculator.minimum_intensity == 1e-20
+
+ def test_matching_results(
+ self, diffraction_calculator: SimulationGenerator, local_structure
+ ):
+ diffraction = diffraction_calculator.calculate_diffraction2d(
+ local_structure, reciprocal_radius=5.0
+ )
+ assert diffraction.coordinates.size == 70
+
+ def test_precession_simple(
+ self, diffraction_calculator_precession_simple, local_structure
+ ):
+ diffraction = diffraction_calculator_precession_simple.calculate_diffraction2d(
+ local_structure,
+ reciprocal_radius=5.0,
+ )
+ assert diffraction.coordinates.size == 250
+
+ def test_precession_full(
+ self, diffraction_calculator_precession_full, local_structure
+ ):
+ diffraction = diffraction_calculator_precession_full.calculate_diffraction2d(
+ local_structure,
+ reciprocal_radius=5.0,
+ )
+ assert diffraction.coordinates.size == 250
+
+ def test_custom_shape_func(self, diffraction_calculator_custom, local_structure):
+ diffraction = diffraction_calculator_custom.calculate_diffraction2d(
+ local_structure,
+ reciprocal_radius=5.0,
+ )
+ assert diffraction.coordinates.size == 52
+
+ def test_appropriate_scaling(self, diffraction_calculator: SimulationGenerator):
+ """Tests that doubling the unit cell halves the pattern spacing."""
+ silicon = make_phase(5)
+ big_silicon = make_phase(10)
+ diffraction = diffraction_calculator.calculate_diffraction2d(
+ phase=silicon, reciprocal_radius=5.0
+ )
+ big_diffraction = diffraction_calculator.calculate_diffraction2d(
+ phase=big_silicon, reciprocal_radius=5.0
+ )
+ indices = [tuple(i) for i in diffraction.coordinates.hkl]
+ big_indices = [tuple(i) for i in big_diffraction.coordinates.hkl]
+ assert (2, 2, 0) in indices
+ assert (2, 2, 0) in big_indices
+ coordinates = diffraction.coordinates[indices.index((2, 2, 0))]
+ big_coordinates = big_diffraction.coordinates[big_indices.index((2, 2, 0))]
+ assert np.allclose(coordinates.data, big_coordinates.data * 2)
+
+ def test_appropriate_intensities(self, diffraction_calculator, local_structure):
+ """Tests the central beam is strongest."""
+ diffraction = diffraction_calculator.calculate_diffraction2d(
+ local_structure, reciprocal_radius=0.5, with_direct_beam=True
+ )
+ indices = [tuple(np.round(i).astype(int)) for i in diffraction.coordinates.hkl]
+ central_beam = indices.index((0, 0, 0))
+
+ smaller = np.greater_equal(
+ diffraction.coordinates.intensity[central_beam],
+ diffraction.coordinates.intensity,
+ )
+ assert np.all(smaller)
+
+ def test_direct_beam(self, diffraction_calculator, local_structure):
+ diffraction = diffraction_calculator.calculate_diffraction2d(
+ local_structure, reciprocal_radius=0.5, with_direct_beam=False
+ )
+ indices = [tuple(np.round(i).astype(int)) for i in diffraction.coordinates.hkl]
+ with pytest.raises(ValueError):
+ indices.index((0, 0, 0))
+
+ def test_shape_factor_strings(self, diffraction_calculator, local_structure):
+ _ = diffraction_calculator.calculate_diffraction2d(
+ local_structure,
+ )
+
+ def test_shape_factor_custom(self, diffraction_calculator, local_structure):
+ t1 = diffraction_calculator.calculate_diffraction2d(
+ local_structure, max_excitation_error=0.02
+ )
+ t2 = diffraction_calculator.calculate_diffraction2d(
+ local_structure, max_excitation_error=0.4
+ )
+ # softly makes sure the two sims are different
+ assert np.sum(t1.coordinates.intensity) != np.sum(t2.coordinates.intensity)
+
+ @pytest.mark.parametrize("is_hex", [True, False])
+ def test_simulate_1d(self, is_hex):
+ generator = SimulationGenerator(300)
+ phase = make_phase()
+ if is_hex:
+ phase.structure.lattice.a = phase.structure.lattice.b
+ phase.structure.lattice.alpha = 90
+ phase.structure.lattice.beta = 90
+ phase.structure.lattice.gamma = 120
+ assert is_lattice_hexagonal(phase.structure.lattice)
+ else:
+ assert not is_lattice_hexagonal(phase.structure.lattice)
+ sim = generator.calculate_diffraction1d(phase, 0.5)
+ assert isinstance(sim, Simulation1D)
+
+ assert len(sim.intensities) == len(sim.reciprocal_spacing)
+ assert len(sim.intensities) == len(sim.hkl)
+ for h in sim.hkl:
+ h = h.replace("-", "")
+ if is_hex:
+ assert len(h) == 4
+ else:
+ assert len(h) == 3
+
+
+def test_multiphase_multirotation_simulation():
+ generator = SimulationGenerator(300)
+ silicon = make_phase(5)
+ big_silicon = make_phase(10)
+ rot = Rotation.from_euler([[0, 0, 0], [0.1, 0.1, 0.1]])
+ rot2 = Rotation.from_euler([[0, 0, 0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2]])
+ sim = generator.calculate_diffraction2d(
+ [silicon, big_silicon], rotation=[rot, rot2]
+ )
+
+
+def test_multiphase_multirotation_simulation_error():
+ generator = SimulationGenerator(300)
+ silicon = make_phase(5)
+ big_silicon = make_phase(10)
+ rot = Rotation.from_euler([[0, 0, 0], [0.1, 0.1, 0.1]])
+ rot2 = Rotation.from_euler([[0, 0, 0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2]])
+ with pytest.raises(ValueError):
+ sim = generator.calculate_diffraction2d([silicon, big_silicon], rotation=[rot])
+
+
+@pytest.mark.parametrize("scattering_param", ["lobato", "xtables"])
+def test_param_check(scattering_param):
+ generator = SimulationGenerator(300, scattering_params=scattering_param)
+
+
+@pytest.mark.xfail(raises=NotImplementedError)
+def test_invalid_scattering_params():
+ scattering_param = "_empty"
+ generator = SimulationGenerator(300, scattering_params=scattering_param)
+
+
+@pytest.mark.xfail(faises=NotImplementedError)
+def test_invalid_shape_model():
+ generator = SimulationGenerator(300, shape_factor_model="dracula")
+
+
+def test_same_simulation_results():
+ # This test is to ensure that the new SimulationGenerator produces the same
+ # results as the old DiffractionLibraryGenerator. Based on comments from
+ # viljarjf in https://github.com/pyxem/diffsims/pull/201
+ # Shared parameters
+ latt = diffpy.structure.lattice.Lattice(2.464, 2.464, 6.711, 90, 90, 120)
+ atoms = [
+ diffpy.structure.atom.Atom(atype="C", xyz=[0.0, 0.0, 0.25], lattice=latt),
+ diffpy.structure.atom.Atom(atype="C", xyz=[0.0, 0.0, 0.75], lattice=latt),
+ diffpy.structure.atom.Atom(atype="C", xyz=[1 / 3, 2 / 3, 0.25], lattice=latt),
+ diffpy.structure.atom.Atom(atype="C", xyz=[2 / 3, 1 / 3, 0.75], lattice=latt),
+ ]
+ structure_matrix = diffpy.structure.Structure(atoms=atoms, lattice=latt)
+ calibration = 0.0262
+ reciprocal_radius = 1.6768
+ with_direct_beam = False
+ max_excitation_error = 0.1
+ accelerating_voltage = 200
+ shape = (128, 128)
+ sigma = 1.4
+ generator_kwargs = {
+ "accelerating_voltage": accelerating_voltage,
+ "scattering_params": "lobato",
+ "precession_angle": 0,
+ "shape_factor_model": "lorentzian",
+ "approximate_precession": True,
+ "minimum_intensity": 1e-20,
+ }
+ # The euler angles are different, as orix.Phase enforces x||a, z||c*
+ # euler_angles_old = np.array([[0, 90, 120]])
+ euler_angles_new = np.array([[0, 90, 90]])
+
+ # Old way. For creating the old data.
+ # struct_library = StructureLibrary(["Graphite"], [structure_matrix], [euler_angles_old])
+ # diff_gen = DiffractionGenerator(**generator_kwargs)
+ # lib_gen = DiffractionLibraryGenerator(diff_gen)
+ # diff_lib = lib_gen.get_diffraction_library(struct_library,
+ # calibration=calibration,
+ # reciprocal_radius=reciprocal_radius,
+ # with_direct_beam=with_direct_beam,
+ # max_excitation_error=max_excitation_error,
+ # half_shape=shape[0] // 2,
+ # )
+ # old_data = diff_lib["Graphite"]["simulations"][0].get_diffraction_pattern(shape=shape, sigma=sigma)
+
+ # New
+ p = Phase("Graphite", point_group="6/mmm", structure=structure_matrix)
+ gen = SimulationGenerator(**generator_kwargs)
+ rot = Rotation.from_euler(euler_angles_new, degrees=True)
+ sim = gen.calculate_diffraction2d(
+ phase=p,
+ rotation=rot,
+ reciprocal_radius=reciprocal_radius,
+ max_excitation_error=max_excitation_error,
+ with_direct_beam=with_direct_beam,
+ )
+ new_data = sim.get_diffraction_pattern(
+ shape=shape, sigma=sigma, calibration=calibration
+ )
+ old_data = np.load(FILE1)
+ np.testing.assert_allclose(new_data, old_data, atol=1e-8)
diff --git a/diffsims/tests/simulations/__init__.py b/diffsims/tests/simulations/__init__.py
new file mode 100644
index 00000000..8ff2bc07
--- /dev/null
+++ b/diffsims/tests/simulations/__init__.py
@@ -0,0 +1,17 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
diff --git a/diffsims/tests/simulations/test_simulations1d.py b/diffsims/tests/simulations/test_simulations1d.py
new file mode 100644
index 00000000..4900e664
--- /dev/null
+++ b/diffsims/tests/simulations/test_simulations1d.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+import matplotlib.pyplot as plt
+import pytest
+
+from orix.crystal_map import Phase
+import numpy as np
+
+from diffsims.tests.generators.test_simulation_generator import make_phase
+from diffsims.simulations import Simulation1D
+
+
+class TestSingleSimulation:
+ @pytest.fixture
+ def simulation1d(self):
+ al_phase = make_phase()
+ al_phase.name = "Al"
+ hkls = np.array(["100", "110", "111"])
+ magnitudes = np.array([1, 2, 3])
+ inten = np.array([1, 2, 3])
+ recip = 4.0
+
+ return Simulation1D(
+ phase=al_phase,
+ hkl=hkls,
+ reciprocal_spacing=magnitudes,
+ intensities=inten,
+ reciprocal_radius=recip,
+ wavelength=0.025,
+ )
+
+ def test_init(self, simulation1d):
+ assert isinstance(simulation1d, Simulation1D)
+ assert isinstance(simulation1d.phase, Phase)
+ assert isinstance(simulation1d.hkl, np.ndarray)
+ assert isinstance(simulation1d.reciprocal_spacing, np.ndarray)
+ assert isinstance(simulation1d.intensities, np.ndarray)
+ assert isinstance(simulation1d.reciprocal_radius, float)
+
+ @pytest.mark.parametrize("annotate", [True, False])
+ @pytest.mark.parametrize("ax", [None, "new"])
+ @pytest.mark.parametrize("with_labels", [True, False])
+ def test_plot(self, simulation1d, annotate, ax, with_labels):
+ if ax == "new":
+ fig, ax = plt.subplots()
+ fig = simulation1d.plot(annotate_peaks=annotate, ax=ax, with_labels=with_labels)
+
+ def test_repr(self, simulation1d):
+ assert simulation1d.__repr__() == "Simulation1D(name: Al, wavelength: 0.025)"
+
+ def test_theta(self, simulation1d):
+ np.testing.assert_almost_equal(
+ simulation1d.theta, np.array([0.02499479, 0.0499584, 0.07485985])
+ )
diff --git a/diffsims/tests/simulations/test_simulations2d.py b/diffsims/tests/simulations/test_simulations2d.py
new file mode 100644
index 00000000..43d5f419
--- /dev/null
+++ b/diffsims/tests/simulations/test_simulations2d.py
@@ -0,0 +1,458 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017-2024 The diffsims developers
+#
+# This file is part of diffsims.
+#
+# diffsims is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# diffsims is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with diffsims. If not, see .
+
+import numpy as np
+import pytest
+
+from diffpy.structure import Structure, Atom, Lattice
+from orix.crystal_map import Phase
+from orix.quaternion import Rotation
+
+from diffsims.simulations import Simulation2D
+from diffsims.generators.simulation_generator import SimulationGenerator
+from diffsims.crystallography._diffracting_vector import DiffractingVector
+
+
+@pytest.fixture(scope="module")
+def al_phase():
+ p = Phase(
+ name="al",
+ space_group=225,
+ structure=Structure(
+ atoms=[Atom("al", [0, 0, 0])],
+ lattice=Lattice(0.405, 0.405, 0.405, 90, 90, 90),
+ ),
+ )
+ return p
+
+
+class TestSingleSimulation:
+ @pytest.fixture
+ def single_simulation(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], 45, degrees=True)
+ coords = DiffractingVector(phase=al_phase, xyz=[[1, 0, 0]])
+ sim = Simulation2D(
+ phases=al_phase, simulation_generator=gen, coordinates=coords, rotations=rot
+ )
+ return sim
+
+ def test_init(self, single_simulation):
+ assert isinstance(single_simulation, Simulation2D)
+ assert isinstance(single_simulation.phases, Phase)
+ assert isinstance(single_simulation.simulation_generator, SimulationGenerator)
+ assert isinstance(single_simulation.rotations, Rotation)
+
+ def test_get_simulation(self, single_simulation):
+ rotation, phase, coords = single_simulation.get_simulation(0)
+ assert isinstance(rotation, Rotation)
+ assert phase == 0
+
+ def test_iphase(self, single_simulation):
+ with pytest.raises(ValueError):
+ single_simulation.iphase[0]
+
+ def test_irot(self, single_simulation):
+ with pytest.raises(ValueError):
+ single_simulation.irot[0]
+
+ def test_iter(self, single_simulation):
+ count = 0
+ for sim in single_simulation:
+ count += 1
+ assert isinstance(sim, DiffractingVector)
+ assert count == 1
+
+ def test_plot(self, single_simulation):
+ single_simulation.plot()
+
+ def test_num_rotations(self, single_simulation):
+ assert single_simulation._num_rotations() == 1
+
+ def test_polar_flatten(self, single_simulation):
+ (
+ r_templates,
+ theta_templates,
+ intensities_templates,
+ ) = single_simulation.polar_flatten_simulations()
+ assert r_templates.shape == (1, 1)
+ assert theta_templates.shape == (1, 1)
+ assert intensities_templates.shape == (1, 1)
+
+ def test_polar_flatten_axes(self, single_simulation):
+ radial_axes = np.linspace(0, 1, 10)
+ theta_axes = np.linspace(0, 2 * np.pi, 10)
+ (
+ r_templates,
+ theta_templates,
+ intensities_templates,
+ ) = single_simulation.polar_flatten_simulations(
+ radial_axes=radial_axes, azimuthal_axes=theta_axes
+ )
+ assert r_templates.shape == (1, 1)
+ assert theta_templates.shape == (1, 1)
+ assert intensities_templates.shape == (1, 1)
+
+ def test_deepcopy(self, single_simulation):
+ copied = single_simulation.deepcopy()
+ assert copied is not single_simulation
+
+
+class TestSimulationInitFailures:
+ def test_different_size(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], 45, degrees=True)
+ coords = DiffractingVector(phase=al_phase, xyz=[[1, 0, 0], [1, 1, 1]])
+ with pytest.raises(ValueError):
+ sim = Simulation2D(
+ phases=al_phase,
+ simulation_generator=gen,
+ coordinates=[coords, coords],
+ rotations=rot,
+ )
+
+ def test_different_size2(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], (0, 45), degrees=True)
+ coords = DiffractingVector(phase=al_phase, xyz=[[1, 0, 0], [1, 1, 1]])
+ with pytest.raises(ValueError):
+ sim = Simulation2D(
+ phases=al_phase,
+ simulation_generator=gen,
+ coordinates=[coords, coords, coords],
+ rotations=rot,
+ )
+
+ def test_different_size_multiphase(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], 45, degrees=True)
+ coords = DiffractingVector(phase=al_phase, xyz=[[1, 0, 0], [1, 1, 1]])
+ with pytest.raises(ValueError):
+ sim = Simulation2D(
+ phases=[al_phase, al_phase],
+ simulation_generator=gen,
+ coordinates=[[coords, coords], [coords, coords]],
+ rotations=[rot, rot],
+ )
+
+ def test_different_num_phase(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], 45, degrees=True)
+ coords = DiffractingVector(phase=al_phase, xyz=[[1, 0, 0], [1, 1, 1]])
+ with pytest.raises(ValueError):
+ sim = Simulation2D(
+ phases=[al_phase, al_phase],
+ simulation_generator=gen,
+ coordinates=[[coords, coords], [coords, coords], [coords, coords]],
+ rotations=[rot, rot],
+ )
+
+ def test_different_num_phase_and_rot(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], 45, degrees=True)
+ coords = DiffractingVector(phase=al_phase, xyz=[[1, 0, 0], [1, 1, 1]])
+ with pytest.raises(ValueError):
+ sim = Simulation2D(
+ phases=[al_phase, al_phase],
+ simulation_generator=gen,
+ coordinates=[[coords, coords], [coords, coords], [coords, coords]],
+ rotations=[rot, rot, rot],
+ )
+
+
+class TestSinglePhaseMultiSimulation:
+ @pytest.fixture
+ def al_phase(self):
+ p = Phase(
+ name="al",
+ space_group=225,
+ structure=Structure(
+ atoms=[Atom("al", [0, 0, 0])],
+ lattice=Lattice(0.405, 0.405, 0.405, 90, 90, 90),
+ ),
+ )
+ return p
+
+ @pytest.fixture
+ def multi_simulation(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], (0, 15, 30, 45), degrees=True)
+ coords = DiffractingVector(
+ phase=al_phase,
+ xyz=[[1, 0, 0], [0, 1, 0], [1, 1, 0], [1, 1, 1]],
+ intensity=[1, 2, 3, 4],
+ )
+
+ vectors = [coords, coords, coords, coords]
+
+ sim = Simulation2D(
+ phases=al_phase,
+ simulation_generator=gen,
+ coordinates=vectors,
+ rotations=rot,
+ )
+ return sim
+
+ def test_get_simulation(self, multi_simulation):
+ for i in range(4):
+ rotation, phase, coords = multi_simulation.get_simulation(i)
+ assert isinstance(rotation, Rotation)
+ assert phase == 0
+
+ def test_get_current_rotation(self, multi_simulation):
+ rot = multi_simulation.get_current_rotation_matrix()
+ np.testing.assert_array_equal(rot, multi_simulation.rotations[0].to_matrix()[0])
+
+ def test_init(self, multi_simulation):
+ assert isinstance(multi_simulation, Simulation2D)
+ assert isinstance(multi_simulation.phases, Phase)
+ assert isinstance(multi_simulation.simulation_generator, SimulationGenerator)
+ assert isinstance(multi_simulation.rotations, Rotation)
+ assert isinstance(multi_simulation.coordinates, np.ndarray)
+
+ def test_iphase(self, multi_simulation):
+ with pytest.raises(ValueError):
+ multi_simulation.iphase[0]
+
+ def test_irot(self, multi_simulation):
+ sliced_sim = multi_simulation.irot[0]
+ assert isinstance(sliced_sim, Simulation2D)
+ assert isinstance(sliced_sim.phases, Phase)
+ assert sliced_sim.rotations.size == 1
+ assert sliced_sim.coordinates.size == 4
+
+ def test_irot_slice(self, multi_simulation):
+ sliced_sim = multi_simulation.irot[0:2]
+ assert isinstance(sliced_sim, Simulation2D)
+ assert isinstance(sliced_sim.phases, Phase)
+ assert sliced_sim.rotations.size == 2
+ assert sliced_sim.coordinates.size == 2
+
+ def test_plot(self, multi_simulation):
+ multi_simulation.plot()
+
+ def test_plot_rotation(self, multi_simulation):
+ multi_simulation.plot_rotations()
+
+ def test_iter(self, multi_simulation):
+ multi_simulation.phase_index = 0
+ multi_simulation.rotation_index = 0
+ count = 0
+ for sim in multi_simulation:
+ count += 1
+ assert isinstance(sim, DiffractingVector)
+ assert count == 4
+
+ def test_polar_flatten(self, multi_simulation):
+ (
+ r_templates,
+ theta_templates,
+ intensities_templates,
+ ) = multi_simulation.polar_flatten_simulations()
+ assert r_templates.shape == (4, 4)
+ assert theta_templates.shape == (4, 4)
+ assert intensities_templates.shape == (4, 4)
+
+
+class TestMultiPhaseMultiSimulation:
+ @pytest.fixture
+ def al_phase(self):
+ p = Phase(
+ name="al",
+ space_group=225,
+ structure=Structure(
+ atoms=[Atom("al", [0, 0, 0])],
+ lattice=Lattice(0.405, 0.405, 0.405, 90, 90, 90),
+ ),
+ )
+ return p
+
+ @pytest.fixture
+ def multi_simulation(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], (0, 15, 30, 45), degrees=True)
+ rot2 = rot
+ coords = DiffractingVector(
+ phase=al_phase,
+ xyz=[
+ [1, 0, 0],
+ [0, -0.3, 0],
+ [1 / 0.405, 1 / -0.405, 0],
+ [0.1, -0.1, -0.3],
+ ],
+ )
+ coords.intensity = 1
+ vectors = [coords, coords, coords, coords]
+ al_phase2 = al_phase.deepcopy()
+ al_phase2.name = "al2"
+ sim = Simulation2D(
+ phases=[al_phase, al_phase2],
+ simulation_generator=gen,
+ coordinates=[vectors, vectors],
+ rotations=[rot, rot2],
+ )
+ return sim
+
+ def test_init(self, multi_simulation):
+ assert isinstance(multi_simulation, Simulation2D)
+ assert isinstance(multi_simulation.phases, np.ndarray)
+ assert isinstance(multi_simulation.simulation_generator, SimulationGenerator)
+ assert isinstance(multi_simulation.rotations, np.ndarray)
+ assert isinstance(multi_simulation.coordinates, np.ndarray)
+
+ def test_get_simulation(self, multi_simulation):
+ for i in range(4):
+ rotation, phase, coords = multi_simulation.get_simulation(i)
+ assert isinstance(rotation, Rotation)
+ assert phase == 0
+ for i in range(4, 8):
+ rotation, phase, coords = multi_simulation.get_simulation(i)
+ assert isinstance(rotation, Rotation)
+ assert phase == 1
+
+ def test_iphase(self, multi_simulation):
+ phase_slic = multi_simulation.iphase[0]
+ assert isinstance(phase_slic, Simulation2D)
+ assert isinstance(phase_slic.phases, Phase)
+ assert phase_slic.rotations.size == 4
+
+ def test_iphase_str(self, multi_simulation):
+ phase_slic = multi_simulation.iphase["al"]
+ assert isinstance(phase_slic, Simulation2D)
+ assert isinstance(phase_slic.phases, Phase)
+ assert phase_slic.rotations.size == 4
+ assert phase_slic.phases.name == "al"
+
+ def test_iphase_error(self, multi_simulation):
+ with pytest.raises(ValueError):
+ phase_slic = multi_simulation.iphase[3.1]
+
+ def test_irot(self, multi_simulation):
+ sliced_sim = multi_simulation.irot[0]
+ assert isinstance(sliced_sim, Simulation2D)
+ assert isinstance(sliced_sim.phases, np.ndarray)
+ assert sliced_sim.rotations.size == 2
+
+ def test_irot_slice(self, multi_simulation):
+ sliced_sim = multi_simulation.irot[0:2]
+ assert isinstance(sliced_sim, Simulation2D)
+ assert isinstance(sliced_sim.phases, np.ndarray)
+ assert sliced_sim.rotations.size == 2
+
+ @pytest.mark.parametrize("show_labels", [True, False])
+ @pytest.mark.parametrize("units", ["real", "pixel"])
+ @pytest.mark.parametrize("include_zero_beam", [True, False])
+ def test_plot(self, multi_simulation, show_labels, units, include_zero_beam):
+ multi_simulation.phase_index = 0
+ multi_simulation.rotation_index = 0
+ multi_simulation.reciporical_radius = 2
+ multi_simulation.coordinates[0][0].intensity = np.nan
+ multi_simulation.plot(
+ show_labels=show_labels,
+ units=units,
+ min_label_intensity=0.0,
+ include_direct_beam=include_zero_beam,
+ calibration=0.1,
+ )
+
+ def test_plot_rotation(self, multi_simulation):
+ multi_simulation.plot_rotations()
+
+ def test_iter(self, multi_simulation):
+ multi_simulation.phase_index = 0
+ multi_simulation.rotation_index = 0
+ count = 0
+ for sim in multi_simulation:
+ count += 1
+ assert isinstance(sim, DiffractingVector)
+ assert count == 8
+
+ def test_get_diffraction_pattern(self, multi_simulation):
+ # No diffraction spots in this pattern
+ pat = multi_simulation.get_diffraction_pattern(
+ shape=(50, 50), calibration=0.001
+ )
+ assert pat.shape == (50, 50)
+ assert np.max(pat.data) == 0
+
+ def test_get_diffraction_pattern2(self, multi_simulation):
+ pat = multi_simulation.get_diffraction_pattern(
+ shape=(512, 512), calibration=0.01
+ )
+ assert pat.shape == (512, 512)
+ assert np.max(pat.data) == 1
+
+ def test_polar_flatten(self, multi_simulation):
+ (
+ r_templates,
+ theta_templates,
+ intensities_templates,
+ ) = multi_simulation.polar_flatten_simulations()
+ assert r_templates.shape == (8, 4)
+ assert theta_templates.shape == (8, 4)
+ assert intensities_templates.shape == (8, 4)
+
+ def test_rotate_shift_coords(self, multi_simulation):
+ rot = multi_simulation.rotate_shift_coordinates(angle=0.1)
+ assert isinstance(rot, DiffractingVector)
+
+
+class TestMultiPhaseSingleSimulation:
+ @pytest.fixture
+ def al_phase(self):
+ p = Phase(
+ name="al",
+ space_group=225,
+ structure=Structure(
+ atoms=[Atom("al", [0, 0, 0])],
+ lattice=Lattice(0.405, 0.405, 0.405, 90, 90, 90),
+ ),
+ )
+ return p
+
+ @pytest.fixture
+ def multi_simulation(self, al_phase):
+ gen = SimulationGenerator(accelerating_voltage=200)
+ rot = Rotation.from_axes_angles([1, 0, 0], (0,), degrees=True)
+ rot2 = rot
+ coords = DiffractingVector(
+ phase=al_phase,
+ xyz=[
+ [1, 0, 0],
+ [0, -0.3, 0],
+ [1 / 0.405, 1 / -0.405, 0],
+ [0.1, -0.1, -0.3],
+ ],
+ )
+ coords.intensity = 1
+ vectors = coords
+ al_phase2 = al_phase.deepcopy()
+ al_phase2.name = "al2"
+ sim = Simulation2D(
+ phases=[al_phase, al_phase2],
+ simulation_generator=gen,
+ coordinates=[vectors, vectors],
+ rotations=[rot, rot2],
+ )
+ return sim
+
+ def test_get_simulation(self, multi_simulation):
+ for i in range(2):
+ rotation, phase, coords = multi_simulation.get_simulation(i)
+ assert isinstance(rotation, Rotation)
+ assert phase == i
diff --git a/diffsims/utils/shape_factor_models.py b/diffsims/utils/shape_factor_models.py
index b8dc7c20..6aab0902 100644
--- a/diffsims/utils/shape_factor_models.py
+++ b/diffsims/utils/shape_factor_models.py
@@ -17,6 +17,7 @@
# along with diffsims. If not, see .
import numpy as np
+from scipy.integrate import quad
__all__ = [
@@ -217,3 +218,53 @@ def lorentzian_precession(
z = np.sqrt(u**2 + 4 * sigma**2 * excitation_error**2)
fac = (sigma / np.pi) * np.sqrt(2 * (u + z) / z**2)
return fac
+
+
+def _shape_factor_precession(
+ excitation_error, r_spot, phi, shape_function, max_excitation, **kwargs
+):
+ """
+ The rel-rod shape factors for reflections taking into account
+ precession
+
+ Parameters
+ ----------
+ excitation_error : np.ndarray (N,)
+ An array of excitation errors
+ r_spot : np.ndarray (N,)
+ An array representing the distance of spots from the z-axis in A^-1
+ phi : float
+ The precession angle in radians
+ shape_function : callable
+ A function that describes the influence from the rel-rods. Should be
+ in the form func(excitation_error: np.ndarray, max_excitation: float,
+ **kwargs)
+ max_excitation : float
+ Parameter to describe the "extent" of the rel-rods.
+
+ Other parameters
+ ----------------
+ ** kwargs: passed directly to shape_function
+
+ Notes
+ -----
+ * We calculate excitation_error as z_spot - z_sphere so that it is
+ negative when the spot is outside the ewald sphere and positive when inside
+ conform W&C chapter 12, section 12.6
+ * We assume that the sample is a thin infinitely wide slab perpendicular
+ to the optical axis, so that the shape factor function only depends on the
+ distance from each spot to the Ewald sphere parallel to the optical axis.
+ """
+ shf = np.zeros(excitation_error.shape)
+ # loop over all spots
+ for i, (excitation_error_i, r_spot_i) in enumerate(zip(excitation_error, r_spot)):
+
+ def integrand(theta):
+ # Equation 8 in L.Palatinus et al. Acta Cryst. (2019) B75, 512-522
+ S_zero = excitation_error_i
+ variable_term = r_spot_i * (phi) * np.cos(theta)
+ return shape_function(S_zero + variable_term, max_excitation, **kwargs)
+
+ # average factor integrated over the full revolution of the beam
+ shf[i] = (1 / (2 * np.pi)) * quad(integrand, 0, 2 * np.pi)[0]
+ return shf
diff --git a/doc/reference/index.rst b/doc/reference/index.rst
index 91fbe73e..6bfe0428 100644
--- a/doc/reference/index.rst
+++ b/doc/reference/index.rst
@@ -29,5 +29,6 @@ the `demos `_.
libraries
pattern
sims
+ simulations
structure_factor
utils
diff --git a/setup.cfg b/setup.cfg
index 46204e6d..b804aa93 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -28,4 +28,5 @@ known_excludes =
.git/**
doc/build/**
htmlcov/**
- *.code-workspace
\ No newline at end of file
+ *.code-workspace
+ **/*.npy
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 1b9de6c6..1ba7c559 100644
--- a/setup.py
+++ b/setup.py
@@ -78,7 +78,7 @@
"matplotlib >= 3.3",
"numba",
"numpy >= 1.17.3",
- "orix >= 0.9",
+ "orix >= 0.12.1",
"psutil",
"scipy >= 1.8",
"tqdm >= 4.9",