Skip to content

Commit

Permalink
Live visualization (#35)
Browse files Browse the repository at this point in the history
* Fixed a few bugs in multipeak phasing

* Adaptive alien removal for multipeak

Before phasing on a new peak, it compares the shared-object projected diffraction to the actual measured diffraction (both with a 1-pixel gaussian blur). A difference map is defined as the absolute difference between the two blurred images divided by the blurred projection. A mask is generated from all the voxels where the difference map is greater than twice its own median.
While the peak is being phased, the Fourier constraint is enforced using the measurement outside the mask and the projection inside.

* Moved live viewing to its own CoupledRec method so that iterate() isn't so crowded.

* Added a bunch of comments for multipeak stuff

* Adds feedback from the shared object when applying the Fourier constraint. The amount of feedback depends on that peak's confidence value.

* Updated parameters in config_mp:

Removed:
- adaptive_weights (replaced by adapt_trigger)

Added:
- adapt_trigger: Puts the adaptive weighting in the operation/trigger format
- adapt_power: Basically how harshly bad peaks are punished
- adapt_alien_start: When to start removing aliens
- adapt_alien_threshold: Voxels are masked where the error map exceeds this times its median

Renamed:
- weight_X << mp_weight_X
- adapt_threshold_X << peak_threshold_X

* Updated documentation for multipeak configuration

* Added live display trigger for both single and multipeak.

* Fixed colorscales on multipeak live visualization to match single-peak
  • Loading branch information
jacione authored Nov 4, 2024
1 parent 9b24f13 commit 0fe08f7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
74 changes: 58 additions & 16 deletions cohere_core/controller/phasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __init__(self, params, data_file, pkg, **kwargs):
self.new_alg,
self.twin_operation,
self.average_operation,
self.progress_operation]
self.progress_operation,
self.live_operation]

params['init_guess'] = params.get('init_guess', 'random')
if params['init_guess'] == 'AI_guess':
Expand Down Expand Up @@ -149,6 +150,13 @@ def init_dev(self, device_id):
self.dims = devlib.dims(self.data)
print('data shape', self.dims)

if self.params.get("live_trigger", None) is not None:
self.fig, self.axs = plt.subplots(2, 2, figsize=(12, 13), layout="constrained")
plt.show(block=False)
else:
self.fig = None
self.axs = None

if self.need_save_data:
self.saved_data = devlib.copy(self.data)
self.need_save_data = False
Expand Down Expand Up @@ -272,6 +280,9 @@ def iterate(self):
mx = devlib.amax(devlib.absolute(self.ds_image))
self.ds_image = self.ds_image / mx

if self.params.get("live_trigger", None) is not None:
plt.show()

return 0

def save_res(self, save_dir, only_image=False):
Expand Down Expand Up @@ -348,7 +359,7 @@ def modulus(self):
ratio = self.get_ratio(self.iter_data, devlib.absolute(self.rs_amplitudes))
error = dvut.get_norm(devlib.where((self.rs_amplitudes != 0), (devlib.absolute(self.rs_amplitudes) - self.iter_data),
0)) / dvut.get_norm(self.iter_data)
self.errs.append(error)
self.errs.append(error.get())
self.rs_amplitudes *= ratio

def set_prev_pc(self):
Expand Down Expand Up @@ -395,10 +406,41 @@ def average_operation(self):
def progress_operation(self):
print(f'------iter {self.iter} error {self.errs[-1]}')

def live_operation(self):
self.shift_to_center()
half = self.dims[0] // 2
qtr = self.dims[0] // 4
plt.suptitle(
f"iteration: {self.iter}\n"
f"error: {self.errs[-1]}\n"
)
[[ax.clear() for ax in row] for row in self.axs]
img = self.ds_image[qtr:-qtr, qtr:-qtr, half]
self.axs[0][0].set(title="Amplitude", xticks=[], yticks=[])
self.axs[0][0].imshow(devlib.absolute(img).get(), cmap="gray")
self.axs[0][1].set(title="Phase", xticks=[], yticks=[])
self.axs[0][1].imshow(devlib.angle(img).get(), cmap="hsv", interpolation_stage="rgba")

self.axs[1][0].set(title="Error", xlim=(0,self.iter_no), xlabel="Iteration", yscale="log")
self.axs[1][0].plot(self.errs[1:])
self.axs[1][1].set(title="Support", xticks=[], yticks=[])
self.axs[1][1].imshow(self.support[qtr:-qtr, qtr:-qtr, half].get(), cmap="gray")

plt.draw()
plt.pause(0.15)

def get_ratio(self, divident, divisor):
ratio = devlib.where((divisor > 1e-9), divident / divisor, 0.0)
return ratio

def shift_to_center(self):
ind = devlib.center_of_mass(self.support.astype('int32'))
shift_dist = (self.dims[0]//2) - devlib.round(devlib.array(ind))
shift_dist = devlib.to_numpy(shift_dist).tolist()
axis = tuple(range(len(self.ds_image.shape)))
self.ds_image = devlib.roll(self.ds_image, shift_dist, axis=axis)
self.support = devlib.roll(self.support, shift_dist, axis=axis)


class Peak:
"""
Expand Down Expand Up @@ -539,7 +581,7 @@ def __init__(self, params, peak_dirs, pkg, **kwargs):

self.params["switch_peak_trigger"] = self.params.get("switch_peak_trigger", [0, 5])
self.params["adapt_trigger"] = self.params.get("adapt_trigger", [])
self.params["calc_strain"] = self.params.get("calc_strain", True)
self.params["calc_strain"] = self.params.get("calc_strain", False)

self.peak_dirs = peak_dirs
self.er_iter = False # Indicates whether the last iteration done was ER
Expand Down Expand Up @@ -594,8 +636,7 @@ def init_dev(self, device_id):
self.dims = self.data.shape
self.n_voxels = self.dims[0]*self.dims[1]*self.dims[2]

live_view = self.params.get("live_view", False)
if live_view:
if self.params.get("live_trigger", None) is not None:
self.fig, self.axs = plt.subplots(2, 2, figsize=(12, 13), layout="constrained")
plt.show(block=False)
else:
Expand Down Expand Up @@ -747,7 +788,7 @@ def iterate(self):

print('iterate took ', (time.time() - start_t), ' sec')

if self.params.get("live_view", False):
if self.params.get("live_trigger", None) is not None:
plt.show()

if devlib.hasnan(self.ds_image):
Expand Down Expand Up @@ -968,7 +1009,7 @@ def progress_operation(self):
prg += f"| LEHD {self.ctrl_error[-1][3]:0.6f} "
print(prg)

def update_live(self):
def live_operation(self):
half = self.dims[0] // 2
qtr = self.dims[0] // 4
plt.suptitle(
Expand All @@ -978,18 +1019,19 @@ def update_live(self):
f"peak weight: {self.peak_objs[self.pk].weight:0.3f}\n"
f"confidence threshold: {self.peak_threshold[self.iter]}\n"
)

[[ax.clear() for ax in row] for row in self.axs]
self.axs[0][0].set_title("XY amplitude")
self.axs[0][0].imshow(devlib.absolute(self.ds_image[qtr:-qtr, qtr:-qtr, half]).get())
self.axs[1][0].set_title("XY phase")
self.axs[1][0].imshow(devlib.angle(self.ds_image[qtr:-qtr, half, qtr:-qtr]).get())

s = 125
self.axs[0][1].set_title("Measurement")
self.axs[0][1].imshow(devlib.log(devlib.ifftshift(self.peak_objs[self.pk].res_data)[s]+1).get())
self.axs[0][0].imshow(devlib.absolute(self.ds_image[qtr:-qtr, qtr:-qtr, half]).get(), cmap="gray")
self.axs[0][1].set_title("XY phase")
self.axs[0][1].imshow(devlib.angle(self.ds_image[qtr:-qtr, half, qtr:-qtr]).get(), cmap="hsv",
interpolation_stage="rgba")

self.axs[1][0].set_title("Measurement")
meas = devlib.sum(devlib.ifftshift(self.peak_objs[self.pk].res_data), axis=0)
self.axs[1][0].imshow(devlib.sqrt(meas).get(), cmap="magma")
self.axs[1][1].set_title("Fourier Constraint")
self.axs[1][1].imshow(devlib.log(devlib.ifftshift(self.iter_data)[s]+1).get())
data = devlib.sum(devlib.ifftshift(self.iter_data), axis=0)
self.axs[1][1].imshow(devlib.sqrt(data).get(), cmap="magma")
plt.setp(self.fig.get_axes(), xticks=[], yticks=[])

plt.draw()
Expand Down
10 changes: 10 additions & 0 deletions docs/source/config_rec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ progress

progress_trigger = [0, 20]

live viewing
++++++++++++
| This feature allows for a live view of the amplitude, phase, support, and error as the reconstruction develops. With adaptive multipeak phasing, this will instead show the amplitude, phase, measured diffraction pattern, and adapted diffraction pattern. These are shown using a central slice cropped to half the full array size.
- live_trigger:
| Defines when to update the live view.
| example:
::

live_trigger = [0, 10]

GA
++
- ga_generations:
Expand Down

0 comments on commit 0fe08f7

Please sign in to comment.