diff --git a/.gitignore b/.gitignore index e966119..d7df25e 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +.jupyter # IPython profile_default/ @@ -160,4 +161,4 @@ cython_debug/ #.idea/ # MacOS crufts -.DS_Store \ No newline at end of file +.DS_Store diff --git a/README.md b/README.md index 2c1185d..35e78b0 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,11 @@ This package is intended for use on Oscar with GPU support for PyTorch. To use on Oscar, first clone the repo. Then load the anaconda module: ```shell -module load anaconda/2022.05 -source /gpfs/runtime/opt/anaconda/2022.05/etc/profile.d/conda.sh +module load miniconda3/23.11.0s +source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh ``` -If you have never loaded the `anaconda/2022.05` module before, you need to initialize +If you have never loaded the `miniconda3/23.11.0s` module before, you need to initialize the conda module by running the following command. ```shell diff --git a/STEPS.md b/STEPS.md new file mode 100644 index 0000000..632003a --- /dev/null +++ b/STEPS.md @@ -0,0 +1,374 @@ +# Pipeline (under development) + +## Multiple environment setups + +There are at least 3 environments that need to be created to manage steps separately. + +- First load `miniconda3` module: + + ```shell + module load miniconda3/23.11.0s + ``` + +- If you have loaded this before, also do the following. Note: you only need to this **ONCE**. + + ```shell + conda init bash + ``` + +- Then do the following to create all three environments: + + ```shell + conda env create -f environments/suite2p.yml + conda env create -f environments/deepcad.yml + conda env create -f environments/roicat.yml + ``` + +## Steps + +The `scripts/example-steps.sh` show examples of steps to follow for a single animal, single session, single plane. + +The following demonstrates how to run things in batches. + +### 0-makedirs: Prepare directory structure + +**Purpose**: this prepares the directory structure that is expected from all scripts, as well as necessary log folders. Generally the folder structure follows: + +``` +└── SUBJECT + ├── DATE + │ └── STEP + │ └── PLANE + └── multi-session + └── PLANE +``` + +**Task**: + +- Please edit the variables under `# define variables` in `scripts/0-makedirs.sh` before running +- Then run: + +```shell +bash scripts/0-makedirs.sh +``` + +- Then put the raw TIFF data under `0-raw` folders + +**Output**: The following is an example output directory tree from the script: + +
Click to expand example output + +``` +└── MS2457 + ├── 20230902 + │ ├── 0-raw + │ │ ├── plane0 # put your raw TIFF here + │ │ └── plane1 # put your raw TIFF here + │ ├── 1-moco + │ │ ├── plane0 + │ │ └── plane1 + │ ├── 2-deepcad + │ │ ├── plane0 + │ │ └── plane1 + │ ├── 3-suite2p + │ │ ├── plane0 + │ │ └── plane1 + │ └── 4-muse + │ ├── plane0 + │ └── plane1 + . + . + . + └── multi-session + ├── plane0 + └── plane1 + +``` + + +Currently only 1 file per folder is expected. + +You will be reminded at the end to put your raw data in appropriate folders, like this + +``` +Please put raw TIF data in these folders, according to their name +/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230902/0-raw/plane0 +/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230902/0-raw/plane1 +/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230907/0-raw/plane0 +/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230907/0-raw/plane1 +/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230912/0-raw/plane0 +/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230912/0-raw/plane1 +``` + +So that it looks more like this: + +``` +└── MS2457 + ├── 20230902 + │ ├── 0-raw + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1.tif +... +``` + +
+ + +### 1-moco: Motion correction + +**Purpose**: this uses `suite2p` to perform motion correction before denoising in a CPU environment on Oscar, preferably high number of CPUs and sufficient memory. + +**Task**: + +- Please edit the variables under `# define variables` in `scripts/1-moco.sh` before running +- Take a look at `config/1-moco.yml` and change if you need to, refer to [`suite2p` registration parameter documentation](https://suite2p.readthedocs.io/en/latest/settings.html#registration-settings) for more information + +- Then run: + +```shell +sbatch scripts/1-moco.sh +``` + +- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/1-moco/logs-.out`. + +**Output**: The following is an example output directory tree from the script: + + +
Click to expand example output + +``` +└── MS2457 + ├── 20230902 + │ ├── 0-raw + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1.tif + │ ├── 1-moco + │ │ ├── plane0 + │ │ │ ├── MS2457-20230902-plane0_mc.tif + │ │ │ └── suite2p-moco-only + │ │ │ ├── data.bin + │ │ │ ├── data_raw.bin + │ │ │ ├── ops.npy + │ │ │ └── reg_tif + │ │ │ ├── file00000_chan0.tif + │ │ │ ├── file001000_chan0.tif + │ │ │ └── file00500_chan0.tif + │ │ └── plane1 + │ │ ├── MS2457-20230902-plane1_mc.tif + │ │ └── suite2p-moco-only + │ │ ├── data.bin + │ │ ├── data_raw.bin + │ │ ├── ops.npy + │ │ └── reg_tif + │ │ ├── file00000_chan0.tif + │ │ ├── file001000_chan0.tif + │ │ └── file00500_chan0.tif +... +``` + +The file `MS2457-20230902-plane0_mc.tif` is created by concatenating files under `suite2p-moco-only/re_tif` folder. + +This file is expected to be here for the next step. + +Currently only 1 file per folder is expected. + +
+ + +### 2-deepcad: Denoising using `deepcad` + +**Purpose**: this uses `deepcad` to perform denoising on the motion-corrected file, in a GPU environment on Oscar, preferably high memory. + +**Task**: + +- Please edit the variables under `# define variables` in `scripts/2-deepcad.sh` before running +- Take a look at `config/2-deepcad.yml` and change if you need to + - [ ] **TODO** Where is documentation for `deepcad` parameters? +- Then run: + +```shell +sbatch scripts/2-deepcad.sh +``` + +- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/2-deepcad/logs-.out`. + +**Output**: The following is an example output directory tree from the script: + +
Click to expand example output + +``` +# some folders and files are hidden to prioritize changes +└── MS2457 + ├── 20230902 + │ ├── 0-raw + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1.tif + │ ├── 1-moco + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0_mc.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1_mc.tif + │ ├── 2-deepcad + │ │ ├── plane0 + │ │ │ ├── MS2457-20230902-plane0_mc-dc.tif + │ │ │ └── para.yaml + │ │ └── plane1 + │ │ ├── MS2457-20230902-plane1_mc-dc.tif + │ │ └── para.yaml +... +``` + + +Notice files under `2-deepcad` are created, for example `MS2457-20230902-plane0_mc-dc.tif`. + +This file is expected to be here for the next step. + +
+ + +### 3-suite2p: `suite2p` on denoised data + +**Purpose**: this uses `suite2p` on the denoised data from the `deepcad` step, in a CPU environment on Oscar, preferably high number of CPUs and sufficient memory. + +**Task**: + +- Please edit the variables under `# define variables` in `scripts/3-suite2p.sh` before running +- Take a look at `config/3-suite2p.yml` and change if you need to, refer to [`suite2p` parameter documentation](https://suite2p.readthedocs.io/en/latest/settings.html) + - Currently we still need to re-run registration + - Some hacks in parameters are documented in this config file + - If you're familiar with `suite2p`, use your experience and judgements to determine these +- Then run: + +```shell +sbatch scripts/3-suite2p.sh +``` + +- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/3-suite2p/logs-.out`. + +**Output**: The following is an example output directory tree from the script: + + +
Click to expand example output + + +``` +└── MS2457 + ├── 20230902 + │ ├── 0-raw + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1.tif + │ ├── 1-moco + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0_mc.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1_mc.tif + │ ├── 2-deepcad + │ │ ├── plane0 + │ │ │ └── MS2457-20230902-plane0_mc-dc.tif + │ │ └── plane1 + │ │ └── MS2457-20230902-plane1_mc-dc.tif + │ ├── 3-suite2p + │ │ ├── plane0 + │ │ │ ├── F.npy + │ │ │ ├── Fneu.npy + │ │ │ ├── data.bin + │ │ │ ├── figures + │ │ │ │ ├── backgrounds.png + │ │ │ │ └── roi-masks.png + │ │ │ ├── iscell.npy + │ │ │ ├── ops.npy + │ │ │ ├── spks.npy + │ │ │ └── stat.npy + │ │ └── plane1 + │ │ ├── F.npy + │ │ ├── Fneu.npy + │ │ ├── data.bin + │ │ ├── figures + │ │ │ ├── backgrounds.png + │ │ │ └── roi-masks.png + │ │ ├── iscell.npy + │ │ ├── ops.npy + │ │ ├── spks.npy + │ │ └── stat.npy +... +``` + +The folders under `3-suite2p` will be very familiar in terms of namings and organizations for `suite2p`-familiar folks + +
+ +### 4-roicat: Multisession registration using `roicat` + +**Purpose**: this uses `roicat` on the segmented data from `suite2p` step, in a CPU environment on Oscar, preferably high number of CPUs and sufficient memory. GPU is possible but not tested yet. + +**TODO** +- [ ] consider changing "4-muse" to "4-roicat" +- [ ] parameterize to a config file from `scripts/4-roicat.sh` +- [ ] refactor the long `scripts/4-roicat.sh` into smaller chunks +- [ ] save the results from `multi-session` collective folder back to relevant single-session directories, such as ROI mapping and aligned FOV + +**Task**: + +- Please edit the variables under `# define variables` in `scripts/4-roicat.sh` before running + +- Then run: + +```shell +sbatch scripts/4-roicat.sh +``` + +- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/4-roicat/logs-.out`. + +**Output**: The following is an example output directory tree from the script: + +
Click to expand example output + +``` +MS2457/multi-session/ +├── plane0 +│ ├── aligned-images.pkl +│ ├── figures +│ │ ├── FOV_clusters_allrois.gif +│ │ ├── FOV_clusters_iscells.gif +│ │ ├── aligned-fov.png +│ │ ├── aligned-rois.png +│ │ ├── cluster-metrics.png +│ │ ├── num-persist-roi-overall.png +│ │ ├── num-persist-roi-per-session.png +│ │ ├── pw-sim-distrib.png +│ │ └── pw-sim-scatter.png +│ ├── finalized-roi.csv +│ ├── roicat-output.pkl +│ └── summary-roi.csv +└── plane1 + ├── aligned-images.pkl + ├── figures + │ ├── FOV_clusters_allrois.gif + │ ├── FOV_clusters_iscells.gif + │ ├── aligned-fov.png + │ ├── aligned-rois.png + │ ├── cluster-metrics.png + │ ├── num-persist-roi-overall.png + │ ├── num-persist-roi-per-session.png + │ ├── pw-sim-distrib.png + │ └── pw-sim-scatter.png + ├── finalized-roi.csv + ├── roicat-output.pkl + └── summary-roi.csv +``` + +The output data of `roicat` is in `roicat-output.pkl` + +For visual inspection: see the figures in `figures` + +For analysis: `finalized-roi.csv` is what you may want + +
diff --git a/cellreg/1-moco.py b/cellreg/1-moco.py new file mode 100644 index 0000000..561538f --- /dev/null +++ b/cellreg/1-moco.py @@ -0,0 +1,151 @@ +import os +import glob +import shutil +from pathlib import Path + +from natsort import natsorted +from tqdm import tqdm +import numpy as np + +import tifffile +import suite2p + +import argparse +import yaml + + +# this stores constant suite2p ops to force overwrite config +CONSTANT_MOCO_SUITE2P_OPS = { + 'do_registration': 1, + 'reg_tif': True, + 'roidetect': False +} + + +def get_tiff_list(folder, ext=['.tif', '.tiff']): + file_list = [] + for x in ext: + file_list.extend(list(folder.glob(f'*{x}'))) + return file_list + +def merge_tiff_stacks(input_tiffs, output_file): + with tifffile.TiffWriter(output_file) as stack: + for filename in tqdm(input_tiffs): + with tifffile.TiffFile(filename) as tif: + for page in tif.pages: + stack.write( + page.asarray(), + photometric='minisblack', + contiguous=True + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Run motion correction via suite2p (moco) [only 1 channel allow]') + + parser.add_argument("--root-datadir", type=str) + parser.add_argument("--subject", type=str) + parser.add_argument("--date", type=str) + parser.add_argument("--plane", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("--cleanup", action="store_true") + + args = parser.parse_args() + print(args) + + # process data paths + ROOT_DATA_DIR = args.root_datadir + SUBJECT_ID = args.subject + DATE_ID = args.date + PLANE_ID = args.plane + MOCO_CFG_PATH = args.config + CLEAN_UP = args.cleanup + + # define paths + EXP_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID / DATE_ID + RAW_DIR = EXP_DIR / '0-raw' / PLANE_ID + MOCO_DIR = EXP_DIR / '1-moco' / PLANE_ID + MOCO_FILEPATH = MOCO_DIR / f'{SUBJECT_ID}-{DATE_ID}-{PLANE_ID}_mc.tif' + AUX_SUITE2P_DIR = MOCO_DIR / 'suite2p-moco-only' + + # find tiffs + tiff_list = get_tiff_list(RAW_DIR) + tiff_list_str = "\n".join(map(str,tiff_list)) + assert len(tiff_list) == 1, \ + f'Currently only accepting one file inside the raw directory "{RAW_DIR}", ' \ + f'found the following: {tiff_list_str}' + + raw_tiff_path = tiff_list[0] + with tifffile.TiffFile(str(raw_tiff_path)) as tif: + num_frames = len(tif.pages) + page0 = tif.pages[0] + Ly, Lx = page0.shape + print(f'Found raw tiff file ({num_frames} frames of {Ly} x {Lx}): "{raw_tiff_path}"') + + # make db + db = { + 'data_path': str(MOCO_DIR), + 'save_path0': str(AUX_SUITE2P_DIR), + 'tiff_list': tiff_list, + } + print('Data settings: ') + print(db) + + # make ops + ops = suite2p.default_ops() + with open(MOCO_CFG_PATH, 'r') as f: + print(f'Config file loaded from: "{MOCO_CFG_PATH}"') + ops_cfg = yaml.safe_load(f) + ops.update(ops_cfg) + ops.update(CONSTANT_MOCO_SUITE2P_OPS) + ops.update(db) + print('Suite2p configuration') + print(ops) + + # convert tiff to binary + ops = suite2p.io.tiff_to_binary(ops) + f_raw = suite2p.io.BinaryFile( + Ly=Ly, Lx=Lx, + filename=ops['raw_file'] + ) + + # prepare registration file + f_reg = suite2p.io.BinaryFile( + Ly=Ly, Lx=Lx, n_frames = f_raw.shape[0], + filename = ops['reg_file'] + ) + + # registration + suite2p.registration_wrapper( + f_reg, f_raw=f_raw, + f_reg_chan2=None, f_raw_chan2=None, align_by_chan2=False, + refImg=None, ops=ops + ) + + # combine registration batches + REG_TIFF_DIR = Path(ops['reg_file']).parents[0] / 'reg_tif' + reg_tiff_list = natsorted(list(map(str,REG_TIFF_DIR.glob('*.tif*')))) + reg_tiff_list_str = "\n".join(reg_tiff_list) + print(f'Combining the following tiff files into "{MOCO_FILEPATH}":\n{reg_tiff_list_str}') + + merge_tiff_stacks( + input_tiffs=reg_tiff_list, + output_file=MOCO_FILEPATH + ) + + # move files + if CLEAN_UP: + SRC_OUT_DIR = os.path.dirname(ops['reg_file']) + PARENT_SRC_DIR = os.path.dirname(SRC_OUT_DIR) + assert os.path.basename(PARENT_SRC_DIR) == "suite2p", \ + f'To clean up properly, need "{PARENT_SRC_DIR}" to end with a "suite2p" folder.\n' \ + 'Remove "--cleanup" to pass this' + + DST_OUT_DIR = str(AUX_SUITE2P_DIR) + + print(f'Moving files from {SRC_OUT_DIR} to {DST_OUT_DIR}') + shutil.copytree(SRC_OUT_DIR, DST_OUT_DIR, copy_function=shutil.move, dirs_exist_ok=True) + shutil.rmtree(PARENT_SRC_DIR) + + print(f'Finished motion correction. Use {MOCO_FILEPATH} to continue.') + + \ No newline at end of file diff --git a/cellreg/2-deepcad.py b/cellreg/2-deepcad.py new file mode 100644 index 0000000..734954f --- /dev/null +++ b/cellreg/2-deepcad.py @@ -0,0 +1,98 @@ +"""Performs denoising. + +Denoises input images in a given directory with given model and output to an +output directory. +""" + +import os +import glob +import shutil +from pathlib import Path + +from natsort import natsorted +import tifffile + +import argparse +import yaml + +from deepcad.test_collection import testing_class + +def get_tiff_list(folder, ext=['.tif', '.tiff']): + file_list = [] + for x in ext: + file_list.extend(list(folder.glob(f'*{x}'))) + return file_list + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Run denoising with deepcad') + + parser.add_argument("--root-datadir", type=str) + parser.add_argument("--subject", type=str) + parser.add_argument("--date", type=str) + parser.add_argument("--plane", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("--cleanup", action="store_true") + + args = parser.parse_args() + print(args) + + # process data paths + ROOT_DATA_DIR = args.root_datadir + SUBJECT_ID = args.subject + DATE_ID = args.date + PLANE_ID = args.plane + DEEPCAD_CFG_PATH = args.config + CLEAN_UP = args.cleanup + + # define paths + EXP_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID / DATE_ID + MOCO_DIR = EXP_DIR / '1-moco' / PLANE_ID + DEEPCAD_DIR = EXP_DIR / '2-deepcad' / PLANE_ID + DEEPCAD_FILEPATH = DEEPCAD_DIR / f'{SUBJECT_ID}-{DATE_ID}-{PLANE_ID}_mc-dc.tif' + + # find tiffs + tiff_list = get_tiff_list(MOCO_DIR) + tiff_list_str = "\n".join(map(str,tiff_list)) + assert len(tiff_list) == 1, \ + f'Currently only accepting one file inside the motion-corrected directory "{MOCO_DIR}", ' \ + f'found the following: {tiff_list_str}' + + tiff_path = tiff_list[0] + with tifffile.TiffFile(str(tiff_path)) as tif: + num_frames = len(tif.pages) + page0 = tif.pages[0] + Ly, Lx = page0.shape + print(f'Found motion-corrected tiff file ({num_frames} frames of {Ly} x {Lx}): "{tiff_path}"') + + # configuration + with open(DEEPCAD_CFG_PATH, 'r') as f: + print(f'Config file loaded from: "{DEEPCAD_CFG_PATH}"') + dc_cfg = yaml.safe_load(f) + + dc_cfg['test_datasize'] = num_frames + dc_cfg['datasets_path'] = str(MOCO_DIR) + dc_cfg['output_dir'] = str(DEEPCAD_DIR) + + print(dc_cfg) + + # deepcad inference + tc = testing_class(dc_cfg) + tc.run() + + # get output file + out_tiff=list(map(str,DEEPCAD_DIR.glob('**/*.tif*'))) + out_tiff=[x for x in out_tiff if os.path.dirname(x) != str(DEEPCAD_DIR)] + out_tiff_str = "\n".join(map(str,out_tiff)) + assert len(out_tiff) == 1, \ + f'Currently only accepting one output file in deepcad directory for cleaning up "{DEEPCAD_DIR}", ' \ + f'found the following: {out_tiff_str}' + out_tiff = out_tiff[0] + + shutil.copyfile(out_tiff, DEEPCAD_FILEPATH) + + # clean up + if CLEAN_UP: + REMOVE_DIR = os.path.dirname(out_tiff) + shutil.rmtree(REMOVE_DIR) + + print(f'Finished with denoising using deepcad. Use "{DEEPCAD_FILEPATH}" to continue.') \ No newline at end of file diff --git a/cellreg/3-suite2p.py b/cellreg/3-suite2p.py new file mode 100644 index 0000000..50a4917 --- /dev/null +++ b/cellreg/3-suite2p.py @@ -0,0 +1,131 @@ +import os +import shutil +from pathlib import Path +import numpy as np +import suite2p +import argparse +import yaml +from matplotlib import pyplot as plt + +def get_tiff_list(folder, ext=['.tif', '.tiff']): + file_list = [] + for x in ext: + file_list.extend(list(folder.glob(f'*{x}'))) + return file_list + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Run suite2p') + + parser.add_argument("--root-datadir", type=str) + parser.add_argument("--subject", type=str) + parser.add_argument("--date", type=str) + parser.add_argument("--plane", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("--cleanup", action="store_true") + + args = parser.parse_args() + print(args) + + # process data paths + ROOT_DATA_DIR = args.root_datadir + SUBJECT_ID = args.subject + DATE_ID = args.date + PLANE_ID = args.plane + SUITE2P_CFG_PATH = args.config + CLEAN_UP = args.cleanup + + # define paths + + EXP_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID / DATE_ID + DEEPCAD_DIR = EXP_DIR / '2-deepcad' / PLANE_ID + SUITE2P_DIR = EXP_DIR / '3-suite2p' / PLANE_ID + AUX_SUITE2P_DIR = SUITE2P_DIR / 'suite2p-post-deepcad' + + # find tiffs (may be redundant) + tiff_list = get_tiff_list(DEEPCAD_DIR) + + # make db + db = { + 'data_path': str(DEEPCAD_DIR), + 'save_path0': str(AUX_SUITE2P_DIR), + 'tiff_list': tiff_list, + } + print('Data settings: ') + print(db) + + # make ops + ops = suite2p.default_ops() + with open(SUITE2P_CFG_PATH, 'r') as f: + print(f'Ops file loaded from: "{SUITE2P_CFG_PATH}"') + ops_cfg = yaml.safe_load(f) + ops.update(ops_cfg) + print('Suite2p configuration') + print(ops) + + # run suite2p + output_ops = suite2p.run_s2p(ops=ops, db=db) + print('Suite2p done!') + + # create figure directory + OUT_DIR = Path(output_ops['save_path']) + FIG_DIR = OUT_DIR / 'figures' + FIG_DIR.mkdir(parents=True, exist_ok=True) + + # plot and save backgrounds + plt.figure(figsize=(15,5)) + plt.subplot(131) + plt.imshow(output_ops['max_proj'], cmap='gray') + plt.title("Registered Image, Max Projection"); + + plt.subplot(132) + plt.imshow(output_ops['meanImg'], cmap='gray') + plt.title("Mean registered image") + + plt.subplot(133) + plt.imshow(output_ops['meanImgE'], cmap='gray') + plt.title("High-pass filtered Mean registered image") + + fig_bg_file = FIG_DIR / 'backgrounds.png' + plt.savefig(fig_bg_file, dpi=300) + print(f'Background images plotted in "{fig_bg_file}"') + + # get roi masks + stats_file = OUT_DIR / 'stat.npy' + iscell = np.load(OUT_DIR / 'iscell.npy', allow_pickle=True)[:, 0].astype(bool) + stats = np.load(stats_file, allow_pickle=True) + + # plot and save roi masks + im = suite2p.ROI.stats_dicts_to_3d_array(stats, Ly=output_ops['Ly'], Lx=output_ops['Lx'], label_id=True) + im[im == 0] = np.nan + + plt.figure(figsize=(20,8)) + plt.subplot(1, 4, 1) + plt.imshow(output_ops['max_proj'], cmap='gray') + plt.title("Registered Image, Max Projection") + + plt.subplot(1, 4, 2) + plt.imshow(np.nanmax(im, axis=0), cmap='jet') + plt.title("All ROIs Found") + + plt.subplot(1, 4, 3) + plt.imshow(np.nanmax(im[~iscell], axis=0, ), cmap='jet') + plt.title("All Non-Cell ROIs") + + plt.subplot(1, 4, 4) + plt.imshow(np.nanmax(im[iscell], axis=0), cmap='jet') + plt.title("All Cell ROIs") + + fig_roi_file = FIG_DIR / 'roi-masks.png' + plt.savefig(fig_roi_file, dpi=300) + print(f'ROI mask plotted in "{fig_roi_file}"') + + # move files + if CLEAN_UP: + SRC_OUT_DIR = output_ops['save_path'] + DST_OUT_DIR = str(SUITE2P_DIR) + print(f'Moving files from {SRC_OUT_DIR} to {DST_OUT_DIR}') + shutil.copytree(SRC_OUT_DIR, DST_OUT_DIR, copy_function=shutil.move, dirs_exist_ok=True) + shutil.rmtree(AUX_SUITE2P_DIR) + + + print(f'Finished with suite2p. Use {SUITE2P_DIR} to continue.') diff --git a/cellreg/4-roicat.py b/cellreg/4-roicat.py new file mode 100644 index 0000000..0673593 --- /dev/null +++ b/cellreg/4-roicat.py @@ -0,0 +1,731 @@ +import os +import glob +import shutil +from pathlib import Path +import copy +import multiprocessing as mp +import tempfile + +import numpy as np +import pandas as pd +from tqdm import tqdm + +import roicat + +import matplotlib.pyplot as plt +import seaborn as sns + +import argparse +import yaml +import pickle + + +PARAMS = { + 'um_per_pixel': 0.7, + 'background_max_percentile': 99.9, + 'suite2p': { # `roicat.data_importing.Data_suite2p` + 'new_or_old_suite2p': 'new', + 'type_meanImg': 'meanImg', + }, + 'fov_augment': { # `aligner.augment_FOV_images` + 'roi_FOV_mixing_factor': 0.5, + 'use_CLAHE': False, + 'CLAHE_grid_size': 1, + 'CLAHE_clipLimit': 1, + 'CLAHE_normalize': True, + }, + 'fit_geometric': { # `aligner.fit_geometric` + 'template': 0, + 'template_method': 'image', + 'mode_transform': 'affine', + 'mask_borders': (5,5,5,5), + 'n_iter': 1000, + 'termination_eps': 1e-6, + 'gaussFiltSize': 15, + 'auto_fix_gaussFilt_step':1, + }, + 'fit_nonrigid': { # `aligner.fit_nonrigid` + 'disable': True, + 'template': 0, + 'template_method': 'image', + 'mode_transform':'createOptFlow_DeepFlow', + 'kwargs_mode_transform':None, + }, + 'roi_blur': { + 'kernel_halfWidth': 2 + } + +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Run roicat for multi-session registration') + + parser.add_argument("--root-datadir", type=str) + parser.add_argument("--subject", type=str) + parser.add_argument("--plane", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("-g", "--use-gpu", action="store_true") + parser.add_argument("-v", "--verbose", action="store_true") + + parser.add_argument( + "--max-depth", type=int, default=6, + help='max depth to find suite2p files, relative to `--root-datadir`' + ) + + parser.add_argument( + "--suite2p-subdir", type=str, default='3-suite2p', + help='suite2p subdirectory' + ) + + parser.add_argument( + "--output-topdir", type=str, default='', + help='output top directory, if empty, will use the subject folder' + ) + + args = parser.parse_args() + print(args) + + # process data paths + ROOT_DATA_DIR = args.root_datadir + SUBJECT_ID = args.subject + PLANE_ID = args.plane + ROICAT_CFG_PATH = args.config + SUITE2P_PATH_MAXDEPTH=args.max_depth + USE_GPU=args.use_gpu + VERBOSITY = args.verbose + + OUTPUT_DIR=args.output_topdir + SUITE2P_SUBDIR=args.suite2p_subdir + + # define paths + SUBJECT_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID + if OUTPUT_DIR in [None, '']: + OUTPUT_DIR = SUBJECT_DIR + COLLECTIVE_MUSE_DIR = OUTPUT_DIR / 'multi-session' / PLANE_ID + COLLECTIVE_MUSE_FIG_DIR = COLLECTIVE_MUSE_DIR / 'figures' + COLLECTIVE_MUSE_FIG_DIR.mkdir(parents=True, exist_ok=True) + + # find suite2p paths + dir_allOuterFolders = str(SUBJECT_DIR) + pathSuffixToStat = 'stat.npy' + pathSuffixToOps = 'ops.npy' + pathShouldHave = fr'{SUITE2P_SUBDIR}/{PLANE_ID}' + + paths_allStat = roicat.helpers.find_paths( + dir_outer=dir_allOuterFolders, + reMatch=pathSuffixToStat, + reMatch_in_path=pathShouldHave, + depth=SUITE2P_PATH_MAXDEPTH, + )[:] + + paths_allStat = [ + x for x in paths_allStat + if pathShouldHave in x + ] + + paths_allOps = np.array([ + Path(path).resolve().parent / pathSuffixToOps + for path in paths_allStat + ])[:] + + + print('Paths to all suite2p STAT files:') + print('\n'.join(['\t- ' + str(x) for x in paths_allStat])) + print('\n') + print('Paths to all suite2p OPS files:') + print('\n'.join(['\t- ' + str(x) for x in paths_allOps])) + print('\n') + + # load data + data = roicat.data_importing.Data_suite2p( + paths_statFiles=paths_allStat[:], + paths_opsFiles=paths_allOps[:], + um_per_pixel=PARAMS['um_per_pixel'], + type_meanImg='meanImg', # will be overwritten in the following cell + **{k: v for k, v in PARAMS['suite2p'].items() if k not in ['type_meanImg']}, + verbose=VERBOSITY, + ) + + assert data.check_completeness(verbose=False)['tracking'],\ + "Data object is missing attributes necessary for tracking." + + # also save iscell paths + data.paths_iscell = [ + Path(x).parent / 'iscell.npy' + for x in data.paths_ops + ] + + # load all background images + background_types = [ + 'meanImg', + 'meanImgE', + 'max_proj', + 'Vcorr', + ] + + FOV_backgrounds = {k: [] for k in background_types} + for ops_file in data.paths_ops: + ops = np.load(ops_file, allow_pickle=True).item() + + im_sz = (ops['Ly'], ops['Lx']) + for bg in background_types: + bg_im = ops[bg] + + if bg_im.shape == im_sz: + FOV_backgrounds[bg].append(bg_im) + continue + + print( + f'\t- File {ops_file}: {bg} shape is {bg_im.shape}, which is cropped from {im_sz}. '\ + '\n\tWill attempt to add empty pixels to recover the original shape.' + ) + + im = np.zeros(im_sz).astype(bg_im.dtype) + cropped_xrange, cropped_yrange = ops['xrange'], ops['yrange'] + im[ + cropped_yrange[0]:cropped_yrange[1], + cropped_xrange[0]:cropped_xrange[1] + ] = bg_im + + FOV_backgrounds[bg].append(im) + + # choice of FOV images to align + data.FOV_images = FOV_backgrounds[PARAMS['suite2p']['type_meanImg']] + + # obtain FOVs + aligner = roicat.tracking.alignment.Aligner(verbose=VERBOSITY) + + FOV_images = aligner.augment_FOV_images( + ims=data.FOV_images, + spatialFootprints=data.spatialFootprints, + **PARAMS['fov_augment'] + ) + + # ALIGN FOV + DISABLE_NONRIGID = PARAMS['fit_nonrigid'].pop('disable') + + # geometric fit + aligner.fit_geometric( + ims_moving=FOV_images, + **PARAMS['fit_geometric'] + ) + aligner.transform_images_geometric(FOV_images) + remap_idx = aligner.remappingIdx_geo + + # non-rigid + if not DISABLE_NONRIGID: + aligner.fit_nonrigid( + ims_moving=aligner.ims_registered_geo, + remappingIdx_init=aligner.remappingIdx_geo, + **PARAMS['fit_nonrigid'] + ) + aligner.transform_images_nonrigid(FOV_images) + remap_idx = aligner.remappingIdx_nonrigid + + # transform ROIs + aligner.transform_ROIs( + ROIs=data.spatialFootprints, + remappingIdx=remap_idx, + normalize=True, + ) + + # transform other backgrounds + aligned_backgrounds = {k: [] for k in background_types} + for bg in background_types: + aligned_backgrounds[bg] = aligner.transform_images( + FOV_backgrounds[bg], + remappingIdx=remap_idx + ) + + plt.figure(figsize=(20,20), layout='tight') + types2plt = background_types + ['ROI'] + nrows = len(types2plt) + ncols = data.n_sessions + + splt_cnt = 1 + for k in types2plt: + image_list = aligned_backgrounds.get(k, aligner.get_ROIsAligned_maxIntensityProjection()) + for s_id, img in enumerate(image_list): + plt.subplot(nrows, ncols, splt_cnt) + plt.imshow( + img, cmap='Greys_r', + vmax=np.percentile( + img, + PARAMS['background_max_percentile'] if k!= "ROI" else 95 + ) + ) + plt.axis('off') + plt.title(f'Aligned {k} [#{s_id}]') + splt_cnt += 1 + + plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-fov.png') + + # BUILD FEATUREs + + # blur ROI + blurrer = roicat.tracking.blurring.ROI_Blurrer( + frame_shape=(data.FOV_height, data.FOV_width), + plot_kernel=False, + verbose=VERBOSITY, + **PARAMS['roi_blur'] + ) + + blurrer.blur_ROIs( + spatialFootprints=aligner.ROIs_aligned[:], + ) + + # ROInet embedding + # TODO: Parameterize `ROInet_embedder`, `generate_dataloader` + DEVICE = roicat.helpers.set_device(use_GPU=USE_GPU, verbose=VERBOSITY) + dir_temp = tempfile.gettempdir() + + roinet = roicat.ROInet.ROInet_embedder( + device=DEVICE, + dir_networkFiles=dir_temp, + download_method='check_local_first', + download_url='https://osf.io/x3fd2/download', + download_hash='7a5fb8ad94b110037785a46b9463ea94', + forward_pass_version='latent', + verbose=VERBOSITY + ) + + roinet.generate_dataloader( + ROI_images=data.ROI_images, + um_per_pixel=data.um_per_pixel, + pref_plot=False, + jit_script_transforms=False, + batchSize_dataloader=8, + pinMemory_dataloader=True, + numWorkers_dataloader=4, + persistentWorkers_dataloader=True, + prefetchFactor_dataloader=2, + ) + + roinet.generate_latents() + + # Scattering wavelet embedding + # TODO: Parameterize `SWT`, `SWT.transform` + swt = roicat.tracking.scatteringWaveletTransformer.SWT( + kwargs_Scattering2D={'J': 3, 'L': 12}, + image_shape=data.ROI_images[0].shape[1:3], + device=DEVICE, + ) + + swt.transform( + ROI_images=roinet.ROI_images_rs, + batch_size=100, + ) + + # Compute similarities + # TODO: Parameterize `ROI_graph`, `compute_similarity_blockwise`, `make_normalized_similarities` + + sim = roicat.tracking.similarity_graph.ROI_graph( + n_workers=-1, + frame_height=data.FOV_height, + frame_width=data.FOV_width, + block_height=128, + block_width=128, + algorithm_nearestNeigbors_spatialFootprints='brute', + verbose=VERBOSITY, + ) + + s_sf, s_NN, s_SWT, s_sesh = sim.compute_similarity_blockwise( + spatialFootprints=blurrer.ROIs_blurred, + features_NN=roinet.latents, + features_SWT=swt.latents, + ROI_session_bool=data.session_bool, + spatialFootprint_maskPower=1.0, + ) + + sim.make_normalized_similarities( + centers_of_mass=data.centroids, + features_NN=roinet.latents, + features_SWT=swt.latents, + k_max=data.n_sessions*100, + k_min=data.n_sessions*10, + algo_NN='kd_tree', + device=DEVICE, + ) + + # Clustering + # TODO: Parameterize `find_optimal_parameters_for_pruning`? + clusterer = roicat.tracking.clustering.Clusterer( + s_sf=sim.s_sf, + s_NN_z=sim.s_NN_z, + s_SWT_z=sim.s_SWT_z, + s_sesh=sim.s_sesh, + ) + + kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning( + n_bins=None, + smoothing_window_bins=None, + kwargs_findParameters={ + 'n_patience': 300, + 'tol_frac': 0.001, + 'max_trials': 1200, + 'max_duration': 60*10, + }, + bounds_findParameters={ + 'power_NN': (0., 5.), + 'power_SWT': (0., 5.), + 'p_norm': (-5, 0), + 'sig_NN_kwargs_mu': (0., 1.0), + 'sig_NN_kwargs_b': (0.00, 1.5), + 'sig_SWT_kwargs_mu': (0., 1.0), + 'sig_SWT_kwargs_b': (0.00, 1.5), + }, + n_jobs_findParameters=-1, + ) + + kwargs_mcdm_tmp = kwargs_makeConjunctiveDistanceMatrix_best ## Use the optimized parameters + + clusterer.plot_distSame(kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp) + plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-distrib.png') + + clusterer.plot_similarity_relationships( + plots_to_show=[1,2,3], + max_samples=100000, ## Make smaller if it is running too slow + kwargs_scatter={'s':1, 'alpha':0.2}, + kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp + ); + plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-scatter.png') + + clusterer.make_pruned_similarity_graphs( + d_cutoff=None, + kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp, + stringency=1.0, + convert_to_probability=False, + ) + + if data.n_sessions >= 8: + labels = clusterer.fit( + d_conj=clusterer.dConj_pruned, + session_bool=data.session_bool, + min_cluster_size=2, + n_iter_violationCorrection=3, + split_intraSession_clusters=True, + cluster_selection_method='leaf', + d_clusterMerge=None, + alpha=0.999, + discard_failed_pruning=False, + n_steps_clusterSplit=100, + ) + + else: + labels = clusterer.fit_sequentialHungarian( + d_conj=clusterer.dConj_pruned, ## Input distance matrix + session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions + thresh_cost=0.6, ## Threshold + ) + + quality_metrics = clusterer.compute_quality_metrics() + + labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict = roicat.tracking.clustering.make_label_variants(labels=labels, n_roi_bySession=data.n_roi) + + results = { + "clusters":{ + "labels": labels_squeezed, + "labels_bySession": labels_bySession, + "labels_bool": labels_bool, + "labels_bool_bySession": labels_bool_bySession, + "labels_dict": labels_dict, + }, + "ROIs": { + "ROIs_aligned": aligner.ROIs_aligned, + "ROIs_raw": data.spatialFootprints, + "frame_height": data.FOV_height, + "frame_width": data.FOV_width, + "idx_roi_session": np.where(data.session_bool)[1], + "n_sessions": data.n_sessions, + }, + "input_data": { + "paths_stat": data.paths_stat, + "paths_ops": data.paths_ops, + }, + "quality_metrics": clusterer.quality_metrics if hasattr(clusterer, 'quality_metrics') else None, + } + + run_data = copy.deepcopy({ + 'data': data.serializable_dict, + 'aligner': aligner.serializable_dict, + 'blurrer': blurrer.serializable_dict, + 'roinet': roinet.serializable_dict, + 'swt': swt.serializable_dict, + 'sim': sim.serializable_dict, + 'clusterer': clusterer.serializable_dict, + }) + + iscell_bySession = [np.load(ic_p)[:,0] for ic_p in data.paths_iscell] + + with open(COLLECTIVE_MUSE_DIR / 'roicat-output.pkl', 'wb') as f: + pickle.dump(dict( + run_data = run_data, + results = results, + iscell = iscell_bySession + ), f) + + print(f'Number of clusters: {len(np.unique(results["clusters"]["labels"]))}') + print(f'Number of discarded ROIs: {(results["clusters"]["labels"]==-1).sum()}') + + # Visualize + confidence = (((results['quality_metrics']['cluster_silhouette'] + 1) / 2) * results['quality_metrics']['cluster_intra_means']) + + fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7)) + + axs[0,0].hist(results['quality_metrics']['cluster_silhouette'], 50); + axs[0,0].set_xlabel('cluster_silhouette'); + axs[0,0].set_ylabel('cluster counts'); + + axs[0,1].hist(results['quality_metrics']['cluster_intra_means'], 50); + axs[0,1].set_xlabel('cluster_intra_means'); + axs[0,1].set_ylabel('cluster counts'); + + axs[1,0].hist(confidence, 50); + axs[1,0].set_xlabel('confidence'); + axs[1,0].set_ylabel('cluster counts'); + + axs[1,1].hist(results['quality_metrics']['sample_silhouette'], 50); + axs[1,1].set_xlabel('sample_silhouette score'); + axs[1,1].set_ylabel('roi sample counts'); + + fig.savefig(COLLECTIVE_MUSE_FIG_DIR / 'cluster-metrics.png') + + # FOV clusters + FOV_clusters = roicat.visualization.compute_colored_FOV( + spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], + FOV_height=results['ROIs']['frame_height'], + FOV_width=results['ROIs']['frame_width'], + labels=results["clusters"]["labels_bySession"], ## cluster labels + # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array + alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4), + # alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array + ) + + FOV_clusters_with_iscell = roicat.visualization.compute_colored_FOV( + spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], ## Spatial footprint sparse arrays + FOV_height=results['ROIs']['frame_height'], + FOV_width=results['ROIs']['frame_width'], + labels=results["clusters"]["labels_bySession"], ## cluster labels + # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array + alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4), + alphas_sf=iscell_bySession + # alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array + ) + + roicat.helpers.save_gif( + array=FOV_clusters, + path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_allrois.gif'), + frameRate=5.0, + loop=0, + ) + + roicat.helpers.save_gif( + array=FOV_clusters_with_iscell, + path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_iscells.gif'), + frameRate=5.0, + loop=0, + ) + + plt.figure(figsize=(20,10), layout='tight') + roi_image_dict = { + 'all': FOV_clusters, + 'iscell': FOV_clusters_with_iscell + } + nrows = len(roi_image_dict) + ncols = data.n_sessions + + splt_cnt = 1 + for k, image_list in roi_image_dict.items(): + for s_id, img in enumerate(image_list): + plt.subplot(nrows, ncols, splt_cnt) + plt.imshow(img) + plt.axis('off') + plt.title(f'Aligned {k} [#{s_id}]') + splt_cnt += 1 + + plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-rois.png') + + # save FOVs + num_sessions = data.n_sessions + out_img = [] + for d in range(num_sessions): + out_img.append(dict( + fov = aligner.ims_registered_geo[d], + roi_pre_iscell = FOV_clusters[d], + roi_with_iscell = FOV_clusters_with_iscell[d], + **{k: v[d] for k,v in aligned_backgrounds.items()}, + )) + + with open(COLLECTIVE_MUSE_DIR / 'aligned-images.pkl', 'wb') as f: + pickle.dump(out_img, f) + + + # save summary data + df = pd.DataFrame([ + dict( + session=i, + global_roi=glv, + session_roi=range(len(glv)), + iscell = iscv + ) + for i, (glv, iscv) in enumerate(zip(labels_bySession, iscell_bySession)) + ]).explode(['global_roi','session_roi', 'iscell']).astype({'iscell': 'bool'}) + + df.to_csv(COLLECTIVE_MUSE_DIR / 'summary-roi.csv', index=False) + + # process only iscell for stats + df = ( + df.query('iscell') + .reset_index(drop=True) + ) + + df = df.merge( + ( + df.query('global_roi >= 0') + .groupby('global_roi') + ['session'].agg(lambda x: len(list(x))) + .to_frame('num_sessions') + .reset_index() + ), + how='left' + ) + + df = ( + df.fillna({'num_sessions': 1}) + .astype({'num_sessions': 'int'}) + ) + + # re-indexing + persistent_roi_reindices = ( + df[['num_sessions', 'global_roi']] + .query('global_roi >= 0 and num_sessions > 1') + .drop_duplicates() + .sort_values('num_sessions', ascending=False) + .reset_index(drop=True) + .reset_index() + .set_index('global_roi') + ['index'].to_dict() + ) + + df['reindexed_global_roi'] = df['global_roi'].map(persistent_roi_reindices) + + single_roi_start_indices = df['reindexed_global_roi'].max() + 1 + single_roi_rows = df.query('reindexed_global_roi.isna()').index + num_single_rois = len(single_roi_rows) + + df.loc[single_roi_rows, 'reindexed_global_roi'] = \ + np.arange(num_single_rois) + single_roi_start_indices + + df['reindexed_global_roi'] = df['reindexed_global_roi'].astype('int') + df = df.rename(columns={ + 'global_roi': 'roicat_global_roi', + 'reindexed_global_roi': 'global_roi' + }) + + df.to_csv(COLLECTIVE_MUSE_DIR / 'finalized-roi.csv', index=False) + + # plot persistent ROIs summary + persist_rois = ( + df + .drop_duplicates(['global_roi']) + .value_counts('num_sessions', sort=False) + .to_frame('num_rois') + .reset_index() + ) + + plt.figure(figsize=(4,5)) + ax = sns.barplot( + persist_rois, + x = 'num_sessions', + y = 'num_rois', + hue = 'num_sessions', + facecolor = '#afafaf', + dodge=False, + edgecolor='k' + ) + sns.despine(trim=True, offset=10) + + plt.legend([], [], frameon=False) + [ax.bar_label(c, padding=5, fontsize=10) for c in ax.containers] + plt.xlabel('# sessions') + plt.ylabel('# rois') + plt.title('Persisted ROIs') + plt.tight_layout() + plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-overall.png') + + # plot persistent ROIs per sessions + df_sessions = ( + df + .value_counts(['session','num_sessions']) + .to_frame('count') + .reset_index() + ) + + df_total_per_session = ( + df_sessions + .groupby('session') + ['count'].agg('sum') + .to_frame('total_count') + .reset_index() + ) + + df_sessions = df_sessions.merge(df_total_per_session, how='left') + df_sessions['percent'] = 100 * df_sessions['count'] / df_sessions['total_count'] + + plt.figure(figsize=(10,5)) + bar_kwargs = dict( + kind='bar', + stacked=True, + colormap='GnBu', + width=0.7, + edgecolor='k', + ) + + ax1 = plt.subplot(121) + + ( + df_sessions + .pivot(index='session',columns='num_sessions', values='count') + .fillna(0) + .plot( + **bar_kwargs, + xlabel='session ID', + ylabel='# rois', + legend=False, + ax=ax1) + ) + plt.tick_params(rotation=0) + + ax2 = plt.subplot(122) + + ( + df_sessions + .pivot(index='session',columns='num_sessions', values='percent') + .fillna(0) + .plot( + **bar_kwargs, + xlabel='session ID', + ylabel='% roi per session', + ax=ax2 + ) + ) + plt.tick_params(rotation=0) + + leg_handles, leg_labels = plt.gca().get_legend_handles_labels() + plt.legend( + reversed(leg_handles), + reversed(leg_labels), + loc='upper right', + bbox_to_anchor=[1.5,1], + title='# sessions', + edgecolor='k', + ) + + sns.despine(trim=True, offset=10) + + plt.suptitle( + 'Distribution of detected and aligned ROIs across sessions', + ) + plt.tight_layout(w_pad=5) + plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-per-session.png') + diff --git a/cellreg/deepcad/test_collection.py b/cellreg/deepcad/test_collection.py index 56e3c0c..3fd1e28 100755 --- a/cellreg/deepcad/test_collection.py +++ b/cellreg/deepcad/test_collection.py @@ -135,6 +135,7 @@ def read_imglist(self): def read_modellist(self): model_path = self.pth_dir + "/" + self.denoise_model + print("MODEL PATH: ", model_path) model_list = list(os.walk(model_path, topdown=False))[-1][-1] model_list.sort() diff --git a/config/1-moco.yml b/config/1-moco.yml new file mode 100644 index 0000000..711dc97 --- /dev/null +++ b/config/1-moco.yml @@ -0,0 +1,11 @@ +# do_registration: 1 # motion correction, will be overwritten in script anyway to have `do_registration: 1` +# reg_tif: True # save tiff registration in batches, will be overwritten in script anyway to have `reg_tif: True` +# roidetect: False # whether to do segmentation after registration, will be overwritten in script anyway to have `roidetect: False` +two_step_registration: 1.0 # important for registration, esp. for low SNR +nimg_init: 1000 # num frames initial to use as reference for registration +block_size: [256, 256] # matters for registration +nonrigid: True # important for registration +keep_movie_raw: True # whether to keep raw binary file +fs: 5 # estimates across planes +combined: False # whether to combine across planes +tau: 0.57 # calcium dynamics for deconv, possibly not correct \ No newline at end of file diff --git a/config/2-deepcad.yml b/config/2-deepcad.yml new file mode 100644 index 0000000..d9053ee --- /dev/null +++ b/config/2-deepcad.yml @@ -0,0 +1,11 @@ +patch_x: 150 +patch_y: 150 +patch_t: 150 +overlap_factor: 0.6 +scale_factor: 1 +pth_dir: "/oscar/data/afleisc2/collab/multiday-reg/MultiSession/deepcad_model" +fmap: 16 +GPU: '0' +num_workers: 0 +visualize_images_per_epoch: False +save_test_images_per_epoch: True \ No newline at end of file diff --git a/config/3-suite2p.yml b/config/3-suite2p.yml new file mode 100644 index 0000000..6b7a8c0 --- /dev/null +++ b/config/3-suite2p.yml @@ -0,0 +1,29 @@ +do_registration: True +two_step_registration: 1.0 # important for registration, esp. for low SNR +nimg_init: 1000 # num frames initial to use as reference for registration +block_size: [256, 256] # matters for registration +nonrigid: True # important for registration +reg_tif: False # save tiff registration in batches +keep_movie_raw: False # whether to keep raw binary file +fs: 50 # HACK! estimates across planes, currently artificially increase, see https://github.com/MouseLand/suite2p/issues/1073 +combined: False # whether to combine across planes +roidetect: True # whether to do segmentation after registration +tau: 0.57 # calcium dynamics for deconv, possibly not correct +# 1Preg: False # DEFAULT, not relevant for 2P +# spatial_hp_reg: 42.0 # DEFAULT, not relevant for 2P +# snr_thresh: 1.2 # DEFAULT, SNR for non-rigid registration +maxregshift: 0.01 # see https://github.com/MouseLand/suite2p/issues/273 +maxregshiftNR: 4.0 # relevant for registration +# smooth_sigma: 1.15 # DEFAULT, relevant for only rigid +# smooth_sigma_time: 0.0 # DEFAULT, relevant for only rigid +# maxregshift: 0.1 # DEFAULT, relevant for non-rigid +# reg_tif_chan2: False # DEFAULT, not relevant, whether to write reg binary of chan2 to tiff file +pre_smooth: 2.0 # UNCLEAR WHETHER or not relevant for 2p registration, can turn to 0.0 +sparse_mode: False +# spatial_scale: 2 # segmentation +max_iterations: 100 # segmentation +threshold_scaling: 2.0 # segmentation +max_overlap: 0.95 # segmentation +# denoise: 1.0 # segmentation ; turn this off post-deepcad +# flow_threshold: 0.0 # use default instead, which is 1.5 +anatomical_only: 1 # segmentation (whether or not to use cellpose) diff --git a/deepcad_model/E_05_Iter_3136.pth b/deepcad_model/E_05_Iter_3136.pth old mode 100755 new mode 100644 diff --git a/deepcad_model/para.yaml b/deepcad_model/para.yaml old mode 100755 new mode 100644 diff --git a/environments/deepcad.yml b/environments/deepcad.yml new file mode 100644 index 0000000..420510a --- /dev/null +++ b/environments/deepcad.yml @@ -0,0 +1,36 @@ +name: deepcad +channels: + - conda-forge + - defaults +dependencies: + - python=3.9 + - ipykernel + - ipython + - pip + - pip: + - -f https://download.pytorch.org/whl/cu111/torch_stable.html + - beautifulsoup4 + - matplotlib + - networkx + - numpy + - pandas + - toml + - tomli + - tqdm + - csbdeep + - gdown + - opencv-python + - opencv-python-headless + - pillow + - pyyaml + - requests + - scikit-image + - scipy + - tifffile + - typer + - typer-cli + - urllib3 + - natsort + - "torch==1.10.1+cu111" + - "torchaudio==0.10.1+cu111" + - "torchvision==0.11.2+cu111" diff --git a/environments/roicat.yml b/environments/roicat.yml new file mode 100644 index 0000000..5d3d3ba --- /dev/null +++ b/environments/roicat.yml @@ -0,0 +1,13 @@ +name: roicat +channels: + - conda-forge + - defaults +dependencies: + - python=3.11 + - ipykernel + - ipython + - pip + - pip: + - "roicat[all]" + - "git+https://github.com/RichieHakim/roiextractors" + diff --git a/environments/suite2p.yml b/environments/suite2p.yml new file mode 100644 index 0000000..72bc834 --- /dev/null +++ b/environments/suite2p.yml @@ -0,0 +1,12 @@ +name: suite2p +channels: + - conda-forge + - defaults +dependencies: + - "python=3.9" + - pip + - ipykernel + - pip: + - "suite2p" + - pyyaml + diff --git a/notebooks/convert-bin2tiff.ipynb b/notebooks/convert-bin2tiff.ipynb new file mode 100644 index 0000000..7bd01ce --- /dev/null +++ b/notebooks/convert-bin2tiff.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d197cc2c-bc68-485a-ae34-d54ed2e4a582", + "metadata": {}, + "source": [ + "# Convert `suite2p` binary files to tiff stacks" + ] + }, + { + "cell_type": "markdown", + "id": "ea818e9d-0892-4517-9fbc-283a5c81cc24", + "metadata": {}, + "source": [ + "This is a helper notebook to convert the `suite2p` binary files to tiff stacks (in the same folder).\n", + "\n", + "| Binary input file | Tiff output file | \n", + "| --: | :-- | \n", + "| `data.bin` | `chan1.tiff` | \n", + "| `data_chan2.bin` | `chan2.tiff` |\n", + "| `data_raw.bin` | `chan1_raw.tiff` | \n", + "| `data_chan2_raw.bin` | `chan2_raw.tiff` |\n", + "\n", + "This can be run on any environment as long as it has the necessary packages. If not do:\n", + "\n", + "```bash\n", + "pip install numpy tifffile tqdm\n", + "```\n", + "\n", + "This is preferably running in parallel for faster conversion of multiple files.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb6c11d4-ef92-4e25-a152-05794f15a72d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import glob\n", + "\n", + "import numpy as np\n", + "from tifffile import TiffWriter\n", + "from tqdm.notebook import tqdm\n", + "from tqdm.contrib.concurrent import process_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "038bc5ca-c48a-4dee-a5b4-a1945459a290", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# detect ops files with a certain path pattern\n", + "# alternatively, define `ops_files` as a list pointing to the appropriate `ops.npy` files\n", + "main_dir = '/oscar/data/afleisc2/collab/multiday-reg/data/SD_*'\n", + "\n", + "ops_files = sorted(glob.glob(os.path.join(main_dir, '**/ops.npy'), recursive=True))\n", + "ops_files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7de0e424-f76a-42b0-89bc-eac6d4b8f356", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "len(ops_files)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66ab5705-9542-4b0a-8aa8-d5bd7752e4eb", + "metadata": {}, + "outputs": [], + "source": [ + "def bin2tiff(bin_file, tiff_file, ops_file, tqdm_kwargs=dict()):\n", + " \"\"\"Convert a binary mmap file to tiff file from suite2p\"\"\"\n", + " # metadata\n", + " ops = np.load(ops_file, allow_pickle=True).item()\n", + " n_frames, Ly, Lx = ops['nframes'], ops['Ly'], ops['Lx']\n", + "\n", + " # read in binary file\n", + " memmap_obj = np.memmap(\n", + " bin_file,\n", + " mode='r',\n", + " dtype='int16',\n", + " shape=(n_frames, Ly, Lx)\n", + " )\n", + " \n", + " # write to tiff file\n", + " with TiffWriter(tiff_file, bigtiff=True) as f:\n", + " for i in tqdm(range(n_frames), **tqdm_kwargs):\n", + " curr_frame = np.floor(memmap_obj[i]).astype(np.int16)\n", + " f.write(curr_frame, contiguous=True)\n", + " \n", + " # close binary file\n", + " memmap_obj._mmap.close()\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "158b56ec-8cc3-4b1b-ba2c-5dd125908d2b", + "metadata": {}, + "outputs": [], + "source": [ + "def bin2tiff_one_ops(ops_file):\n", + " \"\"\"Conversion function for parallel\"\"\"\n", + " \n", + " ops_dir = os.path.dirname(ops_file)\n", + " fluo_files = {\n", + " 'data.bin': 'chan1.tiff', \n", + " 'data_chan2.bin': 'chan2.tiff',\n", + " 'data_raw.bin': 'chan1_raw.tiff', \n", + " 'data_chan2_raw.bin': 'chan2_raw.tiff',\n", + " }\n", + " \n", + " for bin_file, tiff_file in fluo_files.items():\n", + " bin_file = os.path.join(ops_dir, bin_file)\n", + " tiff_file = os.path.join(ops_dir, tiff_file)\n", + " if (\n", + " (not os.path.exists(bin_file))\n", + " or os.path.exists(tiff_file)\n", + " ):\n", + " continue\n", + " \n", + " bin2tiff(\n", + " bin_file,\n", + " tiff_file,\n", + " ops_file,\n", + " tqdm_kwargs=dict(\n", + " disable = True\n", + " )\n", + " )\n", + "\n", + "process_map(\n", + " bin2tiff_one_ops,\n", + " ops_files,\n", + " max_workers=16\n", + ")\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "roicat", + "language": "python", + "name": "roicat" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "scenes_data": { + "active_scene": "Default Scene", + "init_scene": "", + "scenes": [ + "Default Scene" + ] + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/convert-flowregtiff2suite2ptiff.ipynb b/notebooks/convert-flowregtiff2suite2ptiff.ipynb new file mode 100644 index 0000000..711cf1a --- /dev/null +++ b/notebooks/convert-flowregtiff2suite2ptiff.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5c04ceeb-e83d-4118-8843-4b0e9afd465a", + "metadata": {}, + "source": [ + "# Convert 2-channel TIFF from `flow_registration` to `suite2p` \n", + "\n", + "This notebook converts the TIFF stack output from `flow_registration` with dimension `frame x spatial_1 x spatial_2 x channel` to what `suite2p` expects for a 2-channel TIFF stack: `frame x channel x spatial_1 x spatial_2`. \n", + "\n", + "This basically just transpose the dimensions and saves as an OME.TIFF\n", + "\n", + "## Requirements\n", + "\n", + "- `numpy`\n", + "- `pyometiff`\n", + "- `scikit-image`\n", + "\n", + "If these are not installed, do this in the activated environment:\n", + "\n", + "```bash\n", + "pip install numpy pyometiff scikit-image\n", + "```\n", + "\n", + "## Caution on memory\n", + "\n", + "This current reads in the whole TIFF stack in memory instead doing it in frame batches, so you should sufficient memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db413ba4-ece8-451e-a8a6-8b96961be99d", + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install scikit-image pyometiff numpy\n", + "\n", + "import os\n", + "import glob\n", + "\n", + "import numpy as np\n", + "import skimage.io as skio\n", + "from pyometiff import OMETIFFWriter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee4e1ffe-3cfd-42fb-9f92-6dd843470547", + "metadata": {}, + "outputs": [], + "source": [ + "# define paths\n", + "input_tiff_path = \"PATH/TO/FLOW/REG/compensated.TIFF\" # input\n", + "output_tiff_path = \"out-data/flowreg_tcyx_4s2p.ome.tiff\" # output\n", + "output_dim_order = \"TCYX\" # what suite2p wants; don't change" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aacdafc4-e405-4c51-a6c2-f184047bdff6", + "metadata": {}, + "outputs": [], + "source": [ + "# read\n", + "X = skio.imread(input_tiff_path) # flowreg is \"TYXC\"\n", + "X = np.transpose(X, (0,3,1,2))\n", + "\n", + "# write\n", + "writer = OMETIFFWriter(\n", + " fpath=output_tiff_path,\n", + " dimension_order=output_dim_order,\n", + " array=X,\n", + " metadata=dict(),\n", + " explicit_tiffdata=False,\n", + " bigtiff=True\n", + ")\n", + "\n", + "writer.write()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "scenes_data": { + "active_scene": "Default Scene", + "init_scene": "", + "scenes": [ + "Default Scene" + ] + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/inspect-muse.ipynb b/notebooks/inspect-muse.ipynb new file mode 100644 index 0000000..da99a07 --- /dev/null +++ b/notebooks/inspect-muse.ipynb @@ -0,0 +1,429 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "39dd752a-c720-4f52-a149-62668664bfd5", + "metadata": {}, + "source": [ + "# Inspect multisession outputs\n", + "\n", + "This is currently being tested to inspect the output of processed `roicat` outputs.\n", + "\n", + "This currently uses `tk` and cannot run on Oscar headless.\n", + "\n", + "Note: this is now superseded by `refine-muse.ipynb`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "020f3892-fe74-493e-81c3-eb99e5a10adc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import glob\n", + "import warnings\n", + "\n", + "from pathlib import Path\n", + "import pickle\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import skimage\n", + "import matplotlib.pyplot as plt\n", + "\n", + "%matplotlib tk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86808689-73d8-45a2-bdf4-8928b39fa932", + "metadata": {}, + "outputs": [], + "source": [ + "muse_dir = Path('data/SD_0664/multi-session/plane1/')\n", + "roicat_out_file = muse_dir / 'roicat-output.pkl'\n", + "aligned_img_file = muse_dir / 'aligned-images.pkl'\n", + "roi_table_file = muse_dir / 'finalized-roi.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37fe3dc0-6aec-4c7b-91f1-5579991bb8b0", + "metadata": {}, + "outputs": [], + "source": [ + "fuse_cfg = dict(\n", + " background_choice = 'max_proj', # max_proj, meanImg, meanImgE, Vcorr\n", + " fuse_alpha = 0.4, # fusing ROI mask with background\n", + " fuse_bg_max_p = 98, # for clipping background max value with percentile\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f967bc2-e41d-45aa-806a-400fb7a9a7ec", + "metadata": {}, + "outputs": [], + "source": [ + "!ls $muse_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4af7a27e-e508-4ea9-a13f-be920421fd49", + "metadata": {}, + "outputs": [], + "source": [ + "with open(roicat_out_file, 'rb') as f:\n", + " roicat_out = pickle.load(f)\n", + "\n", + "with open(aligned_img_file, 'rb') as f:\n", + " aligned_images = pickle.load(f)\n", + "\n", + "aligned_rois = roicat_out['results']['ROIs']['ROIs_aligned']\n", + "\n", + "# TODO: unclear if this is the right order\n", + "image_dims = (\n", + " roicat_out['results']['ROIs']['frame_width'],\n", + " roicat_out['results']['ROIs']['frame_height']\n", + ")\n", + "roi_table = pd.read_csv(roi_table_file)\n", + "\n", + "num_sessions = len(aligned_rois)\n", + "assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))\n", + "\n", + "num_global_rois = roi_table['global_roi'].nunique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b03da632-c02f-42f2-8171-3fbd6ebcd0d4", + "metadata": {}, + "outputs": [], + "source": [ + "global_roi_colors = np.random.rand(num_global_rois, 3).round(3)\n", + "\n", + "roi_table['global_roi_color'] = roi_table['global_roi'].apply(\n", + " lambda x: global_roi_colors[x]\n", + ")\n", + "roi_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "532ad296-d743-435c-953f-a8f8521acfe4", + "metadata": {}, + "outputs": [], + "source": [ + "def throw_invalid_footprint_warning(footprints, roi_table, session):\n", + " invalid_max_idx = footprints.max(axis=1).toarray().squeeze() <= 0\n", + " if sum(invalid_max_idx) == 0:\n", + " return\n", + " \n", + " invalid_max_idx = np.where(invalid_max_idx)[0]\n", + " \n", + " roi_record = roi_table.query(\n", + " 'session == @session and ' \\\n", + " 'session_roi in @invalid_max_idx and' \\\n", + " '(num_sessions > 1 or roicat_global_roi >= 0)' \n", + " )\n", + " if len(roi_record) > 0:\n", + " warnings.warn(\n", + " 'The following ROI records have invalid footprints, i.e. '\\\n", + " '(a) max = 0, (b) persist for more than 1 sessions or `roicat_global_roi=-1`:',\n", + " roi_record.drop(columns='global_roi_color').to_dict('records')\n", + " )\n", + " \n", + "def select_session_rois(footprints, roi_table, session):\n", + " session_iscell_rois = (\n", + " roi_table.query('session == @session')\n", + " .sort_values(by='session_roi')\n", + " .reset_index(drop=True)\n", + " )\n", + " roi_colors = np.stack(session_iscell_rois['global_roi_color'].to_list())\n", + " roi_iscell_idx = session_iscell_rois['session_roi'].to_list()\n", + "\n", + " footprints = footprints[roi_iscell_idx]\n", + " assert footprints.shape[0] == len(roi_iscell_idx)\n", + " \n", + " return footprints, roi_colors \n", + "\n", + "def normalize_footprint(X):\n", + " max_X = X.max(axis=1).toarray()\n", + " X = X.multiply(1.0 / max_X)\n", + " return X\n", + " \n", + "def sparse2image(sparse_footprints, image_dims):\n", + " image = (\n", + " (sparse_footprints > 0)\n", + " .T\n", + " .reshape(*image_dims)\n", + " .toarray()\n", + " )\n", + " return image\n", + "\n", + "def sparse2contour(sparse_footprints, image_dims):\n", + " image = (\n", + " (sparse_footprints > 0)\n", + " .T\n", + " .reshape(*image_dims)\n", + " .toarray()\n", + " )\n", + " contour = skimage.measure.find_contours(image)\n", + " return contour\n", + "\n", + "def color_footprint_one_channel(sparse_footprints, color_vec, image_dims):\n", + " image = np.array(\n", + " (sparse_footprints > 0)\n", + " .multiply(color_vec)\n", + " .sum(axis=0)\n", + " .T\n", + " .reshape(*image_dims)\n", + " )\n", + " return image\n", + "\n", + "def color_footprint(sparse_footprints, color_matrix, image_dims):\n", + " image = np.stack([\n", + " color_footprint_one_channel(\n", + " sparse_footprints,\n", + " color_vec.reshape(-1,1),\n", + " image_dims\n", + " )\n", + " for color_vec in color_matrix.T\n", + " ], axis=-1)\n", + " image = np.clip(image, a_min=0.0, a_max=1.0)\n", + " return image\n", + "\n", + "def fuse_rois_in_background(rois, background, alpha=0.5, background_max_percentile=99):\n", + " background = (background - background.min()) / \\\n", + " (np.percentile(background, background_max_percentile) - background.min())\n", + " background = np.clip(background, a_min=0, a_max=1.0)\n", + " background = skimage.color.gray2rgb(background)\n", + " \n", + " fused_image = background * (1 - alpha) + rois * alpha\n", + " fused_image = np.clip(fused_image, a_min=0, a_max=1.0) \n", + " return fused_image\n", + " \n", + "def compute_session_fused_footprints(\n", + " images, footprints, roi_table, session, \n", + " background_choice = 'max_proj',\n", + " fuse_alpha = 0.5,\n", + " fuse_bg_max_p = 99,\n", + "):\n", + " background_images = images[session]\n", + " if background_choice not in background_images:\n", + " warnings.warn(f'{background_choice} not in the aligned images. Using \"fov\" field instead')\n", + " assert 'fov' in background_images, '\"fov\" not found in aligned images'\n", + " background_choice = 'fov'\n", + " background_image = background_images[background_choice]\n", + " \n", + " sparse_footprints = footprints[session]\n", + " throw_invalid_footprint_warning(sparse_footprints, roi_table, session) \n", + " \n", + " sparse_footprints, roi_colors = select_session_rois(sparse_footprints, roi_table, session)\n", + " sparse_footprints = normalize_footprint(sparse_footprints)\n", + " colored_footprints = color_footprint(sparse_footprints, roi_colors, image_dims)\n", + " fused_footprints = fuse_rois_in_background(\n", + " colored_footprints,\n", + " background_image,\n", + " alpha=fuse_alpha, \n", + " background_max_percentile=fuse_bg_max_p\n", + " )\n", + " return fused_footprints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea78c1b-7776-4230-b8ae-d20318c1c0d3", + "metadata": {}, + "outputs": [], + "source": [ + "fuse_kwargs = dict(\n", + " images=aligned_images,\n", + " footprints=aligned_rois,\n", + " roi_table=roi_table,\n", + " **fuse_cfg\n", + ")\n", + "\n", + "fused_footprints = [\n", + " compute_session_fused_footprints(\n", + " session=session_idx,\n", + " **fuse_kwargs\n", + " )\n", + " for session_idx in range(num_sessions)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e995a99-4110-4335-962d-fd81764a49d3", + "metadata": {}, + "outputs": [], + "source": [ + "roi_contours = []\n", + "for session_rois in aligned_rois:\n", + " roi_contours.append([\n", + " sparse2contour(roi_footprint, image_dims) \n", + " for roi_footprint in session_rois\n", + " ])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b8fd153-b55e-4841-a3e8-d1257763afc6", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(20,10))\n", + "for i in range(num_sessions):\n", + " plt.subplot(1,num_sessions,i+1)\n", + " plt.imshow(fused_footprints[i])\n", + " plt.axis('off')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9966739-17a4-482e-929b-1e5460328587", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib tk\n", + "\n", + "contour_kwargs = dict(c='r', lw=2, alpha=0.8)\n", + "num_cols = min(num_sessions, 6)\n", + "\n", + "\n", + "fig, axes = plt.subplots(\n", + " int(np.ceil(num_sessions/num_cols)), num_cols,\n", + " figsize=(20,10),\n", + " sharex=True,\n", + " sharey=True,\n", + ")\n", + "ax_list = axes.flatten()\n", + "\n", + "for i in range(num_sessions):\n", + " ax_list[i].imshow(fused_footprints[i])\n", + " ax_list[i].set_axis_off()\n", + " ax_list[i].set_title(f'session #{i+1}')\n", + " ax_list[i].session_index = i\n", + " \n", + "def find_tagged_objects(fig, value, key='tag'):\n", + " objs = fig.findobj(\n", + " match=lambda x:\n", + " False if not hasattr(x, key) \n", + " else value in getattr(x, key)\n", + " )\n", + " return objs \n", + "\n", + " \n", + "def onclick(event, tag='highlight'): \n", + " # remove previous highlighted objects\n", + " for x in find_tagged_objects(fig, value=tag):\n", + " x.remove()\n", + " \n", + " # double click: reset title\n", + " # not working, TBD\n", + " # if event.dblclick:\n", + " # for ax in ax_list:\n", + " # ax.set_title(f'session #{i+1}')\n", + " # return\n", + " \n", + " # get data\n", + " ix, iy, ax = event.xdata, event.ydata, event.inaxes\n", + " session = ax.session_index\n", + " \n", + " # obtain session ROI index\n", + " flat_idx = np.ravel_multi_index((round(iy),round(ix)), image_dims)\n", + " session_roi = aligned_rois[session][:,flat_idx].nonzero()[0]\n", + " if len(session_roi) == 0:\n", + " return\n", + " session_roi = session_roi[0] # just select first one if there are ovelap\n", + " select_global_roi = (\n", + " roi_table.query('session == @session and session_roi == @session_roi')\n", + " ['global_roi'].to_list()\n", + " )\n", + " assert len(select_global_roi) == 1\n", + " select_global_roi = select_global_roi[0]\n", + " \n", + " # obtain contours\n", + " select_contours = {\n", + " r['session']: dict(\n", + " contour = roi_contours[r['session']][r['session_roi']],\n", + " **r,\n", + " )\n", + " for _, r in roi_table.query('global_roi == @select_global_roi').iterrows()\n", + " }\n", + " \n", + " # plot contours\n", + " for session, ax in enumerate(ax_list):\n", + " if session not in select_contours:\n", + " ax.set_title(f'session #{session+1} [NOT FOUND]')\n", + " continue\n", + " \n", + " session_contours = select_contours[session]['contour']\n", + " select_session_roi = select_contours[session]['session_roi']\n", + " \n", + " for c in session_contours:\n", + " c_handles = ax.plot(c[:,1],c[:,0], **contour_kwargs)\n", + " for ch in c_handles:\n", + " ch.tag = tag\n", + " \n", + " ax.set_title(f'session #{session+1} [id={select_session_roi} | ID={select_global_roi}]')\n", + " \n", + " plt.show()\n", + " \n", + "cid = fig.canvas.mpl_connect('button_press_event', onclick)\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "datavis", + "language": "python", + "name": "datavis" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + }, + "scenes_data": { + "active_scene": "Default Scene", + "init_scene": "", + "scenes": [ + "Default Scene" + ] + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/refine-muse.ipynb b/notebooks/refine-muse.ipynb new file mode 100644 index 0000000..460a7f8 --- /dev/null +++ b/notebooks/refine-muse.ipynb @@ -0,0 +1,821 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b58dc7ec-7a67-49dc-a26d-b4dfdf90ab3d", + "metadata": { + "tags": [] + }, + "source": [ + "# Inspect and refine multisession outputs\n", + "\n", + "## Description\n", + "\n", + "This is currently being tested to inspect and refine the output of processed `roicat` outputs.\n", + "\n", + "Very **EXPERIMENTAL**. Use with caution\n", + "\n", + "## Requirements\n", + "\n", + "- `numpy`\n", + "- `pandas`\n", + "- `scikit-image`\n", + "- `matplotlib` with `tk` (can change to others like `qt` if needed, but not tested)\n", + "\n", + "## Limitations\n", + "\n", + "This currently uses `tk` and cannot run on Oscar headless. Not sure about Open Ondemand Jupyter.\n", + "\n", + "So either run this locally or use Open Ondemand Desktop session.\n", + "\n", + "The rendering is currently not optimized and will be slow (as it replots many things depending on the operation).\n", + "\n", + "This only plots the `iscell` ones.\n", + "\n", + "## Instructions\n", + "\n", + "### Paths and config\n", + "\n", + "Define the parameters and paths in `Define paths and parameters`\n", + "\n", + "Then run the whole notebook. \n", + "\n", + "### GUI interactions\n", + "\n", + "- Note: \n", + " - Sometimes it might take some time to render\n", + " - The buttons can sometimes be not very responsive, click hard on them\n", + "- To simply inspect: \n", + " - Click on an ROI of a cetain session\n", + " - If available, **red** contours of the same ROI will show up in other sessions\n", + " - The titles will change:\n", + " - Lowercase `id` is the session local ROI index\n", + " - Uppercase `ID` is the global experiment ROI index\n", + " - If ROI cannot be found in a given session based on `roicat` output, the title will say `NOT FOUND`\n", + "- At any point, to clear current selections and requests, click on `Clear`\n", + "- To **chain** ROIs across sessions: \n", + " - First select an ROI so that the **red** contours show up\n", + " - Then click `Chain Request`\n", + " - Then click on another ROI in other sessions\n", + " - **Light blue** contours will show up\n", + " - Then click `Chain`\n", + " - Wait a bit, the colors will change so they will match\n", + " - Click `Clear` and then re-select the ROIs to make sure they now have the **same** global ID (i.e. *chained*)\n", + " - Note: only ROIs that do not appear in the same sessions can be chained\n", + "- To **unchain** ROIs of a certain session (only one at a time):\n", + " - First select an ROI so that the **red** contours show up\n", + " - Then click `Unchain Request`\n", + " - Then enter the session number (starts from 1) to unchain in the box next to it\n", + " - Then press the `Enter` key (or `Return` key on Mac)\n", + " - Wait a bit, the color of the requested ROI to unchain will change\n", + " - Click `Clear` and then re-select the ROIs to make sure they now have **different** global IDs (i.e. *unchained*)\n", + "- If you're satisfied, click `Save`, then the `refined-roi.csv` file will appear\n", + " - Check the same `muse_dir` directory to make sure a file `refined-roi.csv` appears\n", + " - Note: this will not save the color columns\n", + "- To re-inspect the refined output, re-run the notebook with `plot_refined = True`\n" + ] + }, + { + "cell_type": "markdown", + "id": "2da3aecc-b2a0-45aa-8d57-60cd9e001af0", + "metadata": {}, + "source": [ + "## Import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "020f3892-fe74-493e-81c3-eb99e5a10adc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import glob\n", + "import warnings\n", + "from datetime import datetime\n", + "\n", + "from pathlib import Path\n", + "import pickle\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import skimage\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.widgets import Button, MultiCursor, TextBox\n", + "\n", + "# change to other backend if needed\n", + "%matplotlib tk" + ] + }, + { + "cell_type": "markdown", + "id": "f83286b4-a951-4e31-a4f2-67bb6e66902c", + "metadata": {}, + "source": [ + "## Define paths and parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37fe3dc0-6aec-4c7b-91f1-5579991bb8b0", + "metadata": {}, + "outputs": [], + "source": [ + "# Define where multisession directory is (with specified plane)\n", + "# e.g:`path/to/SUBJECT_ID/multi-session/PLANE_ID`\n", + "muse_dir = 'data/SD_0664/multi-session/plane0'\n", + "\n", + "# only if there's a `refined-roi.csv` file\n", + "# if run for first time of a `muse_dir`, use `False`\n", + "# use this to indicate whether you want to replot the refined ROI\n", + "plot_refined = True\n", + "\n", + "# Define visualization configurations (combining ROI masks and background images)\n", + "fuse_cfg = dict(\n", + " background_choice = 'max_proj', # max_proj, meanImg, meanImgE, Vcorr\n", + " fuse_alpha = 0.4, # fusing ROI mask with background\n", + " fuse_bg_max_p = 98, # for clipping background max value with percentile\n", + ")\n", + "\n", + "# use -1, 0 or None to indicate num subplot columns = num sessions\n", + "# use integer > 2 otherwise \n", + "column_wrap = None" + ] + }, + { + "cell_type": "markdown", + "id": "bd081a74-4912-4d92-8708-0df1c09a34ef", + "metadata": {}, + "source": [ + "## Initialize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e61a9bb-ed7b-48f4-b30b-4c3d8fb57514", + "metadata": {}, + "outputs": [], + "source": [ + "# define paths\n", + "muse_dir = Path(muse_dir)\n", + "roicat_out_file = muse_dir / 'roicat-output.pkl'\n", + "aligned_img_file = muse_dir / 'aligned-images.pkl'\n", + "roi_table_file = muse_dir / 'finalized-roi.csv'\n", + "save_roi_file = muse_dir / 'refined-roi.csv'\n", + "\n", + "if plot_refined:\n", + " assert save_roi_file.exists(), f'The refined file \"{save_roi_file}\" does not exist to plot'\n", + " roi_table_file = save_roi_file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4af7a27e-e508-4ea9-a13f-be920421fd49", + "metadata": {}, + "outputs": [], + "source": [ + "# read in\n", + "with open(roicat_out_file, 'rb') as f:\n", + " roicat_out = pickle.load(f)\n", + "\n", + "with open(aligned_img_file, 'rb') as f:\n", + " aligned_images = pickle.load(f)\n", + "\n", + "aligned_rois = roicat_out['results']['ROIs']['ROIs_aligned']\n", + "\n", + "# TODO: unclear if this is the right order\n", + "image_dims = (\n", + " roicat_out['results']['ROIs']['frame_width'],\n", + " roicat_out['results']['ROIs']['frame_height']\n", + ")\n", + "roi_table = pd.read_csv(roi_table_file)\n", + "\n", + "num_sessions = len(aligned_rois)\n", + "assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6158c18-ca02-4fe1-bf93-1b982aaf597f", + "metadata": {}, + "outputs": [], + "source": [ + "# assign unique colors\n", + "num_global_rois = roi_table['global_roi'].nunique()\n", + "max_global_roi_id = roi_table['global_roi'].max()\n", + "\n", + "global_roi_colors = np.random.rand(max_global_roi_id + 1, 3).round(3)\n", + "\n", + "roi_table['global_roi_color'] = roi_table['global_roi'].apply(\n", + " lambda x: global_roi_colors[x]\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e763fd9-5cd7-4cac-8e17-8a533884b7f6", + "metadata": {}, + "outputs": [], + "source": [ + "# check, and recount if needed, the `num_sessions` column\n", + "def check_num_sessions(df):\n", + " \"Check `num_sessions` column and count if inconsistent\"\n", + " \n", + " assert all(df.groupby('global_roi')['num_sessions'].nunique() == 1), \\\n", + " 'There should only be one value for `num_sessions` per `global_roi`'\n", + "\n", + " nses_file_df = (\n", + " df.filter(['global_roi', 'num_sessions'])\n", + " .drop_duplicates()\n", + " .set_index('global_roi')\n", + " .sort_index()\n", + " )\n", + "\n", + " nses_recount_df = (\n", + " df.groupby('global_roi')\n", + " ['session'].nunique()\n", + " .to_frame('num_sessions')\n", + " )\n", + "\n", + " nses_diff = nses_file_df.compare(nses_recount_df)\n", + " if len(nses_diff) > 0:\n", + " print(\n", + " '`num_sessions` is incorrect for `global_roi`: ', \n", + " nses_diff.index.to_list()\n", + " )\n", + " print('-> Dropping current `num_sessions` column and re-count in data frame')\n", + "\n", + " df = (\n", + " df.drop(columns='num_sessions')\n", + " .merge(nses_recount_df.reset_index(), on='global_roi')\n", + " )\n", + " return df\n", + "\n", + "roi_table = check_num_sessions(roi_table)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b03da632-c02f-42f2-8171-3fbd6ebcd0d4", + "metadata": {}, + "outputs": [], + "source": [ + "# save backup of modifiable columns \n", + "backup_tag = datetime.now().strftime('pre[%Y%m%d_%H%M%S]')\n", + "\n", + "roi_table[f'{backup_tag}::num_sessions'] = roi_table['num_sessions'].copy()\n", + "roi_table[f'{backup_tag}::global_roi'] = roi_table['global_roi'].copy()\n", + "roi_table[f'{backup_tag}::global_roi_color'] = roi_table['global_roi_color'].copy()\n", + "\n", + "roi_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "532ad296-d743-435c-953f-a8f8521acfe4", + "metadata": {}, + "outputs": [], + "source": [ + "# functions\n", + "# TODO: these should be in a source file\n", + "def throw_invalid_footprint_warning(footprints, roi_table, session):\n", + " invalid_max_idx = footprints.max(axis=1).toarray().squeeze() <= 0\n", + " if sum(invalid_max_idx) == 0:\n", + " return\n", + " \n", + " invalid_max_idx = np.where(invalid_max_idx)[0]\n", + " \n", + " roi_record = roi_table.query(\n", + " 'session == @session and ' \\\n", + " 'session_roi in @invalid_max_idx and' \\\n", + " '(num_sessions > 1 or roicat_global_roi >= 0)' \n", + " )\n", + " if len(roi_record) > 0:\n", + " warnings.warn(\n", + " 'The following ROI records have invalid footprints, i.e. '\\\n", + " '(a) max = 0, (b) persist for more than 1 sessions or `roicat_global_roi=-1`:',\n", + " roi_record.drop(columns='global_roi_color').to_dict('records')\n", + " )\n", + " \n", + "def select_session_rois(footprints, roi_table, session):\n", + " session_iscell_rois = (\n", + " roi_table.query('session == @session')\n", + " .sort_values(by='session_roi')\n", + " .reset_index(drop=True)\n", + " )\n", + " roi_colors = np.stack(session_iscell_rois['global_roi_color'].to_list())\n", + " roi_iscell_idx = session_iscell_rois['session_roi'].to_list()\n", + "\n", + " footprints = footprints[roi_iscell_idx]\n", + " assert footprints.shape[0] == len(roi_iscell_idx)\n", + " \n", + " return footprints, roi_colors \n", + "\n", + "def normalize_footprint(X):\n", + " max_X = X.max(axis=1).toarray()\n", + " X = X.multiply(1.0 / max_X)\n", + " return X\n", + " \n", + "def sparse2image(sparse_footprints, image_dims):\n", + " image = (\n", + " (sparse_footprints > 0)\n", + " .T\n", + " .reshape(*image_dims)\n", + " .toarray()\n", + " )\n", + " return image\n", + "\n", + "def sparse2contour(sparse_footprints, image_dims):\n", + " image = (\n", + " (sparse_footprints > 0)\n", + " .T\n", + " .reshape(*image_dims)\n", + " .toarray()\n", + " )\n", + " contour = skimage.measure.find_contours(image)\n", + " return contour\n", + "\n", + "def color_footprint_one_channel(sparse_footprints, color_vec, image_dims):\n", + " image = np.array(\n", + " (sparse_footprints > 0)\n", + " .multiply(color_vec)\n", + " .sum(axis=0)\n", + " .T\n", + " .reshape(*image_dims)\n", + " )\n", + " return image\n", + "\n", + "def color_footprint(sparse_footprints, color_matrix, image_dims):\n", + " image = np.stack([\n", + " color_footprint_one_channel(\n", + " sparse_footprints,\n", + " color_vec.reshape(-1,1),\n", + " image_dims\n", + " )\n", + " for color_vec in color_matrix.T\n", + " ], axis=-1)\n", + " image = np.clip(image, a_min=0.0, a_max=1.0)\n", + " return image\n", + "\n", + "def fuse_rois_in_background(rois, background, alpha=0.5, background_max_percentile=99):\n", + " background = (background - background.min()) / \\\n", + " (np.percentile(background, background_max_percentile) - background.min())\n", + " background = np.clip(background, a_min=0, a_max=1.0)\n", + " background = skimage.color.gray2rgb(background)\n", + " \n", + " fused_image = background * (1 - alpha) + rois * alpha\n", + " fused_image = np.clip(fused_image, a_min=0, a_max=1.0) \n", + " return fused_image\n", + " \n", + "def compute_session_fused_footprints(\n", + " images, footprints, roi_table, session, \n", + " background_choice = 'max_proj',\n", + " fuse_alpha = 0.5,\n", + " fuse_bg_max_p = 99,\n", + "):\n", + " background_images = images[session]\n", + " if background_choice not in background_images:\n", + " warnings.warn(f'{background_choice} not in the aligned images. Using \"fov\" field instead')\n", + " assert 'fov' in background_images, '\"fov\" not found in aligned images'\n", + " background_choice = 'fov'\n", + " background_image = background_images[background_choice]\n", + " \n", + " sparse_footprints = footprints[session]\n", + " throw_invalid_footprint_warning(sparse_footprints, roi_table, session) \n", + " \n", + " sparse_footprints, roi_colors = select_session_rois(sparse_footprints, roi_table, session)\n", + " sparse_footprints = normalize_footprint(sparse_footprints)\n", + " colored_footprints = color_footprint(sparse_footprints, roi_colors, image_dims)\n", + " fused_footprints = fuse_rois_in_background(\n", + " colored_footprints,\n", + " background_image,\n", + " alpha=fuse_alpha, \n", + " background_max_percentile=fuse_bg_max_p\n", + " )\n", + " return fused_footprints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea78c1b-7776-4230-b8ae-d20318c1c0d3", + "metadata": {}, + "outputs": [], + "source": [ + "# fuse ROI and background\n", + "fuse_kwargs = dict(\n", + " images=aligned_images,\n", + " footprints=aligned_rois,\n", + " roi_table=roi_table,\n", + " **fuse_cfg\n", + ")\n", + "\n", + "fused_footprints = [\n", + " compute_session_fused_footprints(\n", + " session=session_idx,\n", + " **fuse_kwargs\n", + " )\n", + " for session_idx in range(num_sessions)\n", + "]\n", + "\n", + "# get contours\n", + "roi_contours = []\n", + "for session_rois in aligned_rois:\n", + " roi_contours.append([\n", + " sparse2contour(roi_footprint, image_dims) \n", + " for roi_footprint in session_rois\n", + " ])\n", + "\n", + "# plot to see what they look like\n", + "# plt.figure(figsize=(20,10))\n", + "# for i in range(num_sessions):\n", + "# plt.subplot(1,num_sessions,i+1)\n", + "# plt.imshow(fused_footprints[i])\n", + "# plt.axis('off')\n", + "# plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5ad4f64f-c89a-4a8d-8000-0d7fd52f1c88", + "metadata": {}, + "source": [ + "## Refine GUI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9966739-17a4-482e-929b-1e5460328587", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: set variables inside objects instead of modifying global objects\n", + "# - roi_table\n", + "# - aligned_images\n", + "# - aligned_rois\n", + "# - highlight_contour_kwargs\n", + "# - request_contour_kwargs\n", + "# - save_roi_file\n", + "# TODO: figure out how to optimize re-drawing bc very slow right now\n", + "# TODO: warn when an operation is not valid\n", + "\n", + "highlight_contour_kwargs = dict(c='r', lw=2, alpha=0.8)\n", + "request_contour_kwargs = dict(c='#67a9cf', lw=2, alpha=0.8, ls='--')\n", + "\n", + "if column_wrap is None:\n", + " column_wrap = num_sessions\n", + "if column_wrap < 2:\n", + " column_wrap = num_sessions\n", + "num_cols = min(num_sessions, column_wrap)\n", + "\n", + "fig, axes = plt.subplots(\n", + " int(np.ceil(num_sessions/num_cols)), num_cols,\n", + " figsize=(20,10),\n", + " sharex=True,\n", + " sharey=True,\n", + ")\n", + "ax_list = axes.flatten()\n", + "\n", + "for i in range(num_sessions):\n", + " ax_list[i].imshow(fused_footprints[i])\n", + " ax_list[i].set_axis_off()\n", + " ax_list[i].set_title(f'session #{i+1}')\n", + " ax_list[i].session_index = i\n", + " \n", + "def find_tagged_objects(fig, value, key='tag'):\n", + " objs = fig.findobj(\n", + " match=lambda x:\n", + " False if not hasattr(x, key) \n", + " else value in getattr(x, key)\n", + " )\n", + " return objs \n", + "\n", + "class MuseStateCallBacks:\n", + " state = dict(\n", + " chain_request = False,\n", + " chain = False,\n", + " unchain_request = False,\n", + " # unchain = False,\n", + " )\n", + " selection = dict(\n", + " session = None,\n", + " session_roi = None,\n", + " global_roi = None\n", + " )\n", + " chain_selection = dict( \n", + " session = None,\n", + " session_roi = None,\n", + " global_roi = None\n", + " )\n", + " \n", + " def __init__(self, save_path):\n", + " self.save_path = save_path\n", + " \n", + " def refuse_footprint(self):\n", + " fuse_kwargs = dict(\n", + " images=aligned_images,\n", + " footprints=aligned_rois,\n", + " roi_table=roi_table,\n", + " **fuse_cfg\n", + " )\n", + "\n", + " fused_footprints = [\n", + " compute_session_fused_footprints(\n", + " session=session_idx,\n", + " **fuse_kwargs\n", + " )\n", + " for session_idx in range(num_sessions)\n", + " ]\n", + " \n", + " \n", + " for i in range(num_sessions):\n", + " ax_list[i].imshow(fused_footprints[i])\n", + " \n", + " plt.draw()\n", + " \n", + " def clear(self, event):\n", + " self.reset_state()\n", + " clear_objs = find_tagged_objects(fig, value='highlight') +\\\n", + " find_tagged_objects(fig, value='chain_request')\n", + " \n", + " for x in clear_objs:\n", + " x.remove() \n", + " for i, ax in enumerate(ax_list):\n", + " ax.set_title(f'session #{i+1}')\n", + " \n", + " def reset_state(self):\n", + " for k in self.state.keys():\n", + " self.state[k] = False\n", + " \n", + " def unchain(self, unchain_session):\n", + " print(self.state, self.selection)\n", + " if not self.state['unchain_request'] or None in self.selection.values():\n", + " self.reset_state()\n", + " return\n", + " unchain_session = int(unchain_session) - 1\n", + " curr_global_roi = self.selection['global_roi']\n", + " unchain_roi_idx = (\n", + " roi_table\n", + " .query('global_roi == @curr_global_roi and session == @unchain_session')\n", + " .index.to_list()\n", + " )\n", + " if len(unchain_roi_idx) < 1:\n", + " self.reset_state()\n", + " return\n", + " unchain_roi_idx = list(unchain_roi_idx)[0]\n", + " if roi_table.loc[unchain_roi_idx, 'num_sessions'] < 2:\n", + " self.reset_state()\n", + " return\n", + "\n", + " next_global_roi = roi_table['global_roi'].max() + 1\n", + " \n", + " # decrement `num_sessions` for the ROI that used to be associated with the unchained one\n", + " roi_table.loc[roi_table['global_roi'] == curr_global_roi, 'num_sessions'] -= 1 \n", + " \n", + " # unchained ROI becomes a new ROI \n", + " roi_table.loc[unchain_roi_idx, 'global_roi'] = next_global_roi\n", + " roi_table.loc[unchain_roi_idx, 'num_sessions'] = 1\n", + " roi_table.at[unchain_roi_idx, 'global_roi_color'] = np.random.rand(3)\n", + " \n", + " self.reset_state()\n", + " self.refuse_footprint()\n", + " \n", + " def chain(self, event):\n", + " print(self.state, self.selection)\n", + " if (\n", + " not self.state['chain_request'] or \\\n", + " None in self.selection.values()\n", + " ):\n", + " self.reset_state()\n", + " return\n", + " \n", + " if None in self.chain_selection.values():\n", + " return\n", + " \n", + " curr_glob_roi = self.selection['global_roi']\n", + " curr_glob_roi_rows = roi_table.query('global_roi == @curr_glob_roi')\n", + " \n", + " chain_glob_roi = self.chain_selection['global_roi']\n", + " chain_glob_roi_rows = roi_table.query('global_roi == @chain_glob_roi')\n", + " \n", + " if curr_glob_roi == chain_glob_roi:\n", + " return\n", + " \n", + " print(curr_glob_roi_rows, chain_glob_roi_rows)\n", + " shared_sessions = (\n", + " set(curr_glob_roi_rows['session'].to_list())\n", + " .intersection(chain_glob_roi_rows['session'].to_list())\n", + " )\n", + " union_sessions = (\n", + " set(curr_glob_roi_rows['session'].to_list())\n", + " .union(chain_glob_roi_rows['session'].to_list())\n", + " )\n", + " if len(shared_sessions) > 0:\n", + " return\n", + " \n", + " min_glob_roi = min(curr_glob_roi, chain_glob_roi)\n", + " glob_color = roi_table.query('global_roi == @min_glob_roi')['global_roi_color'].iloc[0]\n", + " concat_idx = list(curr_glob_roi_rows.index) + list(chain_glob_roi_rows.index)\n", + " roi_table.loc[concat_idx, 'num_sessions'] = len(union_sessions)\n", + " roi_table.loc[concat_idx, 'global_roi'] = min_glob_roi\n", + " for idx in concat_idx:\n", + " roi_table.at[idx, 'global_roi_color'] = glob_color\n", + " \n", + " self.reset_state()\n", + " self.refuse_footprint()\n", + " \n", + " def unchain_request(self, event):\n", + " self.reset_state()\n", + " if None in self.selection.values():\n", + " return\n", + " self.state['unchain_request'] = True\n", + " \n", + " def chain_request(self, event):\n", + " self.refuse_footprint()\n", + " if None in self.selection.values():\n", + " return\n", + " self.state['chain_request'] = True\n", + " \n", + " def onclick(self, event, tag='highlight'): \n", + " # get data\n", + " ix, iy, ax = event.xdata, event.ydata, event.inaxes\n", + " if not hasattr(ax, 'session_index'):\n", + " return\n", + " \n", + " contour_kwargs = highlight_contour_kwargs.copy()\n", + " if self.state['chain_request']:\n", + " tag = 'chain_request'\n", + " contour_kwargs = request_contour_kwargs.copy()\n", + "\n", + " # remove previous highlighted objects \n", + " for x in find_tagged_objects(fig, value=tag):\n", + " x.remove() \n", + " session = ax.session_index\n", + "\n", + " # obtain session ROI index\n", + " flat_idx = np.ravel_multi_index((round(iy),round(ix)), image_dims)\n", + " session_roi = aligned_rois[session][:,flat_idx].nonzero()[0]\n", + " if len(session_roi) == 0:\n", + " return\n", + " session_roi = session_roi[0] # just select first one if there are ovelap\n", + " select_global_roi = (\n", + " roi_table.query('session == @session and session_roi == @session_roi')\n", + " ['global_roi'].to_list()\n", + " )\n", + " if len(select_global_roi) != 1:\n", + " return\n", + " select_global_roi = select_global_roi[0]\n", + " \n", + " if not (self.state['chain_request']):\n", + " self.selection['session'] = session\n", + " self.selection['session_roi'] = session_roi\n", + " self.selection['global_roi'] = select_global_roi\n", + " else:\n", + " self.chain_selection['session'] = session\n", + " self.chain_selection['session_roi'] = session_roi\n", + " self.chain_selection['global_roi'] = select_global_roi\n", + " \n", + " # obtain contours\n", + " select_contours = {\n", + " r['session']: dict(\n", + " contour = roi_contours[r['session']][r['session_roi']],\n", + " **r,\n", + " )\n", + " for _, r in roi_table.query('global_roi == @select_global_roi').iterrows()\n", + " }\n", + "\n", + " # plot contours\n", + " for session, ax in enumerate(ax_list):\n", + " if session not in select_contours:\n", + " ax.set_title(f'session #{session+1} [NOT FOUND]')\n", + " continue\n", + "\n", + " session_contours = select_contours[session]['contour']\n", + " select_session_roi = select_contours[session]['session_roi']\n", + "\n", + " for c in session_contours:\n", + " c_handles = ax.plot(c[:,1],c[:,0], **contour_kwargs)\n", + " for ch in c_handles:\n", + " ch.tag = tag\n", + "\n", + " ax.set_title(f'session #{session+1} [id={select_session_roi} | ID={select_global_roi}]')\n", + "\n", + " plt.draw()\n", + " \n", + " def save(self, event):\n", + " # currently avoid saving color columns\n", + " # TODO: if a file exists, warn before saving\n", + " color_columns = (\n", + " roi_table\n", + " .filter(regex='.*global_roi_color.*')\n", + " .columns\n", + " .to_list()\n", + " )\n", + " \n", + " (\n", + " check_num_sessions(roi_table)\n", + " .drop(columns=color_columns)\n", + " .to_csv(\n", + " self.save_path,\n", + " index=False\n", + " )\n", + " )\n", + " \n", + "\n", + "muse_cb = MuseStateCallBacks(\n", + " save_path = save_roi_file,\n", + ")\n", + "cid = fig.canvas.mpl_connect('button_press_event', muse_cb.onclick)\n", + "\n", + "ax_buttons = dict(\n", + " chain_request = fig.add_axes([0.10, 0.05, 0.08, 0.05]),\n", + " chain = fig.add_axes([0.20, 0.05, 0.08, 0.05]),\n", + " unchain_request = fig.add_axes([0.40, 0.05, 0.08, 0.05]),\n", + " clear = fig.add_axes([0.75, 0.05, 0.08, 0.05]),\n", + " save = fig.add_axes([0.9, 0.05, 0.08, 0.05]),\n", + ")\n", + "\n", + "buttons = {\n", + " k: Button(v, k.replace('_', ' ').title())\n", + " for k, v in ax_buttons.items()\n", + "}\n", + "buttons['clear'].on_clicked(muse_cb.clear)\n", + "buttons['chain_request'].on_clicked(muse_cb.chain_request)\n", + "buttons['chain'].on_clicked(muse_cb.chain)\n", + "buttons['unchain_request'].on_clicked(muse_cb.unchain_request)\n", + "buttons['save'].on_clicked(muse_cb.save)\n", + "\n", + "ax_text = fig.add_axes([0.52, 0.05, 0.05, 0.05])\n", + "text_box = TextBox(ax_text, f\"Session \\n (1-{num_sessions}) \", textalignment=\"left\")\n", + "text_box.on_submit(muse_cb.unchain)\n", + "\n", + "multi = MultiCursor(None, ax_list, color='r', lw=0.5, horizOn=True, vertOn=True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1fdbd953-2acf-4b4a-8b34-035895cd4451", + "metadata": {}, + "source": [ + "Check the modification dates of the `*-roi.csv` files. If they are not what's expected, try using what's inside the `MuseStateCallBacks.save` function explicitly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb1175f3-1987-4f5f-b07c-cdd4d7f328c8", + "metadata": {}, + "outputs": [], + "source": [ + "!ls -lthr $muse_dir" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "scenes_data": { + "active_scene": "Default Scene", + "init_scene": "", + "scenes": [ + "Default Scene" + ] + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/roicat-standalone.ipynb b/notebooks/roicat-standalone.ipynb new file mode 100644 index 0000000..5bc3fa3 --- /dev/null +++ b/notebooks/roicat-standalone.ipynb @@ -0,0 +1,942 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7241263e-8942-41a0-8c5b-7393431912b3", + "metadata": {}, + "source": [ + "# ROICAT Standalone\n", + "\n", + "This is a standalone notebook to use `roicat` for cell tracking after `suite2p`.\n", + "\n", + "This was taken from `cellreg/4-roicat.py` but will not be as updated, and may be removed.\n", + "\n", + "This can be run on Oscar or locally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be3ba033-a6d2-4d5e-a1ca-f9bd6805735c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import glob\n", + "import shutil\n", + "from pathlib import Path\n", + "import copy\n", + "import multiprocessing as mp\n", + "import tempfile\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "\n", + "import roicat\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "import argparse\n", + "import yaml\n", + "import pickle\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac423786-68b2-4495-a6c5-c23f0a676a3d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "PARAMS = {\n", + " 'um_per_pixel': 0.7,\n", + " 'background_max_percentile': 99.9,\n", + " 'suite2p': { # `roicat.data_importing.Data_suite2p`\n", + " 'new_or_old_suite2p': 'new',\n", + " 'type_meanImg': 'meanImgE',\n", + " },\n", + " 'fov_augment': { # `aligner.augment_FOV_images`\n", + " 'roi_FOV_mixing_factor': 0.5,\n", + " 'use_CLAHE': False,\n", + " 'CLAHE_grid_size': 1,\n", + " 'CLAHE_clipLimit': 1,\n", + " 'CLAHE_normalize': True,\n", + " },\n", + " 'fit_geometric': { # `aligner.fit_geometric`\n", + " 'template': 0, \n", + " 'template_method': 'image', \n", + " 'mode_transform': 'affine',\n", + " 'mask_borders': (5,5,5,5), \n", + " 'n_iter': 1000,\n", + " 'termination_eps': 1e-6, \n", + " 'gaussFiltSize': 15,\n", + " 'auto_fix_gaussFilt_step':1,\n", + " },\n", + " 'fit_nonrigid': { # `aligner.fit_nonrigid`\n", + " 'disable': True,\n", + " 'template': 0,\n", + " 'template_method': 'image',\n", + " 'mode_transform':'createOptFlow_DeepFlow',\n", + " 'kwargs_mode_transform':None,\n", + " },\n", + " 'roi_blur': {\n", + " 'kernel_halfWidth': 2\n", + " }\n", + " \n", + "}\n", + "\n", + "DISABLE_NONRIGID = PARAMS['fit_nonrigid'].pop('disable')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf4851d0-cb6a-4a04-816b-eb21990cc88a", + "metadata": {}, + "outputs": [], + "source": [ + "ROOT_DATA_DIR = '/oscar/data/afleisc2/sdaste/ROICat-test/'\n", + "SUBJECT_ID = 'SD_0664'\n", + "PLANE_ID = 'plane2'\n", + "SUITE2P_PATH_MAXDEPTH=6\n", + "USE_GPU = False\n", + "VERBOSITY = True\n", + "OUTPUT_DIR='../data'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7d73c56-82cc-4a14-a045-fc15ad3e3957", + "metadata": {}, + "outputs": [], + "source": [ + "# define paths\n", + "SUBJECT_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID\n", + "COLLECTIVE_MUSE_DIR = Path(OUTPUT_DIR) / SUBJECT_ID / 'multi-session' / PLANE_ID\n", + "COLLECTIVE_MUSE_FIG_DIR = COLLECTIVE_MUSE_DIR / 'figures'\n", + "COLLECTIVE_MUSE_FIG_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bacdcf2-a84f-4227-b13f-a0668b9bde24", + "metadata": {}, + "outputs": [], + "source": [ + "# find suite2p paths\n", + "dir_allOuterFolders = str(SUBJECT_DIR)\n", + "pathSuffixToStat = 'stat.npy'\n", + "pathSuffixToOps = 'ops.npy'\n", + "pathShouldHave = fr'suite2p/{PLANE_ID}'\n", + "\n", + "paths_allStat = roicat.helpers.find_paths(\n", + " dir_outer=dir_allOuterFolders,\n", + " reMatch=pathSuffixToStat,\n", + " reMatch_in_path=pathShouldHave, \n", + " depth=SUITE2P_PATH_MAXDEPTH,\n", + ")[:]\n", + "\n", + "paths_allStat = [\n", + " x for x in paths_allStat \n", + " if pathShouldHave in x\n", + "]\n", + "\n", + "paths_allOps = np.array([\n", + " Path(path).resolve().parent / pathSuffixToOps\n", + " for path in paths_allStat\n", + "])[:]\n", + "\n", + "\n", + "print('Paths to all suite2p STAT files:')\n", + "print('\\n'.join(['\\t-' + str(x) for x in paths_allStat]))\n", + "print('\\n')\n", + "print('Paths to all suite2p OPS files:')\n", + "print('\\n'.join(['\\t-' + str(x) for x in paths_allOps]))\n", + "print('\\n')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b95bba8-9e0d-4b0b-9d4f-a527f41622b9", + "metadata": {}, + "outputs": [], + "source": [ + "# load data\n", + "data = roicat.data_importing.Data_suite2p(\n", + " paths_statFiles=paths_allStat[:],\n", + " paths_opsFiles=paths_allOps[:],\n", + " um_per_pixel=PARAMS['um_per_pixel'],\n", + " type_meanImg='meanImg', # will be overwritten in the following cell\n", + " **{k: v for k, v in PARAMS['suite2p'].items() if k not in ['type_meanImg']},\n", + " verbose=VERBOSITY,\n", + ")\n", + "\n", + "assert data.check_completeness(verbose=False)['tracking'],\\\n", + " \"Data object is missing attributes necessary for tracking.\"\n", + "\n", + "# also save iscell paths\n", + "data.paths_iscell = [\n", + " Path(x).parent / 'iscell.npy'\n", + " for x in data.paths_ops\n", + "]\n", + "\n", + "# load all background images\n", + "background_types = [\n", + " 'meanImg',\n", + " 'meanImgE',\n", + " 'max_proj',\n", + " 'Vcorr',\n", + "]\n", + "\n", + "FOV_backgrounds = {k: [] for k in background_types}\n", + "for ops_file in data.paths_ops:\n", + " ops = np.load(ops_file, allow_pickle=True).item()\n", + " \n", + " im_sz = (ops['Ly'], ops['Lx']) \n", + " for bg in background_types:\n", + " bg_im = ops[bg]\n", + " \n", + " if bg_im.shape == im_sz:\n", + " FOV_backgrounds[bg].append(bg_im)\n", + " continue\n", + "\n", + " print(\n", + " f'\\t- File {ops_file}: {bg} shape is {bg_im.shape}, which is cropped from {im_sz}. '\\\n", + " '\\n\\tWill attempt to add empty pixels to recover the original shape.'\n", + " )\n", + "\n", + " im = np.zeros(im_sz).astype(bg_im.dtype)\n", + " cropped_xrange, cropped_yrange = ops['xrange'], ops['yrange']\n", + " im[\n", + " cropped_yrange[0]:cropped_yrange[1],\n", + " cropped_xrange[0]:cropped_xrange[1]\n", + " ] = bg_im\n", + " \n", + " FOV_backgrounds[bg].append(im)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39220367-d78f-4073-a2ce-7a976fb8dbc2", + "metadata": {}, + "outputs": [], + "source": [ + "{k: [vi.shape for vi in v] for k,v in FOV_backgrounds.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0001a617-a832-4808-951a-09dc7975f551", + "metadata": {}, + "outputs": [], + "source": [ + "# choice of FOV images to align\n", + "data.FOV_images = FOV_backgrounds[PARAMS['suite2p']['type_meanImg']]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ee880cf-1eb1-40b2-b2d5-6e999c80096b", + "metadata": {}, + "outputs": [], + "source": [ + "roicat.visualization.display_toggle_image_stack(data.FOV_images)\n", + "roicat.visualization.display_toggle_image_stack(data.get_maxIntensityProjection_spatialFootprints(), clim=[0,1])\n", + "roicat.visualization.display_toggle_image_stack(np.concatenate(data.ROI_images, axis=0)[:5000], image_size=(200,200))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf6d2e0d-b805-43c5-aa8d-d308a80858cd", + "metadata": {}, + "outputs": [], + "source": [ + "# obtain FOVs\n", + "aligner = roicat.tracking.alignment.Aligner(verbose=VERBOSITY)\n", + "\n", + "FOV_images = aligner.augment_FOV_images(\n", + " ims=data.FOV_images,\n", + " spatialFootprints=data.spatialFootprints,\n", + " **PARAMS['fov_augment']\n", + ")\n", + "\n", + "roicat.visualization.display_toggle_image_stack(FOV_images)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c40093c-4ce4-40f6-9a93-16910cb73397", + "metadata": {}, + "outputs": [], + "source": [ + "# geometric fit\n", + "aligner.fit_geometric(\n", + " ims_moving=FOV_images,\n", + " **PARAMS['fit_geometric']\n", + ")\n", + "aligner.transform_images_geometric(FOV_images)\n", + "remap_idx = aligner.remappingIdx_geo\n", + "\n", + "roicat.visualization.display_toggle_image_stack(aligner.ims_registered_geo)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5dbc228-a463-4c40-b102-9d4e23a39675", + "metadata": {}, + "outputs": [], + "source": [ + "# non-rigid\n", + "if not DISABLE_NONRIGID:\n", + " aligner.fit_nonrigid(\n", + " ims_moving=aligner.ims_registered_geo,\n", + " remappingIdx_init=aligner.remappingIdx_geo, \n", + " **PARAMS['fit_nonrigid']\n", + " )\n", + " aligner.transform_images_nonrigid(FOV_images)\n", + " remap_idx = aligner.remappingIdx_nonrigid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "236d1901-1ab7-4187-9bd1-14b377563a98", + "metadata": {}, + "outputs": [], + "source": [ + "# transform ROIs\n", + "aligner.transform_ROIs(\n", + " ROIs=data.spatialFootprints,\n", + " remappingIdx=remap_idx,\n", + " normalize=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "885dfb93-63e5-43fe-aa27-2dffaa178227", + "metadata": {}, + "outputs": [], + "source": [ + "# transform other backgrounds\n", + "aligned_backgrounds = {k: [] for k in background_types}\n", + "for bg in background_types:\n", + " aligned_backgrounds[bg] = aligner.transform_images(\n", + " FOV_backgrounds[bg],\n", + " remappingIdx=remap_idx\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24b42d5f-2e5e-4ed6-bb82-fa83725ec32e", + "metadata": {}, + "outputs": [], + "source": [ + "for bg, im in aligned_backgrounds.items():\n", + " roicat.visualization.display_toggle_image_stack(im)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa7fbb38-0b8e-421d-b97f-424ba55aec13", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(20,20), layout='tight')\n", + "types2plt = background_types + ['ROI']\n", + "nrows = len(types2plt)\n", + "ncols = data.n_sessions\n", + "\n", + "splt_cnt = 1\n", + "for k in types2plt:\n", + " image_list = aligned_backgrounds.get(k, aligner.get_ROIsAligned_maxIntensityProjection())\n", + " for s_id, img in enumerate(image_list):\n", + " plt.subplot(nrows, ncols, splt_cnt)\n", + " plt.imshow(\n", + " img, cmap='Greys_r',\n", + " vmax=np.percentile(\n", + " img,\n", + " PARAMS['background_max_percentile'] if k!= \"ROI\" else 95\n", + " )\n", + " )\n", + " plt.axis('off')\n", + " plt.title(f'Aligned {k} [#{s_id}]') \n", + " splt_cnt += 1\n", + "\n", + "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-fov.png')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04c4ce2e-6fac-44ae-ba96-ba046cc55d9c", + "metadata": {}, + "outputs": [], + "source": [ + "# blur ROI\n", + "blurrer = roicat.tracking.blurring.ROI_Blurrer(\n", + " frame_shape=(data.FOV_height, data.FOV_width),\n", + " plot_kernel=False,\n", + " verbose=VERBOSITY,\n", + " **PARAMS['roi_blur']\n", + ")\n", + "\n", + "blurrer.blur_ROIs(\n", + " spatialFootprints=aligner.ROIs_aligned[:],\n", + ")\n", + "\n", + "# ROInet embedding\n", + "# TODO: Parameterize `ROInet_embedder`, `generate_dataloader`\n", + "DEVICE = roicat.helpers.set_device(use_GPU=USE_GPU, verbose=VERBOSITY)\n", + "dir_temp = tempfile.gettempdir()\n", + "\n", + "roinet = roicat.ROInet.ROInet_embedder(\n", + " device=DEVICE,\n", + " dir_networkFiles=dir_temp,\n", + " download_method='check_local_first',\n", + " download_url='https://osf.io/x3fd2/download',\n", + " download_hash='7a5fb8ad94b110037785a46b9463ea94',\n", + " forward_pass_version='latent',\n", + " verbose=VERBOSITY\n", + ")\n", + "\n", + "roinet.generate_dataloader(\n", + " ROI_images=data.ROI_images,\n", + " um_per_pixel=data.um_per_pixel,\n", + " pref_plot=False,\n", + " jit_script_transforms=False,\n", + " batchSize_dataloader=8, \n", + " pinMemory_dataloader=True,\n", + " numWorkers_dataloader=4,\n", + " persistentWorkers_dataloader=True,\n", + " prefetchFactor_dataloader=2,\n", + ")\n", + "\n", + "roinet.generate_latents()\n", + "\n", + "# Scattering wavelet embedding\n", + "# TODO: Parameterize `SWT`, `SWT.transform`\n", + "swt = roicat.tracking.scatteringWaveletTransformer.SWT(\n", + " kwargs_Scattering2D={'J': 3, 'L': 12}, \n", + " image_shape=data.ROI_images[0].shape[1:3],\n", + " device=DEVICE,\n", + ")\n", + "\n", + "swt.transform(\n", + " ROI_images=roinet.ROI_images_rs,\n", + " batch_size=100,\n", + ")\n", + "\n", + "# Compute similarities\n", + "# TODO: Parameterize `ROI_graph`, `compute_similarity_blockwise`, `make_normalized_similarities`\n", + "\n", + "sim = roicat.tracking.similarity_graph.ROI_graph(\n", + " n_workers=-1, \n", + " frame_height=data.FOV_height,\n", + " frame_width=data.FOV_width,\n", + " block_height=128, \n", + " block_width=128, \n", + " algorithm_nearestNeigbors_spatialFootprints='brute',\n", + " verbose=VERBOSITY, \n", + ")\n", + "\n", + "s_sf, s_NN, s_SWT, s_sesh = sim.compute_similarity_blockwise(\n", + " spatialFootprints=blurrer.ROIs_blurred,\n", + " features_NN=roinet.latents,\n", + " features_SWT=swt.latents,\n", + " ROI_session_bool=data.session_bool,\n", + " spatialFootprint_maskPower=1.0,\n", + ")\n", + "\n", + "sim.make_normalized_similarities(\n", + " centers_of_mass=data.centroids,\n", + " features_NN=roinet.latents,\n", + " features_SWT=swt.latents, \n", + " k_max=data.n_sessions*100,\n", + " k_min=data.n_sessions*10,\n", + " algo_NN='kd_tree',\n", + " device=DEVICE,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c1b1aac-0e82-408c-85c1-9d189c13821e", + "metadata": {}, + "outputs": [], + "source": [ + "# Clustering\n", + "# TODO: Parameterize `find_optimal_parameters_for_pruning`?\n", + "clusterer = roicat.tracking.clustering.Clusterer(\n", + " s_sf=sim.s_sf,\n", + " s_NN_z=sim.s_NN_z,\n", + " s_SWT_z=sim.s_SWT_z,\n", + " s_sesh=sim.s_sesh,\n", + ")\n", + "\n", + "kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning(\n", + " n_bins=None, \n", + " smoothing_window_bins=None,\n", + " kwargs_findParameters={\n", + " 'n_patience': 300,\n", + " 'tol_frac': 0.001, \n", + " 'max_trials': 1200, \n", + " 'max_duration': 60*10, \n", + " },\n", + " bounds_findParameters={\n", + " 'power_NN': (0., 5.),\n", + " 'power_SWT': (0., 5.),\n", + " 'p_norm': (-5, 0),\n", + " 'sig_NN_kwargs_mu': (0., 1.0), \n", + " 'sig_NN_kwargs_b': (0.00, 1.5), \n", + " 'sig_SWT_kwargs_mu': (0., 1.0),\n", + " 'sig_SWT_kwargs_b': (0.00, 1.5),\n", + " },\n", + " n_jobs_findParameters=-1,\n", + ")\n", + "\n", + "kwargs_mcdm_tmp = kwargs_makeConjunctiveDistanceMatrix_best ## Use the optimized parameters\n", + "\n", + "clusterer.plot_distSame(kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp)\n", + "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-distrib.png')\n", + "\n", + "clusterer.plot_similarity_relationships(\n", + " plots_to_show=[1,2,3], \n", + " max_samples=100000, ## Make smaller if it is running too slow\n", + " kwargs_scatter={'s':1, 'alpha':0.2},\n", + " kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp\n", + ");\n", + "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-scatter.png')\n", + "\n", + "clusterer.make_pruned_similarity_graphs(\n", + " d_cutoff=None,\n", + " kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp,\n", + " stringency=1.0,\n", + " convert_to_probability=False, \n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91de02af-fa8b-4252-8695-498fe1ac2b72", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "if data.n_sessions >= 8:\n", + " labels = clusterer.fit(\n", + " d_conj=clusterer.dConj_pruned,\n", + " session_bool=data.session_bool,\n", + " min_cluster_size=2,\n", + " n_iter_violationCorrection=3,\n", + " split_intraSession_clusters=True,\n", + " cluster_selection_method='leaf',\n", + " d_clusterMerge=None, \n", + " alpha=0.999, \n", + " discard_failed_pruning=False, \n", + " n_steps_clusterSplit=100,\n", + " )\n", + "\n", + "else:\n", + " labels = clusterer.fit_sequentialHungarian(\n", + " d_conj=clusterer.dConj_pruned, ## Input distance matrix\n", + " session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions\n", + " thresh_cost=0.6, ## Threshold \n", + " )\n", + "\n", + "quality_metrics = clusterer.compute_quality_metrics()\n", + "\n", + "labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict = roicat.tracking.clustering.make_label_variants(labels=labels, n_roi_bySession=data.n_roi)\n", + "\n", + "results = {\n", + " \"clusters\":{\n", + " \"labels\": labels_squeezed,\n", + " \"labels_bySession\": labels_bySession,\n", + " \"labels_bool\": labels_bool,\n", + " \"labels_bool_bySession\": labels_bool_bySession,\n", + " \"labels_dict\": labels_dict,\n", + " },\n", + " \"ROIs\": {\n", + " \"ROIs_aligned\": aligner.ROIs_aligned,\n", + " \"ROIs_raw\": data.spatialFootprints,\n", + " \"frame_height\": data.FOV_height,\n", + " \"frame_width\": data.FOV_width,\n", + " \"idx_roi_session\": np.where(data.session_bool)[1],\n", + " \"n_sessions\": data.n_sessions,\n", + " },\n", + " \"input_data\": {\n", + " \"paths_stat\": data.paths_stat,\n", + " \"paths_ops\": data.paths_ops,\n", + " },\n", + " \"quality_metrics\": clusterer.quality_metrics if hasattr(clusterer, 'quality_metrics') else None,\n", + "}\n", + "\n", + "run_data = copy.deepcopy({\n", + " 'data': data.serializable_dict,\n", + " 'aligner': aligner.serializable_dict,\n", + " 'blurrer': blurrer.serializable_dict,\n", + " 'roinet': roinet.serializable_dict,\n", + " 'swt': swt.serializable_dict,\n", + " 'sim': sim.serializable_dict,\n", + " 'clusterer': clusterer.serializable_dict,\n", + "})\n", + "\n", + "iscell_bySession = [np.load(ic_p)[:,0] for ic_p in data.paths_iscell]\n", + "\n", + "with open(COLLECTIVE_MUSE_DIR / 'roicat-output.pkl', 'wb') as f:\n", + " pickle.dump(dict(\n", + " run_data = run_data,\n", + " results = results,\n", + " iscell = iscell_bySession\n", + " ), f)\n", + "\n", + "\n", + "print(f'Number of clusters: {len(np.unique(results[\"clusters\"][\"labels\"]))}')\n", + "print(f'Number of discarded ROIs: {(results[\"clusters\"][\"labels\"]==-1).sum()}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac4f230d-4aa0-40dd-86ef-3deafb75580b", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Visualize\n", + "confidence = (((results['quality_metrics']['cluster_silhouette'] + 1) / 2) * results['quality_metrics']['cluster_intra_means'])\n", + "\n", + "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))\n", + "\n", + "axs[0,0].hist(results['quality_metrics']['cluster_silhouette'], 50);\n", + "axs[0,0].set_xlabel('cluster_silhouette');\n", + "axs[0,0].set_ylabel('cluster counts');\n", + "\n", + "axs[0,1].hist(results['quality_metrics']['cluster_intra_means'], 50);\n", + "axs[0,1].set_xlabel('cluster_intra_means');\n", + "axs[0,1].set_ylabel('cluster counts');\n", + "\n", + "axs[1,0].hist(confidence, 50);\n", + "axs[1,0].set_xlabel('confidence');\n", + "axs[1,0].set_ylabel('cluster counts');\n", + "\n", + "axs[1,1].hist(results['quality_metrics']['sample_silhouette'], 50);\n", + "axs[1,1].set_xlabel('sample_silhouette score');\n", + "axs[1,1].set_ylabel('roi sample counts');\n", + "\n", + "fig.savefig(COLLECTIVE_MUSE_FIG_DIR / 'cluster-metrics.png')\n", + "\n", + "# FOV clusters\n", + "FOV_clusters = roicat.visualization.compute_colored_FOV(\n", + " spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], \n", + " FOV_height=results['ROIs']['frame_height'],\n", + " FOV_width=results['ROIs']['frame_width'],\n", + " labels=results[\"clusters\"][\"labels_bySession\"], ## cluster labels\n", + " # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array\n", + " alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),\n", + "# alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array\n", + ")\n", + "\n", + "FOV_clusters_with_iscell = roicat.visualization.compute_colored_FOV(\n", + " spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], ## Spatial footprint sparse arrays\n", + " FOV_height=results['ROIs']['frame_height'],\n", + " FOV_width=results['ROIs']['frame_width'],\n", + " labels=results[\"clusters\"][\"labels_bySession\"], ## cluster labels\n", + " # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array\n", + " alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),\n", + " alphas_sf=iscell_bySession\n", + "# alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array\n", + ")\n", + "\n", + "roicat.helpers.save_gif(\n", + " array=FOV_clusters, \n", + " path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_allrois.gif'),\n", + " frameRate=5.0,\n", + " loop=0,\n", + ")\n", + "\n", + "roicat.helpers.save_gif(\n", + " array=FOV_clusters_with_iscell, \n", + " path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_iscells.gif'),\n", + " frameRate=5.0,\n", + " loop=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a266c7c8-0cbe-4277-b3b4-8ef338ffd005", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(20,10), layout='tight')\n", + "roi_image_dict = {\n", + " 'all': FOV_clusters,\n", + " 'iscell': FOV_clusters_with_iscell\n", + "}\n", + "nrows = len(roi_image_dict)\n", + "ncols = data.n_sessions\n", + "\n", + "splt_cnt = 1\n", + "for k, image_list in roi_image_dict.items():\n", + " for s_id, img in enumerate(image_list):\n", + " plt.subplot(nrows, ncols, splt_cnt)\n", + " plt.imshow(img)\n", + " plt.axis('off')\n", + " plt.title(f'Aligned {k} [#{s_id}]') \n", + " splt_cnt += 1\n", + "\n", + "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-rois.png')\n", + "\n", + "# save FOVs\n", + "num_sessions = data.n_sessions\n", + "out_img = []\n", + "for d in range(num_sessions):\n", + " out_img.append(dict(\n", + " fov = aligner.ims_registered_geo[d],\n", + " roi_pre_iscell = FOV_clusters[d],\n", + " roi_with_iscell = FOV_clusters_with_iscell[d],\n", + " **{k: v[d] for k,v in aligned_backgrounds.items()},\n", + " ))\n", + "\n", + "with open(COLLECTIVE_MUSE_DIR / 'aligned-images.pkl', 'wb') as f:\n", + " pickle.dump(out_img, f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "027adff5-6ad3-4c96-87d0-47d8c83f7c9a", + "metadata": {}, + "outputs": [], + "source": [ + "# save summary data\n", + "df = pd.DataFrame([\n", + " dict(\n", + " session=i, \n", + " global_roi=glv, \n", + " session_roi=range(len(glv)),\n", + " iscell = iscv\n", + " )\n", + " for i, (glv, iscv) in enumerate(zip(labels_bySession, iscell_bySession))\n", + "]).explode(['global_roi','session_roi', 'iscell']).astype({'iscell': 'bool'})\n", + "\n", + "df.to_csv(COLLECTIVE_MUSE_DIR / 'summary-roi.csv', index=False)\n", + "\n", + "# process only iscell for stats\n", + "df = (\n", + " df.query('iscell')\n", + " .reset_index(drop=True)\n", + ")\n", + "\n", + "df = df.merge(\n", + " (\n", + " df.query('global_roi >= 0')\n", + " .groupby('global_roi')\n", + " ['session'].agg(lambda x: len(list(x)))\n", + " .to_frame('num_sessions')\n", + " .reset_index()\n", + " ),\n", + " how='left'\n", + ")\n", + "\n", + "df = (\n", + " df.fillna({'num_sessions': 1})\n", + " .astype({'num_sessions': 'int'})\n", + ")\n", + "\n", + "# re-indexing\n", + "persistent_roi_reindices = (\n", + " df[['num_sessions', 'global_roi']]\n", + " .query('global_roi >= 0 and num_sessions > 1')\n", + " .drop_duplicates()\n", + " .sort_values('num_sessions', ascending=False)\n", + " .reset_index(drop=True)\n", + " .reset_index()\n", + " .set_index('global_roi')\n", + " ['index'].to_dict()\n", + ")\n", + "\n", + "df['reindexed_global_roi'] = df['global_roi'].map(persistent_roi_reindices)\n", + "\n", + "single_roi_start_indices = df['reindexed_global_roi'].max() + 1\n", + "single_roi_rows = df.query('reindexed_global_roi.isna()').index\n", + "num_single_rois = len(single_roi_rows)\n", + "\n", + "df.loc[single_roi_rows, 'reindexed_global_roi'] = \\\n", + " np.arange(num_single_rois) + single_roi_start_indices\n", + "\n", + "df['reindexed_global_roi'] = df['reindexed_global_roi'].astype('int')\n", + "df = df.rename(columns={\n", + " 'global_roi': 'roicat_global_roi', \n", + " 'reindexed_global_roi': 'global_roi'\n", + "})\n", + "\n", + "df.to_csv(COLLECTIVE_MUSE_DIR / 'finalized-roi.csv', index=False)\n", + "\n", + "# plot persistent ROIs summary\n", + "persist_rois = (\n", + " df\n", + " .drop_duplicates(['global_roi'])\n", + " .value_counts('num_sessions', sort=False)\n", + " .to_frame('num_rois')\n", + " .reset_index()\n", + ")\n", + "\n", + "plt.figure(figsize=(4,5))\n", + "ax = sns.barplot(\n", + " persist_rois, \n", + " x = 'num_sessions',\n", + " y = 'num_rois',\n", + " hue = 'num_sessions',\n", + " facecolor = '#afafaf',\n", + " dodge=False,\n", + " edgecolor='k'\n", + ")\n", + "sns.despine(trim=True, offset=10)\n", + "\n", + "plt.legend([], [], frameon=False)\n", + "[ax.bar_label(c, padding=5, fontsize=10) for c in ax.containers]\n", + "plt.xlabel('# sessions')\n", + "plt.ylabel('# rois')\n", + "plt.title('Persisted ROIs')\n", + "plt.tight_layout()\n", + "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-overall.png')\n", + "\n", + "# plot persistent ROIs per sessions\n", + "df_sessions = (\n", + " df\n", + " .value_counts(['session','num_sessions'])\n", + " .to_frame('count')\n", + " .reset_index()\n", + ")\n", + "\n", + "df_total_per_session = (\n", + " df_sessions\n", + " .groupby('session')\n", + " ['count'].agg('sum')\n", + " .to_frame('total_count')\n", + " .reset_index()\n", + ")\n", + "\n", + "df_sessions = df_sessions.merge(df_total_per_session, how='left')\n", + "df_sessions['percent'] = 100 * df_sessions['count'] / df_sessions['total_count']\n", + "\n", + "plt.figure(figsize=(10,5))\n", + "bar_kwargs = dict(\n", + " kind='bar',\n", + " stacked=True, \n", + " colormap='GnBu', \n", + " width=0.7, \n", + " edgecolor='k',\n", + ")\n", + "\n", + "ax1 = plt.subplot(121)\n", + "\n", + "(\n", + " df_sessions\n", + " .pivot(index='session',columns='num_sessions', values='count')\n", + " .fillna(0)\n", + " .plot(\n", + " **bar_kwargs,\n", + " xlabel='session ID',\n", + " ylabel='# rois',\n", + " legend=False,\n", + " ax=ax1)\n", + ")\n", + "plt.tick_params(rotation=0)\n", + "\n", + "ax2 = plt.subplot(122)\n", + "\n", + "(\n", + " df_sessions\n", + " .pivot(index='session',columns='num_sessions', values='percent')\n", + " .fillna(0)\n", + " .plot(\n", + " **bar_kwargs,\n", + " xlabel='session ID',\n", + " ylabel='% roi per session',\n", + " ax=ax2\n", + " )\n", + ")\n", + "plt.tick_params(rotation=0)\n", + "\n", + "leg_handles, leg_labels = plt.gca().get_legend_handles_labels()\n", + "plt.legend(\n", + " reversed(leg_handles),\n", + " reversed(leg_labels),\n", + " loc='upper right', \n", + " bbox_to_anchor=[1.5,1], \n", + " title='# sessions',\n", + " edgecolor='k',\n", + ")\n", + "\n", + "sns.despine(trim=True, offset=10)\n", + "\n", + "plt.suptitle(\n", + " 'Distribution of detected and aligned ROIs across sessions',\n", + ")\n", + "plt.tight_layout(w_pad=5)\n", + "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-per-session.png')\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "roicat", + "language": "python", + "name": "roicat" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "scenes_data": { + "active_scene": "Default Scene", + "init_scene": "", + "scenes": [ + "Default Scene" + ] + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/0-makedirs.sh b/scripts/0-makedirs.sh new file mode 100644 index 0000000..aa409a6 --- /dev/null +++ b/scripts/0-makedirs.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# TODO: consider changing "4-muse" to "4-roicat" + +# define constants +SINGLE_SESSION_SUBDIRS=( "0-raw" "1-moco" "2-deepcad" "3-suite2p" "4-muse" ) +MULTI_SESSION_DIR="multi-session" +LOG_DIR="./logs" + +# define variables +ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/testdir" +SUBJECT_LIST=( "MS2457" ) +DATE_LIST=( "20230902" "20230907" "20230912" ) +PLANE_LIST=( "plane0" "plane1" ) + +# make log folders +for SUBDIR in "${SINGLE_SESSION_SUBDIRS[@]}"; do + mkdir -p "$LOG_DIR/$SUBDIR" +done + +mkdir -p "$LOG_DIR/4-roicat" + +# loop through combinations +for SUBJECT in "${SUBJECT_LIST[@]}"; do + +echo ":::::: BEGIN ($SUBJECT) ::::::" +echo "Creating single-session directories:" + +for DATE in "${DATE_LIST[@]}"; do +for PLANE in "${PLANE_LIST[@]}"; do +for SUBDIR in "${SINGLE_SESSION_SUBDIRS[@]}"; do + + SINGLE_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$SUBDIR/$PLANE" + mkdir -p "$SINGLE_DIR" + echo -e "\t$SINGLE_DIR" + +done +done +done + +echo "Creating multi-session directories:" +for PLANE in "${PLANE_LIST[@]}"; do + MULTI_DIR="$ROOT_DATADIR/$SUBJECT/$MULTI_SESSION_DIR/$PLANE" + mkdir -p "$MULTI_DIR" + echo -e "\t$MULTI_DIR" +done + + +echo ":::::: DONE ($SUBJECT) ::::::" +echo "------------------------------" +echo "" + +done + +# list raw dirs +echo "Please put raw TIF data in these folders, according to their name" +find "$ROOT_DATADIR" -type d -wholename "**/0-raw/*" + diff --git a/scripts/1-moco.sh b/scripts/1-moco.sh new file mode 100644 index 0000000..02baacc --- /dev/null +++ b/scripts/1-moco.sh @@ -0,0 +1,59 @@ +#!/bin/bash +#SBATCH -N 1 +#SBATCH -n 16 +#SBATCH -p batch +#SBATCH --account=carney-afleisc2-condo +#SBATCH --time=01:00:00 +#SBATCH --mem=20g +#SBATCH --job-name 1-moco +#SBATCH --output logs/1-moco/log-%J.out +#SBATCH --error logs/1-moco/log-%J.err + +# define constants +EXEC_FILE="cellreg/1-moco.py" +EXPECTED_INPUT_SUBDIR="0-raw" + +# define variables +MOCO_CFG_PATH="config/1-moco.yml" +ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test" +SUBJECT_LIST=( "MS2457" ) +DATE_LIST=( "20230902" "20230907" "20230912" ) +PLANE_LIST=( "plane0" "plane1" ) + +# activate environment +module load miniconda3/23.11.0s +source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh +conda activate suite2p +command -v python +python "$EXEC_FILE" --help + + +# loop through combinations +for SUBJECT in "${SUBJECT_LIST[@]}"; do +for DATE in "${DATE_LIST[@]}"; do +for PLANE in "${PLANE_LIST[@]}"; do + + echo ":::::: BEGIN ($SUBJECT, $DATE, $PLANE) ::::::" + + EXPECTED_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$EXPECTED_INPUT_SUBDIR/$PLANE" + if [ ! -d "$EXPECTED_DIR" ]; then + echo "$EXPECTED_DIR does not exist. Skip" + continue + fi + + python "$EXEC_FILE" \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --date "$DATE" \ + --plane "$PLANE" \ + --config "$MOCO_CFG_PATH" \ + --cleanup + + echo ":::::: DONE ($SUBJECT, $DATE, $PLANE) ::::::" + echo "--------------------------------------------" + echo "" + +done +done +done + diff --git a/scripts/2-deepcad.sh b/scripts/2-deepcad.sh new file mode 100644 index 0000000..7cdc57e --- /dev/null +++ b/scripts/2-deepcad.sh @@ -0,0 +1,63 @@ +#!/bin/bash +#SBATCH -p gpu --gres=gpu:1 +#SBATCH --account=carney-afleisc2-condo +#SBATCH -N 1 +#SBATCH -n 4 +#SBATCH --time=01:00:00 +#SBATCH --mem=64g +#SBATCH --job-name 2-deepcad +#SBATCH --output logs/2-deepcad/log-%J.out +#SBATCH --error logs/2-deepcad/log-%J.err + +# define constants +EXEC_FILE="cellreg/2-deepcad.py" +EXPECTED_INPUT_SUBDIR="1-moco" + +# define variables +DEEPCAD_CFG_PATH="config/2-deepcad.yml" +ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test" +SUBJECT_LIST=( "MS2457" ) +DATE_LIST=( "20230902" "20230907" "20230912" ) +PLANE_LIST=( "plane0" "plane1" ) + +# activate environment +module load miniconda3/23.11.0s +source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh +conda activate deepcad + +# TODO: check if this is needed anymore +# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib + +command -v python +python "$EXEC_FILE" --help + + +# loop through combinations +for SUBJECT in "${SUBJECT_LIST[@]}"; do +for DATE in "${DATE_LIST[@]}"; do +for PLANE in "${PLANE_LIST[@]}"; do + + echo ":::::: BEGIN ($SUBJECT, $DATE, $PLANE) ::::::" + + EXPECTED_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$EXPECTED_INPUT_SUBDIR/$PLANE" + if [ ! -d "$EXPECTED_DIR" ]; then + echo "$EXPECTED_DIR does not exist. Skip" + continue + fi + + python "$EXEC_FILE" \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --date "$DATE" \ + --plane "$PLANE" \ + --config "$DEEPCAD_CFG_PATH" \ + --cleanup + + echo ":::::: DONE ($SUBJECT, $DATE, $PLANE) ::::::" + echo "--------------------------------------------" + echo "" + +done +done +done + diff --git a/scripts/3-suite2p.sh b/scripts/3-suite2p.sh new file mode 100644 index 0000000..d61bc8e --- /dev/null +++ b/scripts/3-suite2p.sh @@ -0,0 +1,58 @@ +#!/bin/bash +#SBATCH -N 1 +#SBATCH -n 16 +#SBATCH -p batch +#SBATCH --account=carney-afleisc2-condo +#SBATCH --time=03:00:00 +#SBATCH --mem=30g +#SBATCH --job-name 3-suite2p +#SBATCH --output logs/3-suite2p/log-%J.out +#SBATCH --error logs/3-suite2p/log-%J.err + +# define constants +EXEC_FILE="cellreg/3-suite2p.py" +EXPECTED_INPUT_SUBDIR="2-deepcad" + +# define variables +SUITE2P_CFG_PATH="config/3-suite2p.yml" +ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test" +SUBJECT_LIST=( "MS2457" ) +DATE_LIST=( "20230902" "20230907" "20230912" ) +PLANE_LIST=( "plane0" "plane1" ) + +# activate environment +module load miniconda3/23.11.0s +source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh +conda activate suite2p +command -v python +python "$EXEC_FILE" --help + +# loop through combinations +for SUBJECT in "${SUBJECT_LIST[@]}"; do +for DATE in "${DATE_LIST[@]}"; do +for PLANE in "${PLANE_LIST[@]}"; do + + echo ":::::: BEGIN ($SUBJECT, $DATE, $PLANE) ::::::" + + EXPECTED_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$EXPECTED_INPUT_SUBDIR/$PLANE" + if [ ! -d "$EXPECTED_DIR" ]; then + echo "$EXPECTED_DIR does not exist. Skip" + continue + fi + + python "$EXEC_FILE" \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --date "$DATE" \ + --plane "$PLANE" \ + --config "$SUITE2P_CFG_PATH" \ + --cleanup + + echo ":::::: DONE ($SUBJECT, $DATE, $PLANE) ::::::" + echo "--------------------------------------------" + echo "" + +done +done +done + diff --git a/scripts/4-roicat.sh b/scripts/4-roicat.sh new file mode 100644 index 0000000..d571455 --- /dev/null +++ b/scripts/4-roicat.sh @@ -0,0 +1,44 @@ +#!/bin/bash +#SBATCH -N 1 +#SBATCH -p batch +#SBATCH -n 8 +#SBATCH --account=carney-afleisc2-condo +#SBATCH --time=00:30:00 +#SBATCH --mem=30g +#SBATCH --job-name 4-roicat +#SBATCH --output logs/4-roicat/log-%J.out +#SBATCH --error logs/4-roicat/log-%J.err + +# define constants +EXEC_FILE="cellreg/4-roicat.py" + +# define variables +ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test" +SUBJECT_LIST=( "MS2457" ) +PLANE_LIST=( "plane0" "plane1" ) + +# activate environment +module load miniconda3/23.11.0s +source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh +conda activate roicat +command -v python +python "$EXEC_FILE" --help + +# loop through combinations +for SUBJECT in "${SUBJECT_LIST[@]}"; do +for PLANE in "${PLANE_LIST[@]}"; do + + echo ":::::: BEGIN ($SUBJECT, $PLANE) ::::::" + + python "$EXEC_FILE" \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --plane "$PLANE" + + echo ":::::: DONE ($SUBJECT, $PLANE) ::::::" + echo "--------------------------------------" + echo "" + +done +done + diff --git a/scripts/example-steps.sh b/scripts/example-steps.sh new file mode 100644 index 0000000..d8c3fdb --- /dev/null +++ b/scripts/example-steps.sh @@ -0,0 +1,55 @@ +ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test" +SUBJECT="MS2457" +DATE="20230902" +PLANE="plane0" + +MOCO_CFG_PATH="config/1-moco.yml" +DEEPCAD_CFG_PATH="config/2-deepcad.yml" +SUITE2P_CFG_PATH="config/3-suite2p.yml" + +# cpu env +conda acivate suite2p + +python cellreg/1-moco.py \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --date "$DATE" \ + --plane "$PLANE" \ + --config "$MOCO_CFG_PATH" \ + --cleanup + +# gpu env +# e.g.: interact -q gpu -g 1 -t 01:00:00 -m 64g -n 4 +conda activate deepcad + +# try the following if there's an error with the `python cellreg/2-deepcad.py` step +# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib + +python cellreg/2-deepcad.py \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --date "$DATE" \ + --plane "$PLANE" \ + --config "$DEEPCAD_CFG_PATH" \ + --cleanup + +# cpu env +conda acivate suite2p + +python cellreg/3-suite2p.py \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --date "$DATE" \ + --plane "$PLANE" \ + --config "$SUITE2P_CFG_PATH" \ + --cleanup + + +# after all dates are done +# cpu env is sufficient +conda activate roicat +python cellreg/4-roicat.py \ + --root-datadir "$ROOT_DATADIR" \ + --subject "$SUBJECT" \ + --plane "$PLANE" +