From ef827f89340cf16fa034287f1d3a4bac7fc1a103 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 22 Mar 2024 09:35:29 -0700 Subject: [PATCH 01/14] making kdes less memory-intensive --- py4DSTEM/process/phase/utils.py | 171 ++++++++++++++++++-------------- 1 file changed, 96 insertions(+), 75 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a5a541795..7a39d2e63 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1874,25 +1874,40 @@ def bilinearly_interpolate_array( dx = xa - xF dy = ya - yF - all_inds = [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - - all_weights = [ - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ] + # all_inds = [ + # [xF, yF], + # [xF + 1, yF], + # [xF, yF + 1], + # [xF + 1, yF + 1], + # ] + + # all_weights = [ + # (1 - dx) * (1 - dy), + # (dx) * (1 - dy), + # (1 - dx) * (dy), + # (dx) * (dy), + # ] raveled_image = image.ravel() intensities = xp.zeros(xa.shape, dtype=xp.float32) # filter_weights = xp.zeros(xa.shape, dtype=xp.float32) - for inds, weights in zip(all_inds, all_weights): + # for inds, weights in zip(all_inds, all_weights): + for basis_index in range(4): + match basis_index: + case 0: + inds = [xF, yF] + weights = (1 - dx) * (1 - dy) + case 1: + inds = [xF + 1, yF] + weights = (dx) * (1 - dy) + case 2: + inds = [xF, yF + 1] + weights = (1 - dx) * (dy) + case 3: + inds = [xF + 1, yF + 1] + weights = (dx) * (dy) + intensities += ( raveled_image[ xp.ravel_multi_index( @@ -1940,33 +1955,29 @@ def lanczos_interpolate_array( dx = xa - xF dy = ya - yF - all_inds = [] - all_weights = [] + raveled_image = image.ravel() + intensities = xp.zeros(xa.shape, dtype=xp.float32) + filter_weights = xp.zeros(xa.shape, dtype=xp.float32) for i in range(-alpha + 1, alpha + 1): for j in range(-alpha + 1, alpha + 1): - all_inds.append([xF + i, yF + j]) - all_weights.append( - (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) - * (xp.sinc(j - dy) * xp.sinc((i - dy) / alpha)) - ) - raveled_image = image.ravel() - intensities = xp.zeros(xa.shape, dtype=xp.float32) - filter_weights = xp.zeros(xa.shape, dtype=xp.float32) + inds = [xF + i, yF + j] + weights = (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) * ( + xp.sinc(j - dy) * xp.sinc((i - dy) / alpha) + ) - for inds, weights in zip(all_inds, all_weights): - intensities += ( - raveled_image[ - xp.ravel_multi_index( - inds, - image.shape, - mode=["wrap", "wrap"], - ) - ] - * weights - ) - filter_weights += weights + intensities += ( + raveled_image[ + xp.ravel_multi_index( + inds, + image.shape, + mode=["wrap", "wrap"], + ) + ] + * weights + ) + filter_weights += weights return intensities / filter_weights @@ -2080,25 +2091,40 @@ def bilinear_kernel_density_estimate( dx = xa.ravel() - xF dy = ya.ravel() - yF - all_inds = [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] + # all_inds = [ + # [xF, yF], + # [xF + 1, yF], + # [xF, yF + 1], + # [xF + 1, yF + 1], + # ] - all_weights = [ - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ] + # all_weights = [ + # (1 - dx) * (1 - dy), + # (dx) * (1 - dy), + # (1 - dx) * (dy), + # (dx) * (dy), + # ] raveled_intensities = intensities.ravel() pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) - for inds, weights in zip(all_inds, all_weights): + # for inds, weights in zip(all_inds, all_weights): + for basis_index in range(4): + match basis_index: + case 0: + inds = [xF, yF] + weights = (1 - dx) * (1 - dy) + case 1: + inds = [xF + 1, yF] + weights = (dx) * (1 - dy) + case 2: + inds = [xF, yF + 1] + weights = (1 - dx) * (dy) + case 3: + inds = [xF + 1, yF + 1] + weights = (dx) * (dy) + inds_1D = xp.ravel_multi_index( inds, output_shape, @@ -2185,38 +2211,33 @@ def lanczos_kernel_density_estimate( dx = xa.ravel() - xF dy = ya.ravel() - yF - all_inds = [] - all_weights = [] + raveled_intensities = intensities.ravel() + pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) + pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) for i in range(-alpha + 1, alpha + 1): for j in range(-alpha + 1, alpha + 1): - all_inds.append([xF + i, yF + j]) - all_weights.append( - (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) - * (xp.sinc(j - dy) * xp.sinc((i - dy) / alpha)) + inds = [xF + i, yF + j] + weights = (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) * ( + xp.sinc(j - dy) * xp.sinc((i - dy) / alpha) ) - raveled_intensities = intensities.ravel() - pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) - pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) - - for inds, weights in zip(all_inds, all_weights): - inds_1D = xp.ravel_multi_index( - inds, - output_shape, - mode=["wrap", "wrap"], - ) + inds_1D = xp.ravel_multi_index( + inds, + output_shape, + mode=["wrap", "wrap"], + ) - pix_count += xp.bincount( - inds_1D, - weights=weights, - minlength=np.prod(output_shape), - ) - pix_output += xp.bincount( - inds_1D, - weights=weights * raveled_intensities, - minlength=np.prod(output_shape), - ) + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=np.prod(output_shape), + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * raveled_intensities, + minlength=np.prod(output_shape), + ) # reshape 1D arrays to 2D pix_count = xp.reshape( From ed4465332f3554c16881a5fafeadacbd5f3e44a5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 22 Mar 2024 09:35:58 -0700 Subject: [PATCH 02/14] making parallax less memory intensive --- py4DSTEM/process/phase/parallax.py | 45 +++++++++++++++++------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 060a151aa..327cabac9 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -266,6 +266,7 @@ def preprocess( vectorized_com_calculation: bool = True, device: str = None, clear_fft_cache: bool = None, + store_initial_arrays: bool = True, **kwargs, ): """ @@ -306,7 +307,9 @@ def preprocess( device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional - if true, and device = 'gpu', clears the cached fft plan at the end of function calls + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + store_initial_arrays: bool, optional + If True, stores a copy of the arrays necessary to reinitialize in reconstruct Returns -------- @@ -330,15 +333,15 @@ def preprocess( ) # extract calibrations - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities) + intensities = xp.asarray(intensities) - self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) - self._scan_shape = np.array(self._intensities.shape[:2]) + self._region_of_interest_shape = np.array(intensities.shape[-2:]) + self._scan_shape = np.array(intensities.shape[:2]) # descan correction if descan_correction_fit_function is not None: @@ -350,7 +353,7 @@ def preprocess( _, _, ) = self._calculate_intensities_center_of_mass( - self._intensities, + intensities, dp_mask=None, fit_function=descan_correction_fit_function, com_shifts=None, @@ -360,8 +363,8 @@ def preprocess( com_fitted_x = asnumpy(com_fitted_x) com_fitted_y = asnumpy(com_fitted_y) - intensities = asnumpy(self._intensities) - intensities_shifted = np.zeros_like(intensities) + intensities_np = asnumpy(intensities) + intensities_shifted = np.zeros_like(intensities_np) center_x = com_fitted_x.mean() center_y = com_fitted_y.mean() @@ -369,7 +372,7 @@ def preprocess( for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): intensity_shifted = get_shifted_ar( - intensities[rx, ry], + intensities_np[rx, ry], -com_fitted_x[rx, ry] + center_x, -com_fitted_y[rx, ry] + center_y, bilinear=True, @@ -378,12 +381,12 @@ def preprocess( intensities_shifted[rx, ry] = intensity_shifted - self._intensities = xp.asarray(intensities_shifted, xp.float32) + intensities = xp.asarray(intensities_shifted, xp.float32) if dp_mask is not None: self._dp_mask = xp.asarray(dp_mask) else: - dp_mean = self._intensities.mean((0, 1)) + dp_mean = intensities.mean((0, 1)) self._dp_mask = dp_mean >= (xp.max(dp_mean) * threshold_intensity) # select virtual detector pixels @@ -454,7 +457,7 @@ def preprocess( # Collect BF images all_bfs = xp.moveaxis( - self._intensities[:, :, self._xy_inds[:, 0], self._xy_inds[:, 1]], + intensities[:, :, self._xy_inds[:, 0], self._xy_inds[:, 1]], (0, 1, 2), (1, 2, 0), ) @@ -646,13 +649,16 @@ def preprocess( Gs = xp.fft.fft2(self._stack_BF_shifted) self._xy_shifts = ( - -self._probe_angles * defocus_guess / xp.array(self._scan_sampling) + -self._probe_angles + * defocus_guess + / xp.array(self._scan_sampling, dtype=xp.float32) ) if rotation_guess: angle = xp.deg2rad(rotation_guess) rotation_matrix = xp.array( - [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]] + [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]], + dtype=xp.float32, ) self._xy_shifts = xp.dot(self._xy_shifts, rotation_matrix) @@ -693,11 +699,12 @@ def preprocess( / self._mask_sum ) - self._recon_BF_initial = self._recon_BF.copy() - self._stack_BF_shifted_initial = self._stack_BF_shifted.copy() - self._stack_mask_initial = self._stack_mask.copy() - self._recon_mask_initial = self._recon_mask.copy() - self._xy_shifts_initial = self._xy_shifts.copy() + if store_initial_arrays: + self._recon_BF_initial = self._recon_BF.copy() + self._stack_BF_shifted_initial = self._stack_BF_shifted.copy() + self._stack_mask_initial = self._stack_mask.copy() + self._recon_mask_initial = self._recon_mask.copy() + self._xy_shifts_initial = self._xy_shifts.copy() self.recon_BF = asnumpy(self._recon_BF) From 0db84be5a67bca157628d8a537fe56c1412260e8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 22 Mar 2024 10:13:42 -0700 Subject: [PATCH 03/14] fix zero padding cropping, cleaned up viz --- py4DSTEM/process/phase/parallax.py | 42 +++++++++++++----------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 327cabac9..d1567eebc 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2535,25 +2535,12 @@ def aberration_correct( # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object( - self._recon_phase_corrected, upsampled=upsampled - ) - - extent = [ - 0, - sy * cropped_object.shape[1], - sx * cropped_object.shape[0], - 0, - ] - - ax.imshow( - cropped_object, - extent=extent, - cmap=cmap, + self._visualize_figax( + fig, + ax, + upsampled=upsampled, **kwargs, ) @@ -2734,7 +2721,7 @@ def _crop_padded_object( pad_x_left = np.round( self._object_padding_px[0] / 2 * self._kde_upsample_factor ).astype("int") - pad_x_right = pad_x - pad_x_left + pad_x_right = pad_x_left - pad_x pad_y = np.round( self._object_padding_px[1] * self._kde_upsample_factor @@ -2742,20 +2729,27 @@ def _crop_padded_object( pad_y_left = np.round( self._object_padding_px[1] / 2 * self._kde_upsample_factor ).astype("int") - pad_y_right = pad_y - pad_y_left + pad_y_right = pad_y_left - pad_y else: pad_x_left = self._object_padding_px[0] // 2 - pad_x_right = self._object_padding_px[0] - pad_x_left + pad_x_right = pad_x_left - self._object_padding_px[0] pad_y_left = self._object_padding_px[1] // 2 - pad_y_right = self._object_padding_px[1] - pad_y_left + pad_y_right = pad_y_left - self._object_padding_px[1] pad_x_left -= remaining_padding - pad_x_right -= remaining_padding + pad_x_right += remaining_padding pad_y_left -= remaining_padding - pad_y_right -= remaining_padding + pad_y_right += remaining_padding + + sx = slice( + pad_x_left if pad_x_left else None, pad_x_right if pad_x_right else None + ) + sy = slice( + pad_y_left if pad_y_left else None, pad_y_right if pad_y_right else None + ) - return asnumpy(padded_object[pad_x_left:-pad_x_right, pad_y_left:-pad_y_right]) + return asnumpy(padded_object[sx, sy]) def _visualize_figax( self, From a972b165f0da8d8610e9522923f311bfc308e0fe Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 27 Mar 2024 09:59:20 -0700 Subject: [PATCH 04/14] fix plotting for BO --- py4DSTEM/process/phase/parameter_optimize.py | 4 +++- py4DSTEM/process/phase/utils.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 652e1046e..44deba2f5 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -381,7 +381,9 @@ def visualize( fig = plt.figure(figsize=figsize) ax = fig.add_subplot(spec[0]) - skopt_plot_gaussian_process(self._skopt_result, ax=ax, **kwargs) + skopt_plot_gaussian_process( + self._skopt_result, ax=ax, show_title=False, **kwargs + ) if plot_convergence: ax = fig.add_subplot(spec[1]) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 7a39d2e63..5742ff7e7 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1961,7 +1961,6 @@ def lanczos_interpolate_array( for i in range(-alpha + 1, alpha + 1): for j in range(-alpha + 1, alpha + 1): - inds = [xF + i, yF + j] weights = (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) * ( xp.sinc(j - dy) * xp.sinc((i - dy) / alpha) From f1330b78e4348d4dd8c0a6ab5d6bb2d048471caf Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 30 Mar 2024 11:45:55 -0700 Subject: [PATCH 05/14] resample patterns condition change --- py4DSTEM/process/phase/ptychographic_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index fa0b1db9f..b6e22de0e 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1002,7 +1002,7 @@ def _initialize_probe( sx, sy = vacuum_probe_intensity.shape tx, ty = region_of_interest_shape - if sx != tx or sy != ty: + if sx != tx or sy != ty and self._resample_exit_waves is True: vacuum_probe_intensity = bilinear_resample( vacuum_probe_intensity, output_size=(tx, ty), From 8397960903c565f404bd4011b8a591da650309c3 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 30 Mar 2024 14:02:25 -0700 Subject: [PATCH 06/14] ROI and crop patterns --- py4DSTEM/process/phase/phase_base_class.py | 13 ++++++---- .../process/phase/ptychographic_methods.py | 24 +++++++++---------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index 36e0c598a..f99fe2e9e 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1406,7 +1406,7 @@ def _normalize_diffraction_intensities( ) crop_w = np.minimum(crop_y, crop_x) - region_of_interest_shape = (crop_w * 2, crop_w * 2) + diffraction_intensities_shape_crop = (crop_w * 2, crop_w * 2) amplitudes = np.zeros( ( number_of_patterns, @@ -1424,9 +1424,10 @@ def _normalize_diffraction_intensities( else: crop_mask = None - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities_shape_crop = diffraction_intensities.shape[-2:] amplitudes = np.zeros( - (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + (number_of_patterns,) + diffraction_intensities_shape_crop, + dtype=np.float32, ) counter = 0 @@ -1449,7 +1450,9 @@ def _normalize_diffraction_intensities( ) if crop_patterns: - intensities = intensities[crop_mask].reshape(region_of_interest_shape) + intensities = intensities[crop_mask].reshape( + diffraction_intensities_shape_crop + ) mean_intensity += np.sum(intensities) amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) @@ -1457,6 +1460,8 @@ def _normalize_diffraction_intensities( mean_intensity /= amplitudes.shape[0] + self._diffraction_intensities_shape_crop = diffraction_intensities_shape_crop + return amplitudes, mean_intensity, crop_mask def show_complex_CoM( diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index b6e22de0e..14d8a82b1 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1000,17 +1000,6 @@ def _initialize_probe( vacuum_probe_intensity, dtype=xp.float32 ) - sx, sy = vacuum_probe_intensity.shape - tx, ty = region_of_interest_shape - if sx != tx or sy != ty and self._resample_exit_waves is True: - vacuum_probe_intensity = bilinear_resample( - vacuum_probe_intensity, - output_size=(tx, ty), - vectorized=True, - conserve_array_sums=True, - xp=xp, - ) - probe_x0, probe_y0 = get_CoM( vacuum_probe_intensity, device=device, @@ -1025,7 +1014,18 @@ def _initialize_probe( if crop_patterns: vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( - region_of_interest_shape + self._diffraction_intensities_shape_crop + ) + + sx, sy = vacuum_probe_intensity.shape + tx, ty = region_of_interest_shape + if sx != tx or sy != ty and self._resample_exit_waves is True: + vacuum_probe_intensity = bilinear_resample( + vacuum_probe_intensity, + output_size=(tx, ty), + vectorized=True, + conserve_array_sums=True, + xp=xp, ) _probe = ( From 29f1596260ecdcfec1d81d66ac47a4aaf7749389 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 8 Apr 2024 11:31:43 -0700 Subject: [PATCH 07/14] corrected phase bugfix --- py4DSTEM/process/phase/parallax.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index d1567eebc..cdf32e0ae 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2535,12 +2535,25 @@ def aberration_correct( # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "magma") + fig, ax = plt.subplots(figsize=figsize) - self._visualize_figax( - fig, - ax, - upsampled=upsampled, + cropped_object = self._crop_padded_object( + self._recon_phase_corrected, upsampled=upsampled + ) + + extent = [ + 0, + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], + 0, + ] + + ax.imshow( + cropped_object, + extent=extent, + cmap=cmap, **kwargs, ) From dce15c8401dd17be589f405ce184fa9a576bdf3b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 May 2024 14:59:45 -0700 Subject: [PATCH 08/14] amplitudes_shape bug for read-write --- py4DSTEM/process/phase/phase_base_class.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index f99fe2e9e..22fda4ac9 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -1594,6 +1594,7 @@ def to_h5(self, group): "data_transpose": self._rotation_best_transpose, "positions_px": asnumpy(self._positions_px), "region_of_interest_shape": self._region_of_interest_shape, + "amplitudes_shape": self._amplitudes_shape, "num_diffraction_patterns": self._num_diffraction_patterns, "sampling": self.sampling, "angular_sampling": self.angular_sampling, @@ -1735,6 +1736,7 @@ def _populate_instance(self, group): self._positions_px = xp.asarray(preprocess_md["positions_px"]) self._angular_sampling = preprocess_md["angular_sampling"] self._region_of_interest_shape = preprocess_md["region_of_interest_shape"] + self._amplitudes_shape = preprocess_md["amplitudes_shape"] self._num_diffraction_patterns = preprocess_md["num_diffraction_patterns"] self._positions_mask = preprocess_md["positions_mask"] From 14c884f1ceed0d8150b9f0378b45e5234497f774 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 May 2024 15:54:47 -0700 Subject: [PATCH 09/14] store_initial_arrays flag --- .../process/phase/mixedstate_ptychography.py | 15 +++++++++---- .../process/phase/ptychographic_tomography.py | 22 ++++++++++++++----- .../process/phase/singleslice_ptychography.py | 16 ++++++++++---- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py index bd650a931..d3213b49d 100644 --- a/py4DSTEM/process/phase/mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -224,6 +224,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -293,6 +294,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -461,8 +464,9 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape # center probe positions @@ -498,8 +502,11 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None if object_fov_mask is None or plot_probe_overlaps: # overlaps diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index f3b2991ab..6bd84c374 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -234,6 +234,7 @@ def preprocess( object_fov_mask: np.ndarray = True, crop_patterns: bool = False, main_tilt_axis: str = "vertical", + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -294,6 +295,8 @@ def preprocess( The default, 'vertical' (first scan dimension), results in object size (q,p,q), 'horizontal' (second scan dimension) results in object size (p,p,q), any other value (e.g. None) results in object size (max(p,q),p,q). + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -528,8 +531,10 @@ def preprocess( main_tilt_axis, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] self._num_voxels = self._object.shape[0] @@ -559,10 +564,14 @@ def preprocess( # initialize probe self._probes_all = [] - self._probes_all_initial = [] - self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( self._probe_init[index] if list_Q else self._probe_init, @@ -573,8 +582,9 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial.append(_probe.copy()) - self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py index e1f14a90f..973ca8ece 100644 --- a/py4DSTEM/process/phase/singleslice_ptychography.py +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -206,6 +206,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -265,6 +266,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -434,8 +437,10 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape # center probe positions @@ -471,8 +476,11 @@ def preprocess( device=device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None if object_fov_mask is None or plot_probe_overlaps: # overlaps From 071d555d3dda72c71746cec845285a8d8938bd23 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 May 2024 16:29:44 -0700 Subject: [PATCH 10/14] adding rest classes --- .../magnetic_ptychographic_tomography.py | 22 ++++++++++++++----- .../process/phase/magnetic_ptychography.py | 21 +++++++++++++----- .../mixedstate_multislice_ptychography.py | 15 +++++++++---- .../process/phase/multislice_ptychography.py | 15 +++++++++---- 4 files changed, 53 insertions(+), 20 deletions(-) diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py index c9efae806..9ef1ba4dd 100644 --- a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -239,6 +239,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -295,6 +296,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -534,8 +537,10 @@ def preprocess( else: self._object = obj - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] self._num_voxels = self._object.shape[1] @@ -565,10 +570,14 @@ def preprocess( # initialize probe self._probes_all = [] - self._probes_all_initial = [] - self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( self._probe_init[index] if list_Q else self._probe_init, @@ -579,8 +588,9 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial.append(_probe.copy()) - self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py index 2e887739f..01326a08b 100644 --- a/py4DSTEM/process/phase/magnetic_ptychography.py +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -219,6 +219,7 @@ def preprocess( progress_bar: bool = True, object_fov_mask: np.ndarray = True, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -285,6 +286,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional if not none, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -597,8 +600,9 @@ def preprocess( else: self._object = obj - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] # center probe positions @@ -627,10 +631,14 @@ def preprocess( # initialize probe self._probes_all = [] - self._probes_all_initial = [] - self._probes_all_initial_aperture = [] list_Q = isinstance(self._probe_init, (list, tuple)) + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + for index in range(self._num_measurements): _probe, self._semiangle_cutoff = self._initialize_probe( self._probe_init[index] if list_Q else self._probe_init, @@ -641,8 +649,9 @@ def preprocess( ) self._probes_all.append(_probe) - self._probes_all_initial.append(_probe.copy()) - self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + if store_initial_arrays: + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) del self._probe_init diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py index 119fc3a3c..329443564 100644 --- a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -278,6 +278,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -347,6 +348,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional If not None, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -516,8 +519,9 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] # center probe positions @@ -553,8 +557,11 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py index db17cb1a8..3a245d367 100644 --- a/py4DSTEM/process/phase/multislice_ptychography.py +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -252,6 +252,7 @@ def preprocess( force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, crop_patterns: bool = False, + store_initial_arrays: bool = True, device: str = None, clear_fft_cache: bool = None, max_batch_size: int = None, @@ -321,6 +322,8 @@ def preprocess( If None, probe_overlap intensity is thresholded crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. device: str, optional If not None, overwrites self._device to set device preprocess will be perfomed on. clear_fft_cache: bool, optional @@ -490,8 +493,9 @@ def preprocess( self._object_type, ) - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type self._object_shape = self._object.shape[-2:] # center probe positions @@ -527,8 +531,11 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + if store_initial_arrays: + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + else: + self._probe_initial_aperture = None # precompute propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( From a56d23bc275253efa279311628e3555f6ca7adf4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 May 2024 17:03:19 -0700 Subject: [PATCH 11/14] interpolation=None, fixed show_cbar --- py4DSTEM/visualize/show.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 7430992e0..78c09885c 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -78,6 +78,7 @@ def show( show_fft=False, apply_hanning_window=True, show_cbar=False, + interpolation=None, **kwargs, ): """ @@ -615,7 +616,14 @@ def show( # Plot the image if not hist: - cax = ax.matshow(_ar, vmin=vmin, vmax=vmax, cmap=cm, **kwargs) + cax = ax.matshow( + _ar, + vmin=vmin, + vmax=vmax, + cmap=cm, + interpolation=interpolation, + **kwargs, + ) if np.any(_ar.mask): mask_display = np.ma.array(data=_ar.data, mask=~_ar.mask) ax.matshow( @@ -623,7 +631,7 @@ def show( ) if show_cbar: ax_divider = make_axes_locatable(ax) - c_axis = ax_divider.append_axes("right", size="7%") + c_axis = ax_divider.append_axes("right", size="5%", pad="2.5%") fig.colorbar(cax, cax=c_axis) # ...or, plot its histogram else: @@ -806,6 +814,8 @@ def show( ax.set_yticks([]) # Show or return + fig.tight_layout() + returnval = [] if returnfig: returnval.append((fig, ax)) From a8b4a2f03d41315b6cffba5c503cddfe616376ea Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 May 2024 10:58:05 -0700 Subject: [PATCH 12/14] this is why I didnt want to touch show.. --- py4DSTEM/visualize/show.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 78c09885c..dcb1cf285 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -814,7 +814,6 @@ def show( ax.set_yticks([]) # Show or return - fig.tight_layout() returnval = [] if returnfig: @@ -832,6 +831,7 @@ def show( returnval.append(cax) if len(returnval) == 0: if figax is None: + plt.tight_layout() plt.show() return elif (len(returnval)) == 1: From f7a307154b9fb5b82e8f96d2729b1c05313e377d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 5 May 2024 20:55:08 -0700 Subject: [PATCH 13/14] fix BO visualize --- py4DSTEM/process/phase/parameter_optimize.py | 35 +++----------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 44deba2f5..9ef93af22 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -389,43 +389,16 @@ def visualize( ax = fig.add_subplot(spec[1]) skopt_plot_convergence(self._skopt_result, ax=ax) - else: - if plot_convergence: - figsize = kwargs.pop("figsize", (4 * ndims, 4 * (ndims + 0.5))) - spec = GridSpec( - nrows=ndims + 1, - ncols=ndims, - height_ratios=[2] * ndims + [1], - hspace=0.15, - ) - else: - figsize = kwargs.pop("figsize", (4 * ndims, 4 * ndims)) - spec = GridSpec(nrows=ndims, ncols=ndims, hspace=0.15) + spec.tight_layout(fig) + else: if plot_evaluations: - axs = skopt_plot_evaluations(self._skopt_result) + skopt_plot_evaluations(self._skopt_result) elif plot_objective: cmap = kwargs.pop("cmap", "magma") - axs = skopt_plot_objective(self._skopt_result, cmap=cmap, **kwargs) + skopt_plot_objective(self._skopt_result, cmap=cmap, **kwargs) elif plot_convergence: skopt_plot_convergence(self._skopt_result) - return self - - fig = axs[0, 0].figure - fig.set_size_inches(figsize) - for i in range(ndims): - for j in range(ndims): - ax = axs[i, j] - ax.remove() - ax.figure = fig - fig.add_axes(ax) - ax.set_subplotspec(spec[i, j]) - - if plot_convergence: - ax = fig.add_subplot(spec[ndims, :]) - skopt_plot_convergence(self._skopt_result, ax=ax) - - spec.tight_layout(fig) return self From 0aa9bab7602420ff21f84e9fed2bb5888a1147cc Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 6 May 2024 09:47:50 -0700 Subject: [PATCH 14/14] guarding against reset=False users --- py4DSTEM/process/phase/ptychographic_methods.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 14d8a82b1..0c1273123 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -277,6 +277,11 @@ def _reset_reconstruction( else: self.error_iterations = [] self._exit_waves = None + else: + # reset=False first start + if not hasattr(self, "error"): + self.error_iterations = [] + self._exit_waves = None @property def object_fft(self): @@ -3413,6 +3418,14 @@ def _reset_reconstruction( self._exit_waves = [None] * len(self._probes_all) else: self._exit_waves = None + else: + # reset=False first start + if not hasattr(self, "error"): + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None def _return_single_probe(self, probe=None): """Current probe estimate"""