Skip to content

Commit

Permalink
Change to use different method of long-docs generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed May 20, 2024
1 parent 46b6aa5 commit 817c4e6
Show file tree
Hide file tree
Showing 21 changed files with 1,005 additions and 24 deletions.
16 changes: 7 additions & 9 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,14 @@
'filename_pattern': '/plot_',
}

#if "-D" in sys.argv:
# key_value = sys.argv.index("-D") + 1
# name_and_value=sys.argv.pop(key_value)
# sys.argv.pop(sys.argv.index("-D"))
# param_name, param_values = name_and_value.split("=")
# assert param_name == "long_builds"
# long_buildnames = param_values.split(",")
# if "handle_drift" in long_buildnames:
if tags.has("handle_drift") or tags.has("all_long_plot"):
sphinx_gallery_conf["filename_pattern"] += '|/long_plot_handle_drift'

if (handle_drift_path := (Path('long_tutorials/handle_drift'))).is_dir():
shutil.rmtree(handle_drift_path)

sphinx_gallery_conf['examples_dirs'].append('../examples/long_tutorials/handle_drift')
sphinx_gallery_conf["gallery_dirs"].append(handle_drift_path.as_posix())


intersphinx_mapping = {
"neo": ("https://neo.readthedocs.io/en/latest/", None),
Expand Down
1 change: 1 addition & 0 deletions doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
load_matlab_data
combine_recordings
process_by_channel_group
/long_tutorials/handle_drift/plot_handle_drift.rst
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions doc/long_tutorials/handle_drift/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
:orphan:

Handle Drift Tutorial
---------------------

This tutorial is not mean to be displayed on
a sphinx gallery. The generated index.rst is not
meant to be linked to in any toctree.

Instead, sphinx-gallery is used to
automatically build this page, which
takes a long time (~25 minutes), and it is
linked too manually, directly to the
rst (TODO: fill in filename) that
sphinx-gallery generates.


.. raw:: html

<div class="sphx-glr-thumbnails">

.. thumbnail-parent-div-open
.. raw:: html

<div class="sphx-glr-thumbcontainer" tooltip="Spikeinterface offers a very flexible framework to handle drift as a preprocessing step. If you...">

.. only:: html

.. image:: /long_tutorials/handle_drift/images/thumb/sphx_glr_plot_handle_drift_thumb.png
:alt:

:ref:`sphx_glr_long_tutorials_handle_drift_plot_handle_drift.py`

.. raw:: html

<div class="sphx-glr-thumbnail-title">Handle motion/drift with spikeinterface NEW</div>
</div>


.. thumbnail-parent-div-close
.. raw:: html

</div>


.. toctree::
:hidden:

/long_tutorials/handle_drift/plot_handle_drift


.. only:: html

.. container:: sphx-glr-footer sphx-glr-footer-gallery

.. container:: sphx-glr-download sphx-glr-download-python

:download:`Download all examples in Python source code: handle_drift_python.zip </long_tutorials/handle_drift/handle_drift_python.zip>`

.. container:: sphx-glr-download sphx-glr-download-jupyter

:download:`Download all examples in Jupyter notebooks: handle_drift_jupyter.zip </long_tutorials/handle_drift/handle_drift_jupyter.zip>`


.. only:: html

.. rst-class:: sphx-glr-signature

`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
187 changes: 187 additions & 0 deletions doc/long_tutorials/handle_drift/plot_handle_drift.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Handle motion/drift with spikeinterface NEW\n\nSpikeinterface offers a very flexible framework to handle drift as a preprocessing step.\nIf you want to know more, please read the\n`motion_correction` section of the documentation.\n\nHere is a short demo on how to handle drift using the high-level function\n:py:func:`~spikeinterface.preprocessing.correct_motion()`.\n\nThis function takes a preprocessed recording as input and then internally runs\nseveral steps (it can be slow!) and returns a lazy\nrecording that interpolates the traces on-the-fly to compensate for the motion.\n\nInternally this function runs the following steps:\n\n| **1.** ``localize_peaks()``\n| **2.** ``select_peaks()`` (optional)\n| **3.** ``estimate_motion()``\n| **4.** ``interpolate_motion()``\n\nAll these sub-steps can be run with different methods and have many parameters.\nThe high-level function suggests 3 pre-difined \"presets\".\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"FIRST WE IMPORT AND # We will use GENERATE RECORDINGS\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pathlib import Path\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport shutil\nimport spikeinterface.full as si"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from spikeinterface.extractors import toy_example\nfrom spikeinterface.generation.drifting_generator import generate_drifting_recording\n\n# TODO: add a note that it must be run in a if __name__ == \"__main__\" block.\n# TODO: is there currently any way to compute accuracy of method based on\n# drift-corrected vs. original static recording?\n\n_, raw_rec, _ = generate_drifting_recording(\n num_units=300,\n duration=1000,\n generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0),\n seed=42,\n)\nprint(raw_rec)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We preprocess the recording with bandpass filter and a common median reference.\nNote, that it is better to not whiten the recording before motion estimation\nto get a better estimate of peak locations!\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def preprocess_chain(rec):\n rec = si.bandpass_filter(rec, freq_min=300.0, freq_max=6000.0)\n rec = si.common_reference(rec, reference=\"global\", operator=\"median\")\n return rec\n\n\nrec = preprocess_chain(raw_rec)\n\njob_kwargs = dict(n_jobs=40, chunk_duration=\"1s\", progress_bar=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run motion correction with one function!\n\nCorrecting for drift is easy! You just need to run a single function.\nWe will try this function with 3 presets.\n\nInternally a preset is a dictionary of dictionaries containing all parameters for every steps.\n\nHere we also save the motion correction results into a folder to be able to load them later.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# internally, we can explore a preset like this\n# every parameter can be overwritten at runtime\nfrom spikeinterface.preprocessing.motion import motion_options_preset\n\nprint(motion_options_preset[\"kilosort_like\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"lets try theses 3 presets\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"some_presets = (\"rigid_fast\", \"kilosort_like\", \"nonrigid_accurate\")\nresults = {preset: {} for preset in some_presets} # TODO: RENAME VAR"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and compute motion with 3 presets\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"for preset in some_presets:\n print(\"Computing with\", preset)\n\n recording_corrected, motion_info = si.correct_motion( # TODO: RECORDING_CORRECTED UNUSED\n rec, preset=preset, output_motion_info=True, **job_kwargs\n )\n results[preset][\"motion_info\"] = motion_info"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the results\n\nWe load back the results and use the widgets module to explore the estimated drift motion.\n\nFor all methods we have 4 plots:\n * **top left:** time vs estimated peak depth\n * **top right:** time vs peak depth after motion correction\n * **bottom left:** the average motion vector across depths and all motion across spatial depths for non-rigid estimation)\n * **bottom right:** if motion correction is non rigid, the motion vector across depths is plotted as a map, with the color code representing the motion in micrometers.\n\nA few comments on the figures:\n * The preset **'rigid_fast'** has only one motion vector for the entire probe because it is a \"rigid\" case.\n The motion amplitude is globally underestimated because it averages across depths.\n However, the corrected peaks are flatter than the non-corrected ones, so the job is partially done.\n The big jump at=600s when the probe start moving is recovered quite well.\n * The preset **kilosort_like** gives better results because it is a non-rigid case.\n The motion vector is computed for different depths.\n The corrected peak locations are flatter than the rigid case.\n The motion vector map is still be a bit noisy at some depths (e.g around 1000um).\n * The preset **nonrigid_accurate** seems to give the best results on this recording.\n The motion vector seems less noisy globally, but it is not \"perfect\" (see at the top of the probe 3200um to 3800um).\n Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion:\n the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable.\n The method defined by this preset is able to capture this.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"for preset in some_presets:\n fig = plt.figure(figsize=(14, 8))\n si.plot_motion(\n results[preset][\"motion_info\"],\n figure=fig,\n depth_lim=(400, 600),\n color_amplitude=True,\n amplitude_cmap=\"inferno\",\n scatter_decimate=10,\n )\n fig.suptitle(f\"{preset=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot peak localization\n\nWe can also use the internal extra results (peaks and peaks location) to check if putative\nclusters have a lower spatial spread after the motion correction.\n\nHere we plot the estimated peak locations (left) and the corrected peak locations\n(on right) on top of the probe.\nThe color codes for the peak amplitudes.\n\nWe can see here that some clusters seem to be more compact on the 'y' axis, especially\nfor the preset \"nonrigid_accurate\".\n\nBe aware that there are two ways to correct for the motion:\n 1. Interpolate traces and detect/localize peaks again (`interpolate_recording()`)\n 2. Compensate for drifts directly on peak locations (`correct_motion_on_peaks()`)\n\nCase 1 is used before running a spike sorter and the case 2 is used here to display the results.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks\n\nfor preset in some_presets:\n\n fig, axs = plt.subplots(ncols=2, figsize=(12, 8), sharey=True)\n\n ax = axs[0]\n si.plot_probe_map(rec, ax=ax)\n\n motion_info = results[preset][\"motion_info\"]\n\n peaks = motion_info[\"peaks\"]\n sr = rec.get_sampling_frequency()\n time_lim0 = 0\n time_lim1 = 50\n mask = (peaks[\"sample_index\"] > int(sr * time_lim0)) & (peaks[\"sample_index\"] < int(sr * time_lim1))\n sl = slice(None, None, 5)\n amps = np.abs(peaks[\"amplitude\"][mask][sl])\n amps /= np.quantile(amps, 0.95)\n c = plt.get_cmap(\"inferno\")(amps)\n\n color_kargs = dict(alpha=0.2, s=2, c=c)\n\n loc = motion_info[\"peak_locations\"]\n ax.scatter(loc[\"x\"][mask][sl], loc[\"y\"][mask][sl], **color_kargs)\n\n loc2 = correct_motion_on_peaks(\n motion_info[\"peaks\"],\n motion_info[\"peak_locations\"],\n rec.sampling_frequency,\n motion_info[\"motion\"],\n motion_info[\"temporal_bins\"],\n motion_info[\"spatial_bins\"],\n direction=\"y\",\n )\n\n ax = axs[1]\n si.plot_probe_map(rec, ax=ax)\n ax.scatter(loc2[\"x\"][mask][sl], loc2[\"y\"][mask][sl], **color_kargs)\n\n ax.set_ylim(400, 600)\n fig.suptitle(f\"{preset=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Accuracy and Run Times\n\nPresets and related methods have differents accuracies but also computation speeds.\nIt is good to have this in mind!\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"run_times = []\nfor preset in some_presets:\n run_times.append(results[preset][\"motion_info\"][\"run_times\"])\nkeys = run_times[0].keys()\n\nbottom = np.zeros(len(run_times))\nfig, ax = plt.subplots()\nfor k in keys:\n rtimes = np.array([rt[k] for rt in run_times])\n if np.any(rtimes > 0.0):\n ax.bar(some_presets, rtimes, bottom=bottom, label=k)\n bottom += rtimes\nax.legend()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Handle motion/drift with spikeinterface
===================================
Handle motion/drift with spikeinterface NEW
===========================================
Spikeinterface offers a very flexible framework to handle drift as a preprocessing step.
If you want to know more, please read the
Expand Down Expand Up @@ -42,19 +42,9 @@
# drift-corrected vs. original static recording?

_, raw_rec, _ = generate_drifting_recording(
num_units=25,
duration=10,
generate_sorting_kwargs=dict(firing_rates=(5, 10), refractory_period_ms=2.0),
generate_displacement_vector_kwargs=dict(motion_list=[
dict(
drift_mode="zigzag",
amplitude_factor=1.0,
non_rigid_gradient=None,
t_start_drift=1,
t_end_drift=None,
period_s=1,
),
]),
num_units=300,
duration=1000,
generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0),
seed=42,
)
print(raw_rec)
Expand Down
1 change: 1 addition & 0 deletions doc/long_tutorials/handle_drift/plot_handle_drift.py.md5
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
eea0824e90479d36d258993e35fe83b8
Loading

0 comments on commit 817c4e6

Please sign in to comment.