diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index c9efae806..9ef1ba4dd 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -239,6 +239,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -295,6 +296,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -534,8 +537,10 @@ def preprocess( else: self._object = obj - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] self._num_voxels = self._object.shape[1] @@ -565,10 +570,14 @@ def preprocess( # initialize probe self._probes_all = [] - self._probes_all_initial = [] - self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( self._probe_init[index] if list_Q else self._probe_init, @@ -579,8 +588,9 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial.append(_probe.copy()) - self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 2e887739f..01326a08b 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -219,6 +219,7 @@ def preprocess( progress_bar: bool = True, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -285,6 +286,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -597,8 +600,9 @@ def preprocess( else: self._object = obj - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] # center probe positions @@ -627,10 +631,14 @@ def preprocess( # initialize probe self._probes_all = [] - self._probes_all_initial = [] - self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( self._probe_init[index] if list_Q else self._probe_init, @@ -641,8 +649,9 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial.append(_probe.copy()) - self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 119fc3a3c..329443564 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -278,6 +278,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -347,6 +348,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional If not None, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -516,8 +519,9 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] # center probe positions @@ -553,8 +557,11 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index bd650a931..d3213b49d 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -224,6 +224,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -293,6 +294,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -461,8 +464,9 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape # center probe positions @@ -498,8 +502,11 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None if object_fov_mask is None or plot_probe_overlaps: # overlaps diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index db17cb1a8..3a245d367 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -252,6 +252,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -321,6 +322,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional If not None, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -490,8 +493,9 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] # center probe positions @@ -527,8 +531,11 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 060a151aa..cdf32e0ae 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -266,6 +266,7 @@ def preprocess( vectorized_com_calculation: bool = True, device: str = None, clear_fft_cache: bool = None, + store_initial_arrays: bool = True, **kwargs, ): """ @@ -306,7 +307,9 @@ def preprocess( device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional - if true, and device = 'gpu', clears the cached fft plan at the end of function calls + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + store_initial_arrays: bool, optional + If True, stores a copy of the arrays necessary to reinitialize in reconstruct Returns -------- @@ -330,15 +333,15 @@ def preprocess( ) # extract calibrations - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities) + intensities = xp.asarray(intensities) - self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) - self._scan_shape = np.array(self._intensities.shape[:2]) + self._region_of_interest_shape = np.array(intensities.shape[-2:]) + self._scan_shape = np.array(intensities.shape[:2]) # descan correction if descan_correction_fit_function is not None: @@ -350,7 +353,7 @@ def preprocess( _, _, ) = self._calculate_intensities_center_of_mass( - self._intensities, + intensities, dp_mask=None, fit_function=descan_correction_fit_function, com_shifts=None, @@ -360,8 +363,8 @@ def preprocess( com_fitted_x = asnumpy(com_fitted_x) com_fitted_y = asnumpy(com_fitted_y) - intensities = asnumpy(self._intensities) - intensities_shifted = np.zeros_like(intensities) + intensities_np = asnumpy(intensities) + intensities_shifted = np.zeros_like(intensities_np) center_x = com_fitted_x.mean() center_y = com_fitted_y.mean() @@ -369,7 +372,7 @@ def preprocess( for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): intensity_shifted = get_shifted_ar( - intensities[rx, ry], + intensities_np[rx, ry], -com_fitted_x[rx, ry] + center_x, -com_fitted_y[rx, ry] + center_y, bilinear=True, @@ -378,12 +381,12 @@ def preprocess( intensities_shifted[rx, ry] = intensity_shifted - self._intensities = xp.asarray(intensities_shifted, xp.float32) + intensities = xp.asarray(intensities_shifted, xp.float32) if dp_mask is not None: self._dp_mask = xp.asarray(dp_mask) else: - dp_mean = self._intensities.mean((0, 1)) + dp_mean = intensities.mean((0, 1)) self._dp_mask = dp_mean >= (xp.max(dp_mean) * threshold_intensity) # select virtual detector pixels @@ -454,7 +457,7 @@ def preprocess( # Collect BF images all_bfs = xp.moveaxis( - self._intensities[:, :, self._xy_inds[:, 0], self._xy_inds[:, 1]], + intensities[:, :, self._xy_inds[:, 0], self._xy_inds[:, 1]], (0, 1, 2), (1, 2, 0), ) @@ -646,13 +649,16 @@ def preprocess( Gs = xp.fft.fft2(self._stack_BF_shifted) self._xy_shifts = ( - -self._probe_angles * defocus_guess / xp.array(self._scan_sampling) + -self._probe_angles + * defocus_guess + / xp.array(self._scan_sampling, dtype=xp.float32) ) if rotation_guess: angle = xp.deg2rad(rotation_guess) rotation_matrix = xp.array( - [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]] + [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]], + dtype=xp.float32, ) self._xy_shifts = xp.dot(self._xy_shifts, rotation_matrix) @@ -693,11 +699,12 @@ def preprocess( / self._mask_sum ) - self._recon_BF_initial = self._recon_BF.copy() - self._stack_BF_shifted_initial = self._stack_BF_shifted.copy() - self._stack_mask_initial = self._stack_mask.copy() - self._recon_mask_initial = self._recon_mask.copy() - self._xy_shifts_initial = self._xy_shifts.copy() + if store_initial_arrays: + self._recon_BF_initial = self._recon_BF.copy() + self._stack_BF_shifted_initial = self._stack_BF_shifted.copy() + self._stack_mask_initial = self._stack_mask.copy() + self._recon_mask_initial = self._recon_mask.copy() + self._xy_shifts_initial = self._xy_shifts.copy() self.recon_BF = asnumpy(self._recon_BF) @@ -2727,7 +2734,7 @@ def _crop_padded_object( pad_x_left = np.round( self._object_padding_px[0] / 2 * self._kde_upsample_factor ).astype("int") - pad_x_right = pad_x - pad_x_left + pad_x_right = pad_x_left - pad_x pad_y = np.round( self._object_padding_px[1] * self._kde_upsample_factor @@ -2735,20 +2742,27 @@ def _crop_padded_object( pad_y_left = np.round( self._object_padding_px[1] / 2 * self._kde_upsample_factor ).astype("int") - pad_y_right = pad_y - pad_y_left + pad_y_right = pad_y_left - pad_y else: pad_x_left = self._object_padding_px[0] // 2 - pad_x_right = self._object_padding_px[0] - pad_x_left + pad_x_right = pad_x_left - self._object_padding_px[0] pad_y_left = self._object_padding_px[1] // 2 - pad_y_right = self._object_padding_px[1] - pad_y_left + pad_y_right = pad_y_left - self._object_padding_px[1] pad_x_left -= remaining_padding - pad_x_right -= remaining_padding + pad_x_right += remaining_padding pad_y_left -= remaining_padding - pad_y_right -= remaining_padding + pad_y_right += remaining_padding + + sx = slice( + pad_x_left if pad_x_left else None, pad_x_right if pad_x_right else None + ) + sy = slice( + pad_y_left if pad_y_left else None, pad_y_right if pad_y_right else None + ) - return asnumpy(padded_object[pad_x_left:-pad_x_right, pad_y_left:-pad_y_right]) + return asnumpy(padded_object[sx, sy]) def _visualize_figax( self, diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 652e1046e..9ef93af22 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -381,49 +381,24 @@ def visualize( fig = plt.figure(figsize=figsize) ax = fig.add_subplot(spec[0]) - skopt_plot_gaussian_process(self._skopt_result, ax=ax, **kwargs) + skopt_plot_gaussian_process( + self._skopt_result, ax=ax, show_title=False, **kwargs + ) if plot_convergence: ax = fig.add_subplot(spec[1]) skopt_plot_convergence(self._skopt_result, ax=ax) - else: - if plot_convergence: - figsize = kwargs.pop("figsize", (4 * ndims, 4 * (ndims + 0.5))) - spec = GridSpec( - nrows=ndims + 1, - ncols=ndims, - height_ratios=[2] * ndims + [1], - hspace=0.15, - ) - else: - figsize = kwargs.pop("figsize", (4 * ndims, 4 * ndims)) - spec = GridSpec(nrows=ndims, ncols=ndims, hspace=0.15) + spec.tight_layout(fig) + else: if plot_evaluations: - axs = skopt_plot_evaluations(self._skopt_result) + skopt_plot_evaluations(self._skopt_result) elif plot_objective: cmap = kwargs.pop("cmap", "magma") - axs = skopt_plot_objective(self._skopt_result, cmap=cmap, **kwargs) + skopt_plot_objective(self._skopt_result, cmap=cmap, **kwargs) elif plot_convergence: skopt_plot_convergence(self._skopt_result) - return self - - fig = axs[0, 0].figure - fig.set_size_inches(figsize) - for i in range(ndims): - for j in range(ndims): - ax = axs[i, j] - ax.remove() - ax.figure = fig - fig.add_axes(ax) - ax.set_subplotspec(spec[i, j]) - - if plot_convergence: - ax = fig.add_subplot(spec[ndims, :]) - skopt_plot_convergence(self._skopt_result, ax=ax) - - spec.tight_layout(fig) return self diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 36e0c598a..22fda4ac9 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1406,7 +1406,7 @@ def _normalize_diffraction_intensities( ) crop_w = np.minimum(crop_y, crop_x) - region_of_interest_shape = (crop_w * 2, crop_w * 2) + diffraction_intensities_shape_crop = (crop_w * 2, crop_w * 2) amplitudes = np.zeros( ( number_of_patterns, @@ -1424,9 +1424,10 @@ def _normalize_diffraction_intensities( else: crop_mask = None - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities_shape_crop = diffraction_intensities.shape[-2:] amplitudes = np.zeros( - (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + (number_of_patterns,) + diffraction_intensities_shape_crop, + dtype=np.float32, ) counter = 0 @@ -1449,7 +1450,9 @@ def _normalize_diffraction_intensities( ) if crop_patterns: - intensities = intensities[crop_mask].reshape(region_of_interest_shape) + intensities = intensities[crop_mask].reshape( + diffraction_intensities_shape_crop + ) mean_intensity += np.sum(intensities) amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) @@ -1457,6 +1460,8 @@ def _normalize_diffraction_intensities( mean_intensity /= amplitudes.shape[0] + self._diffraction_intensities_shape_crop = diffraction_intensities_shape_crop + return amplitudes, mean_intensity, crop_mask def show_complex_CoM( @@ -1589,6 +1594,7 @@ def to_h5(self, group): "data_transpose": self._rotation_best_transpose, "positions_px": asnumpy(self._positions_px), "region_of_interest_shape": self._region_of_interest_shape, + "amplitudes_shape": self._amplitudes_shape, "num_diffraction_patterns": self._num_diffraction_patterns, "sampling": self.sampling, "angular_sampling": self.angular_sampling, @@ -1730,6 +1736,7 @@ def _populate_instance(self, group): self._positions_px = xp.asarray(preprocess_md["positions_px"]) self._angular_sampling = preprocess_md["angular_sampling"] self._region_of_interest_shape = preprocess_md["region_of_interest_shape"] + self._amplitudes_shape = preprocess_md["amplitudes_shape"] self._num_diffraction_patterns = preprocess_md["num_diffraction_patterns"] self._positions_mask = preprocess_md["positions_mask"] diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index fa0b1db9f..0c1273123 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -277,6 +277,11 @@ def _reset_reconstruction( else: self.error_iterations = [] self._exit_waves = None + else: + # reset=False first start + if not hasattr(self, "error"): + self.error_iterations = [] + self._exit_waves = None @property def object_fft(self): @@ -1000,17 +1005,6 @@ def _initialize_probe( vacuum_probe_intensity, dtype=xp.float32 ) - sx, sy = vacuum_probe_intensity.shape - tx, ty = region_of_interest_shape - if sx != tx or sy != ty: - vacuum_probe_intensity = bilinear_resample( - vacuum_probe_intensity, - output_size=(tx, ty), - vectorized=True, - conserve_array_sums=True, - xp=xp, - ) - probe_x0, probe_y0 = get_CoM( vacuum_probe_intensity, device=device, @@ -1025,7 +1019,18 @@ def _initialize_probe( if crop_patterns: vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( - region_of_interest_shape + self._diffraction_intensities_shape_crop + ) + + sx, sy = vacuum_probe_intensity.shape + tx, ty = region_of_interest_shape + if sx != tx or sy != ty and self._resample_exit_waves is True: + vacuum_probe_intensity = bilinear_resample( + vacuum_probe_intensity, + output_size=(tx, ty), + vectorized=True, + conserve_array_sums=True, + xp=xp, ) _probe = ( @@ -3413,6 +3418,14 @@ def _reset_reconstruction( self._exit_waves = [None] * len(self._probes_all) else: self._exit_waves = None + else: + # reset=False first start + if not hasattr(self, "error"): + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None def _return_single_probe(self, probe=None): """Current probe estimate""" diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index f3b2991ab..6bd84c374 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -234,6 +234,7 @@ def preprocess( object_fov_mask: np.ndarray = True, crop_patterns: bool = False, main_tilt_axis: str = "vertical", + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -294,6 +295,8 @@ def preprocess( The default, 'vertical' (first scan dimension), results in object size (q,p,q), 'horizontal' (second scan dimension) results in object size (p,p,q), any other value (e.g. None) results in object size (max(p,q),p,q). + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -528,8 +531,10 @@ def preprocess( main_tilt_axis, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] self._num_voxels = self._object.shape[0] @@ -559,10 +564,14 @@ def preprocess( # initialize probe self._probes_all = [] - self._probes_all_initial = [] - self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( self._probe_init[index] if list_Q else self._probe_init, @@ -573,8 +582,9 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial.append(_probe.copy()) - self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index e1f14a90f..973ca8ece 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -206,6 +206,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -265,6 +266,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -434,8 +437,10 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape # center probe positions @@ -471,8 +476,11 @@ def preprocess( device=device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None if object_fov_mask is None or plot_probe_overlaps: # overlaps diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a5a541795..5742ff7e7 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1874,25 +1874,40 @@ def bilinearly_interpolate_array( dx = xa - xF dy = ya - yF - all_inds = [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - - all_weights = [ - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ] + # all_inds = [ + # [xF, yF], + # [xF + 1, yF], + # [xF, yF + 1], + # [xF + 1, yF + 1], + # ] + + # all_weights = [ + # (1 - dx) * (1 - dy), + # (dx) * (1 - dy), + # (1 - dx) * (dy), + # (dx) * (dy), + # ] raveled_image = image.ravel() intensities = xp.zeros(xa.shape, dtype=xp.float32) # filter_weights = xp.zeros(xa.shape, dtype=xp.float32) - for inds, weights in zip(all_inds, all_weights): + # for inds, weights in zip(all_inds, all_weights): + for basis_index in range(4): + match basis_index: + case 0: + inds = [xF, yF] + weights = (1 - dx) * (1 - dy) + case 1: + inds = [xF + 1, yF] + weights = (dx) * (1 - dy) + case 2: + inds = [xF, yF + 1] + weights = (1 - dx) * (dy) + case 3: + inds = [xF + 1, yF + 1] + weights = (dx) * (dy) + intensities += ( raveled_image[ xp.ravel_multi_index( @@ -1940,33 +1955,28 @@ def lanczos_interpolate_array( dx = xa - xF dy = ya - yF - all_inds = [] - all_weights = [] + raveled_image = image.ravel() + intensities = xp.zeros(xa.shape, dtype=xp.float32) + filter_weights = xp.zeros(xa.shape, dtype=xp.float32) for i in range(-alpha + 1, alpha + 1): for j in range(-alpha + 1, alpha + 1): - all_inds.append([xF + i, yF + j]) - all_weights.append( - (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) - * (xp.sinc(j - dy) * xp.sinc((i - dy) / alpha)) + inds = [xF + i, yF + j] + weights = (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) * ( + xp.sinc(j - dy) * xp.sinc((i - dy) / alpha) ) - raveled_image = image.ravel() - intensities = xp.zeros(xa.shape, dtype=xp.float32) - filter_weights = xp.zeros(xa.shape, dtype=xp.float32) - - for inds, weights in zip(all_inds, all_weights): - intensities += ( - raveled_image[ - xp.ravel_multi_index( - inds, - image.shape, - mode=["wrap", "wrap"], - ) - ] - * weights - ) - filter_weights += weights + intensities += ( + raveled_image[ + xp.ravel_multi_index( + inds, + image.shape, + mode=["wrap", "wrap"], + ) + ] + * weights + ) + filter_weights += weights return intensities / filter_weights @@ -2080,25 +2090,40 @@ def bilinear_kernel_density_estimate( dx = xa.ravel() - xF dy = ya.ravel() - yF - all_inds = [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] + # all_inds = [ + # [xF, yF], + # [xF + 1, yF], + # [xF, yF + 1], + # [xF + 1, yF + 1], + # ] - all_weights = [ - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ] + # all_weights = [ + # (1 - dx) * (1 - dy), + # (dx) * (1 - dy), + # (1 - dx) * (dy), + # (dx) * (dy), + # ] raveled_intensities = intensities.ravel() pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) - for inds, weights in zip(all_inds, all_weights): + # for inds, weights in zip(all_inds, all_weights): + for basis_index in range(4): + match basis_index: + case 0: + inds = [xF, yF] + weights = (1 - dx) * (1 - dy) + case 1: + inds = [xF + 1, yF] + weights = (dx) * (1 - dy) + case 2: + inds = [xF, yF + 1] + weights = (1 - dx) * (dy) + case 3: + inds = [xF + 1, yF + 1] + weights = (dx) * (dy) + inds_1D = xp.ravel_multi_index( inds, output_shape, @@ -2185,38 +2210,33 @@ def lanczos_kernel_density_estimate( dx = xa.ravel() - xF dy = ya.ravel() - yF - all_inds = [] - all_weights = [] + raveled_intensities = intensities.ravel() + pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) + pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) for i in range(-alpha + 1, alpha + 1): for j in range(-alpha + 1, alpha + 1): - all_inds.append([xF + i, yF + j]) - all_weights.append( - (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) - * (xp.sinc(j - dy) * xp.sinc((i - dy) / alpha)) + inds = [xF + i, yF + j] + weights = (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) * ( + xp.sinc(j - dy) * xp.sinc((i - dy) / alpha) ) - raveled_intensities = intensities.ravel() - pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) - pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) - - for inds, weights in zip(all_inds, all_weights): - inds_1D = xp.ravel_multi_index( - inds, - output_shape, - mode=["wrap", "wrap"], - ) + inds_1D = xp.ravel_multi_index( + inds, + output_shape, + mode=["wrap", "wrap"], + ) - pix_count += xp.bincount( - inds_1D, - weights=weights, - minlength=np.prod(output_shape), - ) - pix_output += xp.bincount( - inds_1D, - weights=weights * raveled_intensities, - minlength=np.prod(output_shape), - ) + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=np.prod(output_shape), + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * raveled_intensities, + minlength=np.prod(output_shape), + ) # reshape 1D arrays to 2D pix_count = xp.reshape( diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 7430992e0..dcb1cf285 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -78,6 +78,7 @@ def show( show_fft=False, apply_hanning_window=True, show_cbar=False, + interpolation=None, **kwargs, ): """ @@ -615,7 +616,14 @@ def show( # Plot the image if not hist: - cax = ax.matshow(_ar, vmin=vmin, vmax=vmax, cmap=cm, **kwargs) + cax = ax.matshow( + _ar, + vmin=vmin, + vmax=vmax, + cmap=cm, + interpolation=interpolation, + **kwargs, + ) if np.any(_ar.mask): mask_display = np.ma.array(data=_ar.data, mask=~_ar.mask) ax.matshow( @@ -623,7 +631,7 @@ def show( ) if show_cbar: ax_divider = make_axes_locatable(ax) - c_axis = ax_divider.append_axes("right", size="7%") + c_axis = ax_divider.append_axes("right", size="5%", pad="2.5%") fig.colorbar(cax, cax=c_axis) # ...or, plot its histogram else: @@ -806,6 +814,7 @@ def show( ax.set_yticks([]) # Show or return + returnval = [] if returnfig: returnval.append((fig, ax)) @@ -822,6 +831,7 @@ def show( returnval.append(cax) if len(returnval) == 0: if figax is None: + plt.tight_layout() plt.show() return elif (len(returnval)) == 1: