Skip to content

Commit

Permalink
Make register_ROIs work with rigid correction
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanbb committed Jan 5, 2025
1 parent 5e74a7c commit 1ceea4f
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions caiman/base/rois.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def register_ROIs(A1,
"max_shifts": (10, 10),
"shifts_opencv": True,
"upsample_factor_grid": 4,
"interp_shifts_precisely": True
"interp_shifts_precisely": True,
"max_deviation_rigid": 2
# any other argument to tile_and_correct can also be used in align_options
}

Expand All @@ -442,15 +443,23 @@ def register_ROIs(A1,
template2, shifts, _, _ = tile_and_correct(template2, template1 - template1.min(),
add_to_movie=template2.min(), **align_options)

patch_centers = get_patch_centers(dims, overlaps=align_options["overlaps"], strides=align_options["strides"],
shifts_opencv=align_options["shifts_opencv"], upsample_factor_grid=align_options["upsample_factor_grid"])
patch_grid = tuple(len(centers) for centers in patch_centers)
_sh_ = np.stack(shifts, axis=0)
shifts_x = np.reshape(_sh_[:, 1], patch_grid, order='C').astype(np.float32)
shifts_y = np.reshape(_sh_[:, 0], patch_grid, order='C').astype(np.float32)
if align_options["max_deviation_rigid"] == 0:
# repeat rigid shifts to size of the image
shifts_x_full = np.full(dims, -shifts[1])
shifts_y_full = np.full(dims, -shifts[0])
else:
# piecewise - interpolate from patches to get shifts per pixel
patch_centers = get_patch_centers(dims, overlaps=align_options["overlaps"], strides=align_options["strides"],
shifts_opencv=align_options["shifts_opencv"],
upsample_factor_grid=align_options["upsample_factor_grid"])
patch_grid = tuple(len(centers) for centers in patch_centers)
_sh_ = np.stack(shifts, axis=0)
shifts_x = np.reshape(_sh_[:, 1], patch_grid, order='C').astype(np.float32)
shifts_y = np.reshape(_sh_[:, 0], patch_grid, order='C').astype(np.float32)

shifts_x_full = interpolate_shifts(-shifts_x, patch_centers, tuple(range(d) for d in dims))
shifts_y_full = interpolate_shifts(-shifts_y, patch_centers, tuple(range(d) for d in dims))

shifts_x_full = interpolate_shifts(-shifts_x, patch_centers, tuple(range(d) for d in dims))
shifts_y_full = interpolate_shifts(-shifts_y, patch_centers, tuple(range(d) for d in dims))
x_remap = (shifts_x_full + x_grid).astype(np.float32)
y_remap = (shifts_y_full + y_grid).astype(np.float32)

Expand Down

0 comments on commit 1ceea4f

Please sign in to comment.