From 0178502fe9124304886bfd9cf9de240701e65939 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Thu, 14 Dec 2023 13:15:27 -0500
Subject: [PATCH 01/14] update from Ford
---
cellreg/deepcad/test_collection.py | 1 +
deepcad_model/E_05_Iter_3136.pth | Bin
deepcad_model/para.yaml | 0
3 files changed, 1 insertion(+)
mode change 100755 => 100644 deepcad_model/E_05_Iter_3136.pth
mode change 100755 => 100644 deepcad_model/para.yaml
diff --git a/cellreg/deepcad/test_collection.py b/cellreg/deepcad/test_collection.py
index 56e3c0c..3fd1e28 100755
--- a/cellreg/deepcad/test_collection.py
+++ b/cellreg/deepcad/test_collection.py
@@ -135,6 +135,7 @@ def read_imglist(self):
def read_modellist(self):
model_path = self.pth_dir + "/" + self.denoise_model
+ print("MODEL PATH: ", model_path)
model_list = list(os.walk(model_path, topdown=False))[-1][-1]
model_list.sort()
diff --git a/deepcad_model/E_05_Iter_3136.pth b/deepcad_model/E_05_Iter_3136.pth
old mode 100755
new mode 100644
diff --git a/deepcad_model/para.yaml b/deepcad_model/para.yaml
old mode 100755
new mode 100644
From 07165ddeccfa2b5f6577bda744a3d4faa00abe51 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Thu, 14 Dec 2023 13:17:10 -0500
Subject: [PATCH 02/14] add pipeline mkdir-moco-deepcad-suite2p
---
.gitignore | 3 +-
STEPS.md | 296 +++++++++++++++++++++++++++++++++++++++
cellreg/1-moco.py | 151 ++++++++++++++++++++
cellreg/2-deepcad.py | 98 +++++++++++++
cellreg/3-suite2p.py | 131 +++++++++++++++++
config/1-moco.yml | 11 ++
config/2-deepcad.yml | 11 ++
config/3-suite2p.yml | 29 ++++
environments/deepcad.yml | 35 +++++
environments/roicat.yml | 13 ++
environments/suite2p.yml | 12 ++
scripts/0-makedirs.sh | 54 +++++++
scripts/1-moco.sh | 60 ++++++++
scripts/2-deepcad.sh | 66 +++++++++
scripts/3-suite2p.sh | 60 ++++++++
scripts/example-steps.sh | 44 ++++++
16 files changed, 1073 insertions(+), 1 deletion(-)
create mode 100644 STEPS.md
create mode 100644 cellreg/1-moco.py
create mode 100644 cellreg/2-deepcad.py
create mode 100644 cellreg/3-suite2p.py
create mode 100644 config/1-moco.yml
create mode 100644 config/2-deepcad.yml
create mode 100644 config/3-suite2p.yml
create mode 100644 environments/deepcad.yml
create mode 100644 environments/roicat.yml
create mode 100644 environments/suite2p.yml
create mode 100644 scripts/0-makedirs.sh
create mode 100644 scripts/1-moco.sh
create mode 100644 scripts/2-deepcad.sh
create mode 100644 scripts/3-suite2p.sh
create mode 100644 scripts/example-steps.sh
diff --git a/.gitignore b/.gitignore
index e966119..d7df25e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -77,6 +77,7 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
+.jupyter
# IPython
profile_default/
@@ -160,4 +161,4 @@ cython_debug/
#.idea/
# MacOS crufts
-.DS_Store
\ No newline at end of file
+.DS_Store
diff --git a/STEPS.md b/STEPS.md
new file mode 100644
index 0000000..bc1886f
--- /dev/null
+++ b/STEPS.md
@@ -0,0 +1,296 @@
+# Pipeline (under development)
+
+## Multiple environment setups
+
+There are at least 3 environments that need to be created to manage steps separately:
+
+```shell
+conda env create -f environments/suite2p.yml
+conda env create -f environments/deepcad.yml
+conda env create -f environments/roicat.yml
+```
+
+## Steps
+
+The `scripts/example-steps.sh` show examples of steps to follow for a single animal, single session, single plane.
+
+The following demonstrates how to run things in batches.
+
+### 0-makedirs: Prepare directory structure
+
+**Purpose**: this prepares the directory structure that is expected from all scripts, as well as necessary log folders. Generally the folder structure follows:
+
+```
+└── SUBJECT
+ ├── DATE
+ │ └── STEP
+ │ └── PLANE
+ └── multi-session
+ └── PLANE
+```
+
+**Task**:
+
+- Please edit the variables under `# define variables` in `scripts/0-makedirs.sh` before running
+- Then run:
+
+```shell
+bash scripts/0-makedirs.sh
+```
+
+- Then put the raw TIFF data under `0-raw` folders
+
+**Output**: The following is an example output directory tree from the script:
+
+Click to expand example output
+
+```
+└── MS2457
+ ├── 20230902
+ │ ├── 0-raw
+ │ │ ├── plane0 # put your raw TIFF here
+ │ │ └── plane1 # put your raw TIFF here
+ │ ├── 1-moco
+ │ │ ├── plane0
+ │ │ └── plane1
+ │ ├── 2-deepcad
+ │ │ ├── plane0
+ │ │ └── plane1
+ │ ├── 3-suite2p
+ │ │ ├── plane0
+ │ │ └── plane1
+ │ └── 4-muse
+ │ ├── plane0
+ │ └── plane1
+ .
+ .
+ .
+ └── multi-session
+ ├── plane0
+ └── plane1
+
+```
+
+
+Currently only 1 file per folder is expected.
+
+You will be reminded at the end to put your raw data in appropriate folders, like this
+
+```
+Please put raw TIF data in these folders, according to their name
+/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230902/0-raw/plane0
+/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230902/0-raw/plane1
+/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230907/0-raw/plane0
+/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230907/0-raw/plane1
+/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230912/0-raw/plane0
+/oscar/data/afleisc2/collab/multiday-reg/data/testdir/MS2457/20230912/0-raw/plane1
+```
+
+So that it looks more like this:
+
+```
+└── MS2457
+ ├── 20230902
+ │ ├── 0-raw
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1.tif
+...
+```
+
+
+
+
+### 1-moco: Motion correction
+
+**Purpose**: this uses `suite2p` to perform motion correction before denoising in a CPU environment on Oscar, preferably high number of CPUs and sufficient memory.
+
+**Task**:
+
+- Please edit the variables under `# define variables` in `scripts/1-moco.sh` before running
+- Take a look at `config/1-moco.yml` and change if you need to, refer to [`suite2p` registration parameter documentation](https://suite2p.readthedocs.io/en/latest/settings.html#registration-settings) for more information
+
+- Then run:
+
+```shell
+sbatch scripts/1-moco.sh
+```
+
+- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/1-moco/logs-.out`.
+
+**Output**: The following is an example output directory tree from the script:
+
+
+Click to expand example output
+
+```
+└── MS2457
+ ├── 20230902
+ │ ├── 0-raw
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1.tif
+ │ ├── 1-moco
+ │ │ ├── plane0
+ │ │ │ ├── MS2457-20230902-plane0_mc.tif
+ │ │ │ └── suite2p-moco-only
+ │ │ │ ├── data.bin
+ │ │ │ ├── data_raw.bin
+ │ │ │ ├── ops.npy
+ │ │ │ └── reg_tif
+ │ │ │ ├── file00000_chan0.tif
+ │ │ │ ├── file001000_chan0.tif
+ │ │ │ └── file00500_chan0.tif
+ │ │ └── plane1
+ │ │ ├── MS2457-20230902-plane1_mc.tif
+ │ │ └── suite2p-moco-only
+ │ │ ├── data.bin
+ │ │ ├── data_raw.bin
+ │ │ ├── ops.npy
+ │ │ └── reg_tif
+ │ │ ├── file00000_chan0.tif
+ │ │ ├── file001000_chan0.tif
+ │ │ └── file00500_chan0.tif
+...
+```
+
+The file `MS2457-20230902-plane0_mc.tif` is created by concatenating files under `suite2p-moco-only/re_tif` folder.
+
+This file is expected to be here for the next step.
+
+Currently only 1 file per folder is expected.
+
+
+
+
+### 2-deepcad: Denoising using `deepcad`
+
+**Purpose**: this uses `deepcad` to perform denoising on the motion-corrected file, in a GPU environment on Oscar, preferably high memory.
+
+**Task**:
+
+- Please edit the variables under `# define variables` in `scripts/2-deepcad.sh` before running
+- Take a look at `config/2-deepcad.yml` and change if you need to
+ - [ ] **TODO** Where is documentation for `deepcad` parameters?
+- Then run:
+
+```shell
+sbatch scripts/2-deepcad.sh
+```
+
+- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/2-deepcad/logs-.out`.
+
+**Output**: The following is an example output directory tree from the script:
+
+Click to expand example output
+
+```
+# some folders and files are hidden to prioritize changes
+└── MS2457
+ ├── 20230902
+ │ ├── 0-raw
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1.tif
+ │ ├── 1-moco
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0_mc.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1_mc.tif
+ │ ├── 2-deepcad
+ │ │ ├── plane0
+ │ │ │ ├── MS2457-20230902-plane0_mc-dc.tif
+ │ │ │ └── para.yaml
+ │ │ └── plane1
+ │ │ ├── MS2457-20230902-plane1_mc-dc.tif
+ │ │ └── para.yaml
+...
+```
+
+
+Notice files under `2-deepcad` are created, for example `MS2457-20230902-plane0_mc-dc.tif`.
+
+This file is expected to be here for the next step.
+
+
+
+
+### 3-suite2p: `suite2p` on denoised data
+
+**Purpose**: this uses `suite2p` on the denoised data from the `deepcad` step, in a CPU environment on Oscar, preferably high number of CPUs and sufficient memory.
+
+**Task**:
+
+- Please edit the variables under `# define variables` in `scripts/3-suite2p.sh` before running
+- Take a look at `config/3-suite2p.yml` and change if you need to, refer to [`suite2p` parameter documentation](https://suite2p.readthedocs.io/en/latest/settings.html)
+ - Currently we still need to re-run registration
+ - Some hacks in parameters are documented in this config file
+ - If you're familiar with `suite2p`, use your experience and judgements to determine these
+- Then run:
+
+```shell
+sbatch scripts/3-suite2p.sh
+```
+
+- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/3-suite2p/logs-.out`.
+
+**Output**: The following is an example output directory tree from the script:
+
+
+Click to expand example output
+
+
+```
+└── MS2457
+ ├── 20230902
+ │ ├── 0-raw
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1.tif
+ │ ├── 1-moco
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0_mc.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1_mc.tif
+ │ ├── 2-deepcad
+ │ │ ├── plane0
+ │ │ │ └── MS2457-20230902-plane0_mc-dc.tif
+ │ │ └── plane1
+ │ │ └── MS2457-20230902-plane1_mc-dc.tif
+ │ ├── 3-suite2p
+ │ │ ├── plane0
+ │ │ │ ├── F.npy
+ │ │ │ ├── Fneu.npy
+ │ │ │ ├── data.bin
+ │ │ │ ├── figures
+ │ │ │ │ ├── backgrounds.png
+ │ │ │ │ └── roi-masks.png
+ │ │ │ ├── iscell.npy
+ │ │ │ ├── ops.npy
+ │ │ │ ├── spks.npy
+ │ │ │ └── stat.npy
+ │ │ └── plane1
+ │ │ ├── F.npy
+ │ │ ├── Fneu.npy
+ │ │ ├── data.bin
+ │ │ ├── figures
+ │ │ │ ├── backgrounds.png
+ │ │ │ └── roi-masks.png
+ │ │ ├── iscell.npy
+ │ │ ├── ops.npy
+ │ │ ├── spks.npy
+ │ │ └── stat.npy
+...
+```
+
+The folders under `3-suite2p` will be very familiar in terms of namings and organizations for `suite2p`-familiar folks
+
+
+
+### 4-roicat: Multisession registration using `roicat`
+
+TBD
diff --git a/cellreg/1-moco.py b/cellreg/1-moco.py
new file mode 100644
index 0000000..561538f
--- /dev/null
+++ b/cellreg/1-moco.py
@@ -0,0 +1,151 @@
+import os
+import glob
+import shutil
+from pathlib import Path
+
+from natsort import natsorted
+from tqdm import tqdm
+import numpy as np
+
+import tifffile
+import suite2p
+
+import argparse
+import yaml
+
+
+# this stores constant suite2p ops to force overwrite config
+CONSTANT_MOCO_SUITE2P_OPS = {
+ 'do_registration': 1,
+ 'reg_tif': True,
+ 'roidetect': False
+}
+
+
+def get_tiff_list(folder, ext=['.tif', '.tiff']):
+ file_list = []
+ for x in ext:
+ file_list.extend(list(folder.glob(f'*{x}')))
+ return file_list
+
+def merge_tiff_stacks(input_tiffs, output_file):
+ with tifffile.TiffWriter(output_file) as stack:
+ for filename in tqdm(input_tiffs):
+ with tifffile.TiffFile(filename) as tif:
+ for page in tif.pages:
+ stack.write(
+ page.asarray(),
+ photometric='minisblack',
+ contiguous=True
+ )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser('Run motion correction via suite2p (moco) [only 1 channel allow]')
+
+ parser.add_argument("--root-datadir", type=str)
+ parser.add_argument("--subject", type=str)
+ parser.add_argument("--date", type=str)
+ parser.add_argument("--plane", type=str)
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--cleanup", action="store_true")
+
+ args = parser.parse_args()
+ print(args)
+
+ # process data paths
+ ROOT_DATA_DIR = args.root_datadir
+ SUBJECT_ID = args.subject
+ DATE_ID = args.date
+ PLANE_ID = args.plane
+ MOCO_CFG_PATH = args.config
+ CLEAN_UP = args.cleanup
+
+ # define paths
+ EXP_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID / DATE_ID
+ RAW_DIR = EXP_DIR / '0-raw' / PLANE_ID
+ MOCO_DIR = EXP_DIR / '1-moco' / PLANE_ID
+ MOCO_FILEPATH = MOCO_DIR / f'{SUBJECT_ID}-{DATE_ID}-{PLANE_ID}_mc.tif'
+ AUX_SUITE2P_DIR = MOCO_DIR / 'suite2p-moco-only'
+
+ # find tiffs
+ tiff_list = get_tiff_list(RAW_DIR)
+ tiff_list_str = "\n".join(map(str,tiff_list))
+ assert len(tiff_list) == 1, \
+ f'Currently only accepting one file inside the raw directory "{RAW_DIR}", ' \
+ f'found the following: {tiff_list_str}'
+
+ raw_tiff_path = tiff_list[0]
+ with tifffile.TiffFile(str(raw_tiff_path)) as tif:
+ num_frames = len(tif.pages)
+ page0 = tif.pages[0]
+ Ly, Lx = page0.shape
+ print(f'Found raw tiff file ({num_frames} frames of {Ly} x {Lx}): "{raw_tiff_path}"')
+
+ # make db
+ db = {
+ 'data_path': str(MOCO_DIR),
+ 'save_path0': str(AUX_SUITE2P_DIR),
+ 'tiff_list': tiff_list,
+ }
+ print('Data settings: ')
+ print(db)
+
+ # make ops
+ ops = suite2p.default_ops()
+ with open(MOCO_CFG_PATH, 'r') as f:
+ print(f'Config file loaded from: "{MOCO_CFG_PATH}"')
+ ops_cfg = yaml.safe_load(f)
+ ops.update(ops_cfg)
+ ops.update(CONSTANT_MOCO_SUITE2P_OPS)
+ ops.update(db)
+ print('Suite2p configuration')
+ print(ops)
+
+ # convert tiff to binary
+ ops = suite2p.io.tiff_to_binary(ops)
+ f_raw = suite2p.io.BinaryFile(
+ Ly=Ly, Lx=Lx,
+ filename=ops['raw_file']
+ )
+
+ # prepare registration file
+ f_reg = suite2p.io.BinaryFile(
+ Ly=Ly, Lx=Lx, n_frames = f_raw.shape[0],
+ filename = ops['reg_file']
+ )
+
+ # registration
+ suite2p.registration_wrapper(
+ f_reg, f_raw=f_raw,
+ f_reg_chan2=None, f_raw_chan2=None, align_by_chan2=False,
+ refImg=None, ops=ops
+ )
+
+ # combine registration batches
+ REG_TIFF_DIR = Path(ops['reg_file']).parents[0] / 'reg_tif'
+ reg_tiff_list = natsorted(list(map(str,REG_TIFF_DIR.glob('*.tif*'))))
+ reg_tiff_list_str = "\n".join(reg_tiff_list)
+ print(f'Combining the following tiff files into "{MOCO_FILEPATH}":\n{reg_tiff_list_str}')
+
+ merge_tiff_stacks(
+ input_tiffs=reg_tiff_list,
+ output_file=MOCO_FILEPATH
+ )
+
+ # move files
+ if CLEAN_UP:
+ SRC_OUT_DIR = os.path.dirname(ops['reg_file'])
+ PARENT_SRC_DIR = os.path.dirname(SRC_OUT_DIR)
+ assert os.path.basename(PARENT_SRC_DIR) == "suite2p", \
+ f'To clean up properly, need "{PARENT_SRC_DIR}" to end with a "suite2p" folder.\n' \
+ 'Remove "--cleanup" to pass this'
+
+ DST_OUT_DIR = str(AUX_SUITE2P_DIR)
+
+ print(f'Moving files from {SRC_OUT_DIR} to {DST_OUT_DIR}')
+ shutil.copytree(SRC_OUT_DIR, DST_OUT_DIR, copy_function=shutil.move, dirs_exist_ok=True)
+ shutil.rmtree(PARENT_SRC_DIR)
+
+ print(f'Finished motion correction. Use {MOCO_FILEPATH} to continue.')
+
+
\ No newline at end of file
diff --git a/cellreg/2-deepcad.py b/cellreg/2-deepcad.py
new file mode 100644
index 0000000..734954f
--- /dev/null
+++ b/cellreg/2-deepcad.py
@@ -0,0 +1,98 @@
+"""Performs denoising.
+
+Denoises input images in a given directory with given model and output to an
+output directory.
+"""
+
+import os
+import glob
+import shutil
+from pathlib import Path
+
+from natsort import natsorted
+import tifffile
+
+import argparse
+import yaml
+
+from deepcad.test_collection import testing_class
+
+def get_tiff_list(folder, ext=['.tif', '.tiff']):
+ file_list = []
+ for x in ext:
+ file_list.extend(list(folder.glob(f'*{x}')))
+ return file_list
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser('Run denoising with deepcad')
+
+ parser.add_argument("--root-datadir", type=str)
+ parser.add_argument("--subject", type=str)
+ parser.add_argument("--date", type=str)
+ parser.add_argument("--plane", type=str)
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--cleanup", action="store_true")
+
+ args = parser.parse_args()
+ print(args)
+
+ # process data paths
+ ROOT_DATA_DIR = args.root_datadir
+ SUBJECT_ID = args.subject
+ DATE_ID = args.date
+ PLANE_ID = args.plane
+ DEEPCAD_CFG_PATH = args.config
+ CLEAN_UP = args.cleanup
+
+ # define paths
+ EXP_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID / DATE_ID
+ MOCO_DIR = EXP_DIR / '1-moco' / PLANE_ID
+ DEEPCAD_DIR = EXP_DIR / '2-deepcad' / PLANE_ID
+ DEEPCAD_FILEPATH = DEEPCAD_DIR / f'{SUBJECT_ID}-{DATE_ID}-{PLANE_ID}_mc-dc.tif'
+
+ # find tiffs
+ tiff_list = get_tiff_list(MOCO_DIR)
+ tiff_list_str = "\n".join(map(str,tiff_list))
+ assert len(tiff_list) == 1, \
+ f'Currently only accepting one file inside the motion-corrected directory "{MOCO_DIR}", ' \
+ f'found the following: {tiff_list_str}'
+
+ tiff_path = tiff_list[0]
+ with tifffile.TiffFile(str(tiff_path)) as tif:
+ num_frames = len(tif.pages)
+ page0 = tif.pages[0]
+ Ly, Lx = page0.shape
+ print(f'Found motion-corrected tiff file ({num_frames} frames of {Ly} x {Lx}): "{tiff_path}"')
+
+ # configuration
+ with open(DEEPCAD_CFG_PATH, 'r') as f:
+ print(f'Config file loaded from: "{DEEPCAD_CFG_PATH}"')
+ dc_cfg = yaml.safe_load(f)
+
+ dc_cfg['test_datasize'] = num_frames
+ dc_cfg['datasets_path'] = str(MOCO_DIR)
+ dc_cfg['output_dir'] = str(DEEPCAD_DIR)
+
+ print(dc_cfg)
+
+ # deepcad inference
+ tc = testing_class(dc_cfg)
+ tc.run()
+
+ # get output file
+ out_tiff=list(map(str,DEEPCAD_DIR.glob('**/*.tif*')))
+ out_tiff=[x for x in out_tiff if os.path.dirname(x) != str(DEEPCAD_DIR)]
+ out_tiff_str = "\n".join(map(str,out_tiff))
+ assert len(out_tiff) == 1, \
+ f'Currently only accepting one output file in deepcad directory for cleaning up "{DEEPCAD_DIR}", ' \
+ f'found the following: {out_tiff_str}'
+ out_tiff = out_tiff[0]
+
+ shutil.copyfile(out_tiff, DEEPCAD_FILEPATH)
+
+ # clean up
+ if CLEAN_UP:
+ REMOVE_DIR = os.path.dirname(out_tiff)
+ shutil.rmtree(REMOVE_DIR)
+
+ print(f'Finished with denoising using deepcad. Use "{DEEPCAD_FILEPATH}" to continue.')
\ No newline at end of file
diff --git a/cellreg/3-suite2p.py b/cellreg/3-suite2p.py
new file mode 100644
index 0000000..50a4917
--- /dev/null
+++ b/cellreg/3-suite2p.py
@@ -0,0 +1,131 @@
+import os
+import shutil
+from pathlib import Path
+import numpy as np
+import suite2p
+import argparse
+import yaml
+from matplotlib import pyplot as plt
+
+def get_tiff_list(folder, ext=['.tif', '.tiff']):
+ file_list = []
+ for x in ext:
+ file_list.extend(list(folder.glob(f'*{x}')))
+ return file_list
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser('Run suite2p')
+
+ parser.add_argument("--root-datadir", type=str)
+ parser.add_argument("--subject", type=str)
+ parser.add_argument("--date", type=str)
+ parser.add_argument("--plane", type=str)
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--cleanup", action="store_true")
+
+ args = parser.parse_args()
+ print(args)
+
+ # process data paths
+ ROOT_DATA_DIR = args.root_datadir
+ SUBJECT_ID = args.subject
+ DATE_ID = args.date
+ PLANE_ID = args.plane
+ SUITE2P_CFG_PATH = args.config
+ CLEAN_UP = args.cleanup
+
+ # define paths
+
+ EXP_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID / DATE_ID
+ DEEPCAD_DIR = EXP_DIR / '2-deepcad' / PLANE_ID
+ SUITE2P_DIR = EXP_DIR / '3-suite2p' / PLANE_ID
+ AUX_SUITE2P_DIR = SUITE2P_DIR / 'suite2p-post-deepcad'
+
+ # find tiffs (may be redundant)
+ tiff_list = get_tiff_list(DEEPCAD_DIR)
+
+ # make db
+ db = {
+ 'data_path': str(DEEPCAD_DIR),
+ 'save_path0': str(AUX_SUITE2P_DIR),
+ 'tiff_list': tiff_list,
+ }
+ print('Data settings: ')
+ print(db)
+
+ # make ops
+ ops = suite2p.default_ops()
+ with open(SUITE2P_CFG_PATH, 'r') as f:
+ print(f'Ops file loaded from: "{SUITE2P_CFG_PATH}"')
+ ops_cfg = yaml.safe_load(f)
+ ops.update(ops_cfg)
+ print('Suite2p configuration')
+ print(ops)
+
+ # run suite2p
+ output_ops = suite2p.run_s2p(ops=ops, db=db)
+ print('Suite2p done!')
+
+ # create figure directory
+ OUT_DIR = Path(output_ops['save_path'])
+ FIG_DIR = OUT_DIR / 'figures'
+ FIG_DIR.mkdir(parents=True, exist_ok=True)
+
+ # plot and save backgrounds
+ plt.figure(figsize=(15,5))
+ plt.subplot(131)
+ plt.imshow(output_ops['max_proj'], cmap='gray')
+ plt.title("Registered Image, Max Projection");
+
+ plt.subplot(132)
+ plt.imshow(output_ops['meanImg'], cmap='gray')
+ plt.title("Mean registered image")
+
+ plt.subplot(133)
+ plt.imshow(output_ops['meanImgE'], cmap='gray')
+ plt.title("High-pass filtered Mean registered image")
+
+ fig_bg_file = FIG_DIR / 'backgrounds.png'
+ plt.savefig(fig_bg_file, dpi=300)
+ print(f'Background images plotted in "{fig_bg_file}"')
+
+ # get roi masks
+ stats_file = OUT_DIR / 'stat.npy'
+ iscell = np.load(OUT_DIR / 'iscell.npy', allow_pickle=True)[:, 0].astype(bool)
+ stats = np.load(stats_file, allow_pickle=True)
+
+ # plot and save roi masks
+ im = suite2p.ROI.stats_dicts_to_3d_array(stats, Ly=output_ops['Ly'], Lx=output_ops['Lx'], label_id=True)
+ im[im == 0] = np.nan
+
+ plt.figure(figsize=(20,8))
+ plt.subplot(1, 4, 1)
+ plt.imshow(output_ops['max_proj'], cmap='gray')
+ plt.title("Registered Image, Max Projection")
+
+ plt.subplot(1, 4, 2)
+ plt.imshow(np.nanmax(im, axis=0), cmap='jet')
+ plt.title("All ROIs Found")
+
+ plt.subplot(1, 4, 3)
+ plt.imshow(np.nanmax(im[~iscell], axis=0, ), cmap='jet')
+ plt.title("All Non-Cell ROIs")
+
+ plt.subplot(1, 4, 4)
+ plt.imshow(np.nanmax(im[iscell], axis=0), cmap='jet')
+ plt.title("All Cell ROIs")
+
+ fig_roi_file = FIG_DIR / 'roi-masks.png'
+ plt.savefig(fig_roi_file, dpi=300)
+ print(f'ROI mask plotted in "{fig_roi_file}"')
+
+ # move files
+ if CLEAN_UP:
+ SRC_OUT_DIR = output_ops['save_path']
+ DST_OUT_DIR = str(SUITE2P_DIR)
+ print(f'Moving files from {SRC_OUT_DIR} to {DST_OUT_DIR}')
+ shutil.copytree(SRC_OUT_DIR, DST_OUT_DIR, copy_function=shutil.move, dirs_exist_ok=True)
+ shutil.rmtree(AUX_SUITE2P_DIR)
+
+
+ print(f'Finished with suite2p. Use {SUITE2P_DIR} to continue.')
diff --git a/config/1-moco.yml b/config/1-moco.yml
new file mode 100644
index 0000000..711dc97
--- /dev/null
+++ b/config/1-moco.yml
@@ -0,0 +1,11 @@
+# do_registration: 1 # motion correction, will be overwritten in script anyway to have `do_registration: 1`
+# reg_tif: True # save tiff registration in batches, will be overwritten in script anyway to have `reg_tif: True`
+# roidetect: False # whether to do segmentation after registration, will be overwritten in script anyway to have `roidetect: False`
+two_step_registration: 1.0 # important for registration, esp. for low SNR
+nimg_init: 1000 # num frames initial to use as reference for registration
+block_size: [256, 256] # matters for registration
+nonrigid: True # important for registration
+keep_movie_raw: True # whether to keep raw binary file
+fs: 5 # estimates across planes
+combined: False # whether to combine across planes
+tau: 0.57 # calcium dynamics for deconv, possibly not correct
\ No newline at end of file
diff --git a/config/2-deepcad.yml b/config/2-deepcad.yml
new file mode 100644
index 0000000..d9053ee
--- /dev/null
+++ b/config/2-deepcad.yml
@@ -0,0 +1,11 @@
+patch_x: 150
+patch_y: 150
+patch_t: 150
+overlap_factor: 0.6
+scale_factor: 1
+pth_dir: "/oscar/data/afleisc2/collab/multiday-reg/MultiSession/deepcad_model"
+fmap: 16
+GPU: '0'
+num_workers: 0
+visualize_images_per_epoch: False
+save_test_images_per_epoch: True
\ No newline at end of file
diff --git a/config/3-suite2p.yml b/config/3-suite2p.yml
new file mode 100644
index 0000000..6b7a8c0
--- /dev/null
+++ b/config/3-suite2p.yml
@@ -0,0 +1,29 @@
+do_registration: True
+two_step_registration: 1.0 # important for registration, esp. for low SNR
+nimg_init: 1000 # num frames initial to use as reference for registration
+block_size: [256, 256] # matters for registration
+nonrigid: True # important for registration
+reg_tif: False # save tiff registration in batches
+keep_movie_raw: False # whether to keep raw binary file
+fs: 50 # HACK! estimates across planes, currently artificially increase, see https://github.com/MouseLand/suite2p/issues/1073
+combined: False # whether to combine across planes
+roidetect: True # whether to do segmentation after registration
+tau: 0.57 # calcium dynamics for deconv, possibly not correct
+# 1Preg: False # DEFAULT, not relevant for 2P
+# spatial_hp_reg: 42.0 # DEFAULT, not relevant for 2P
+# snr_thresh: 1.2 # DEFAULT, SNR for non-rigid registration
+maxregshift: 0.01 # see https://github.com/MouseLand/suite2p/issues/273
+maxregshiftNR: 4.0 # relevant for registration
+# smooth_sigma: 1.15 # DEFAULT, relevant for only rigid
+# smooth_sigma_time: 0.0 # DEFAULT, relevant for only rigid
+# maxregshift: 0.1 # DEFAULT, relevant for non-rigid
+# reg_tif_chan2: False # DEFAULT, not relevant, whether to write reg binary of chan2 to tiff file
+pre_smooth: 2.0 # UNCLEAR WHETHER or not relevant for 2p registration, can turn to 0.0
+sparse_mode: False
+# spatial_scale: 2 # segmentation
+max_iterations: 100 # segmentation
+threshold_scaling: 2.0 # segmentation
+max_overlap: 0.95 # segmentation
+# denoise: 1.0 # segmentation ; turn this off post-deepcad
+# flow_threshold: 0.0 # use default instead, which is 1.5
+anatomical_only: 1 # segmentation (whether or not to use cellpose)
diff --git a/environments/deepcad.yml b/environments/deepcad.yml
new file mode 100644
index 0000000..517febe
--- /dev/null
+++ b/environments/deepcad.yml
@@ -0,0 +1,35 @@
+name: deepcad
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.9
+ - ipykernel
+ - ipython
+ - matplotlib
+ - networkx
+ - numpy
+ - pandas
+ - pip
+ - toml
+ - tomli
+ - tqdm
+ - pip:
+ - beautifulsoup4
+ - csbdeep
+ - gdown
+ - opencv-python
+ - opencv-python-headless
+ - pillow
+ - pyyaml
+ - requests
+ - scikit-image
+ - scipy
+ - tifffile2023.9.26
+ - torch==1.10.1+cu111
+ - torchaudio==0.10.1+cu111
+ - torchvision==0.11.2+cu111
+ - typer
+ - typer-cli
+ - urllib3
+ - natsort
diff --git a/environments/roicat.yml b/environments/roicat.yml
new file mode 100644
index 0000000..5d3d3ba
--- /dev/null
+++ b/environments/roicat.yml
@@ -0,0 +1,13 @@
+name: roicat
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.11
+ - ipykernel
+ - ipython
+ - pip
+ - pip:
+ - "roicat[all]"
+ - "git+https://github.com/RichieHakim/roiextractors"
+
diff --git a/environments/suite2p.yml b/environments/suite2p.yml
new file mode 100644
index 0000000..72bc834
--- /dev/null
+++ b/environments/suite2p.yml
@@ -0,0 +1,12 @@
+name: suite2p
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - "python=3.9"
+ - pip
+ - ipykernel
+ - pip:
+ - "suite2p"
+ - pyyaml
+
diff --git a/scripts/0-makedirs.sh b/scripts/0-makedirs.sh
new file mode 100644
index 0000000..178ee8c
--- /dev/null
+++ b/scripts/0-makedirs.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+
+# define constants
+SINGLE_SESSION_SUBDIRS=( "0-raw" "1-moco" "2-deepcad" "3-suite2p" "4-muse" )
+MULTI_SESSION_DIR="multi-session"
+LOG_DIR="./logs"
+
+# define variables
+ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/testdir"
+SUBJECT_LIST=( "MS2457" )
+DATE_LIST=( "20230902" "20230907" "20230912" )
+PLANE_LIST=( "plane0" "plane1" )
+
+# make log folders
+for SUBDIR in "${SINGLE_SESSION_SUBDIRS[@]}"; do
+ mkdir -p "$LOG_DIR/$SUBDIR"
+done
+
+# loop through combinations
+for SUBJECT in "${SUBJECT_LIST[@]}"; do
+
+echo ":::::: BEGIN ($SUBJECT) ::::::"
+echo "Creating single-session directories:"
+
+for DATE in "${DATE_LIST[@]}"; do
+for PLANE in "${PLANE_LIST[@]}"; do
+for SUBDIR in "${SINGLE_SESSION_SUBDIRS[@]}"; do
+
+ SINGLE_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$SUBDIR/$PLANE"
+ mkdir -p "$SINGLE_DIR"
+ echo -e "\t$SINGLE_DIR"
+
+done
+done
+done
+
+echo "Creating multi-session directories:"
+for PLANE in "${PLANE_LIST[@]}"; do
+ MULTI_DIR="$ROOT_DATADIR/$SUBJECT/$MULTI_SESSION_DIR/$PLANE"
+ mkdir -p "$MULTI_DIR"
+ echo -e "\t$MULTI_DIR"
+done
+
+
+echo ":::::: DONE ($SUBJECT) ::::::"
+echo "------------------------------"
+echo ""
+
+done
+
+# list raw dirs
+echo "Please put raw TIF data in these folders, according to their name"
+find "$ROOT_DATADIR" -type d -wholename "**/0-raw/*"
+
diff --git a/scripts/1-moco.sh b/scripts/1-moco.sh
new file mode 100644
index 0000000..4fe6606
--- /dev/null
+++ b/scripts/1-moco.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+#SBATCH -N 1
+#SBATCH -n 16
+##SBATCH --account=carney-afleisc2-condo
+#SBATCH --time=01:00:00
+#SBATCH --mem=20g
+#SBATCH --job-name 1-moco
+#SBATCH --output logs/1-moco/log-%J.out
+#SBATCH --error logs/1-moco/log-%J.err
+
+# define constants
+EXEC_FILE="cellreg/1-moco.py"
+EXPECTED_INPUT_SUBDIR="0-raw"
+
+# define variables
+MOCO_CFG_PATH="config/1-moco.yml"
+ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test"
+SUBJECT_LIST=( "MS2457" )
+DATE_LIST=( "20230902" "20230907" "20230912" )
+PLANE_LIST=( "plane0" "plane1" )
+
+# activate environment
+module load miniconda/4.12.0
+source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+conda activate suite2p
+which python
+python "$EXEC_FILE" --help
+
+
+# loop through combinations
+for SUBJECT in "${SUBJECT_LIST[@]}"; do
+for DATE in "${DATE_LIST[@]}"; do
+for PLANE in "${PLANE_LIST[@]}"; do
+
+ echo ":::::: BEGIN ($SUBJECT, $DATE, $PLANE) ::::::"
+
+ EXPECTED_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$EXPECTED_INPUT_SUBDIR/$PLANE"
+ if [ ! -d "$EXPECTED_DIR" ]; then
+ echo "$EXPECTED_DIR does not exist. Skip"
+ continue
+ fi
+
+ python "$EXEC_FILE" \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --date "$DATE" \
+ --plane "$PLANE" \
+ --config "$MOCO_CFG_PATH" \
+ --cleanup
+
+ echo ":::::: DONE ($SUBJECT, $DATE, $PLANE) ::::::"
+ echo "--------------------------------------------"
+ echo ""
+
+done
+done
+done
+
+
+
diff --git a/scripts/2-deepcad.sh b/scripts/2-deepcad.sh
new file mode 100644
index 0000000..55cb2f1
--- /dev/null
+++ b/scripts/2-deepcad.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+#SBATCH -p gpu --gres=gpu:1
+#SBATCH --account=carney-afleisc2-condo
+#SBATCH -N 1
+#SBATCH -n 4
+#SBATCH --account=carney-afleisc2-condo
+#SBATCH --time=01:00:00
+#SBATCH --mem=64g
+#SBATCH --job-name 2-deepcad
+#SBATCH --output logs/2-deepcad/log-%J.out
+#SBATCH --error logs/2-deepcad/log-%J.err
+
+# define constants
+EXEC_FILE="cellreg/2-deepcad.py"
+EXPECTED_INPUT_SUBDIR="1-moco"
+
+# define variables
+DEEPCAD_CFG_PATH="config/2-deepcad.yml"
+ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test"
+SUBJECT_LIST=( "MS2457" )
+DATE_LIST=( "20230902" "20230907" "20230912" )
+PLANE_LIST=( "plane0" "plane1" )
+
+# activate environment
+module load miniconda/4.12.0
+source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+conda activate deepcad
+
+# not sure how to automated the following environment setting
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
+
+which python
+python "$EXEC_FILE" --help
+
+
+# loop through combinations
+for SUBJECT in "${SUBJECT_LIST[@]}"; do
+for DATE in "${DATE_LIST[@]}"; do
+for PLANE in "${PLANE_LIST[@]}"; do
+
+ echo ":::::: BEGIN ($SUBJECT, $DATE, $PLANE) ::::::"
+
+ EXPECTED_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$EXPECTED_INPUT_SUBDIR/$PLANE"
+ if [ ! -d "$EXPECTED_DIR" ]; then
+ echo "$EXPECTED_DIR does not exist. Skip"
+ continue
+ fi
+
+ python "$EXEC_FILE" \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --date "$DATE" \
+ --plane "$PLANE" \
+ --config "$DEEPCAD_CFG_PATH" \
+ --cleanup
+
+ echo ":::::: DONE ($SUBJECT, $DATE, $PLANE) ::::::"
+ echo "--------------------------------------------"
+ echo ""
+
+done
+done
+done
+
+
+
diff --git a/scripts/3-suite2p.sh b/scripts/3-suite2p.sh
new file mode 100644
index 0000000..a2d6030
--- /dev/null
+++ b/scripts/3-suite2p.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+#SBATCH -N 1
+#SBATCH -n 16
+#SBATCH --account=carney-afleisc2-condo
+#SBATCH --time=03:00:00
+#SBATCH --mem=30g
+#SBATCH --job-name 3-suite2p
+#SBATCH --output logs/3-suite2p/log-%J.out
+#SBATCH --error logs/3-suite2p/log-%J.err
+
+# define constants
+EXEC_FILE="cellreg/3-suite2p.py"
+EXPECTED_INPUT_SUBDIR="2-deepcad"
+
+# define variables
+SUITE2P_CFG_PATH="config/3-suite2p.yml"
+ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test"
+SUBJECT_LIST=( "MS2457" )
+DATE_LIST=( "20230902" "20230907" "20230912" )
+PLANE_LIST=( "plane0" "plane1" )
+
+# activate environment
+module load miniconda/4.12.0
+source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+conda activate suite2p
+which python
+python "$EXEC_FILE" --help
+
+
+# loop through combinations
+for SUBJECT in "${SUBJECT_LIST[@]}"; do
+for DATE in "${DATE_LIST[@]}"; do
+for PLANE in "${PLANE_LIST[@]}"; do
+
+ echo ":::::: BEGIN ($SUBJECT, $DATE, $PLANE) ::::::"
+
+ EXPECTED_DIR="$ROOT_DATADIR/$SUBJECT/$DATE/$EXPECTED_INPUT_SUBDIR/$PLANE"
+ if [ ! -d "$EXPECTED_DIR" ]; then
+ echo "$EXPECTED_DIR does not exist. Skip"
+ continue
+ fi
+
+ python "$EXEC_FILE" \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --date "$DATE" \
+ --plane "$PLANE" \
+ --config "$SUITE2P_CFG_PATH" \
+ --cleanup
+
+ echo ":::::: DONE ($SUBJECT, $DATE, $PLANE) ::::::"
+ echo "--------------------------------------------"
+ echo ""
+
+done
+done
+done
+
+
+
diff --git a/scripts/example-steps.sh b/scripts/example-steps.sh
new file mode 100644
index 0000000..8b491ed
--- /dev/null
+++ b/scripts/example-steps.sh
@@ -0,0 +1,44 @@
+ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test"
+SUBJECT="MS2457"
+DATE="20230902"
+PLANE="plane0"
+
+MOCO_CFG_PATH="config/1-moco.yml"
+DEEPCAD_CFG_PATH="config/2-deepcad.yml"
+SUITE2P_CFG_PATH="config/3-suite2p.yml"
+
+# cpu env
+conda acivate suite2p
+
+python cellreg/1-moco.py \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --date "$DATE" \
+ --plane "$PLANE" \
+ --config "$MOCO_CFG_PATH" \
+ --cleanup
+
+# gpu env
+# e.g.: interact -q gpu -g 1 -t 01:00:00 -m 64g -n 4
+conda activate deepcad
+# not sure how to automated the following environment setting
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
+
+python cellreg/2-deepcad.py \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --date "$DATE" \
+ --plane "$PLANE" \
+ --config "$DEEPCAD_CFG_PATH" \
+ --cleanup
+
+# cpu env
+conda acivate suite2p
+
+python cellreg/3-suite2p.py \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --date "$DATE" \
+ --plane "$PLANE" \
+ --config "$SUITE2P_CFG_PATH" \
+ --cleanup
\ No newline at end of file
From 072b8ff3513dfde6ba3994804491c905ccdb697f Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Thu, 14 Dec 2023 16:43:34 -0500
Subject: [PATCH 03/14] env: fix typos in deepcad
---
environments/deepcad.yml | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/environments/deepcad.yml b/environments/deepcad.yml
index 517febe..88f3db1 100644
--- a/environments/deepcad.yml
+++ b/environments/deepcad.yml
@@ -25,10 +25,10 @@ dependencies:
- requests
- scikit-image
- scipy
- - tifffile2023.9.26
- - torch==1.10.1+cu111
- - torchaudio==0.10.1+cu111
- - torchvision==0.11.2+cu111
+ - tifffile
+ - torch==1.10.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
+ - torchaudio==0.10.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
+ - torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
- typer
- typer-cli
- urllib3
From 28c7c5aef6cc030ca18bba5142614b3baf4bb1d0 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Thu, 14 Dec 2023 17:00:59 -0500
Subject: [PATCH 04/14] add roicat to pipeline
still need to refactor and some TODOs
---
STEPS.md | 66 +++-
cellreg/4-roicat.py | 692 +++++++++++++++++++++++++++++++++++++++
scripts/0-makedirs.sh | 4 +
scripts/4-roicat.sh | 47 +++
scripts/example-steps.sh | 11 +-
5 files changed, 818 insertions(+), 2 deletions(-)
create mode 100644 cellreg/4-roicat.py
create mode 100644 scripts/4-roicat.sh
diff --git a/STEPS.md b/STEPS.md
index bc1886f..a8d8f03 100644
--- a/STEPS.md
+++ b/STEPS.md
@@ -293,4 +293,68 @@ The folders under `3-suite2p` will be very familiar in terms of namings and orga
### 4-roicat: Multisession registration using `roicat`
-TBD
+**Purpose**: this uses `roicat` on the segmented data from `suite2p` step, in a CPU environment on Oscar, preferably high number of CPUs and sufficient memory. GPU is possible but not tested yet.
+
+**TODO**
+- [ ] consider changing "4-muse" to "4-roicat"
+- [ ] parameterize to a config file from `scripts/4-roicat.sh`
+- [ ] refactor the long `scripts/4-roicat.sh` into smaller chunks
+- [ ] save the results from `multi-session` collective folder back to relevant single-session directories, such as ROI mapping and aligned FOV
+
+**Task**:
+
+- Please edit the variables under `# define variables` in `scripts/4-roicat.sh` before running
+
+- Then run:
+
+```shell
+sbatch scripts/4-roicat.sh
+```
+
+- You can check whether your job is run with `myq`, take note of the `ID` column, then you can view the log file under `logs/4-roicat/logs-.out`.
+
+**Output**: The following is an example output directory tree from the script:
+
+Click to expand example output
+
+```
+MS2457/multi-session/
+├── plane0
+│ ├── aligned-images.pkl
+│ ├── figures
+│ │ ├── FOV_clusters_allrois.gif
+│ │ ├── FOV_clusters_iscells.gif
+│ │ ├── aligned-fov.png
+│ │ ├── aligned-rois.png
+│ │ ├── cluster-metrics.png
+│ │ ├── num-persist-roi-overall.png
+│ │ ├── num-persist-roi-per-session.png
+│ │ ├── pw-sim-distrib.png
+│ │ └── pw-sim-scatter.png
+│ ├── finalized-roi.csv
+│ ├── roicat-output.pkl
+│ └── summary-roi.csv
+└── plane1
+ ├── aligned-images.pkl
+ ├── figures
+ │ ├── FOV_clusters_allrois.gif
+ │ ├── FOV_clusters_iscells.gif
+ │ ├── aligned-fov.png
+ │ ├── aligned-rois.png
+ │ ├── cluster-metrics.png
+ │ ├── num-persist-roi-overall.png
+ │ ├── num-persist-roi-per-session.png
+ │ ├── pw-sim-distrib.png
+ │ └── pw-sim-scatter.png
+ ├── finalized-roi.csv
+ ├── roicat-output.pkl
+ └── summary-roi.csv
+```
+
+The output data of `roicat` is in `roicat-output.pkl`
+
+For visual inspection: see the figures in `figures`
+
+For analysis: `finalized-roi.csv` is what you may want
+
+
\ No newline at end of file
diff --git a/cellreg/4-roicat.py b/cellreg/4-roicat.py
new file mode 100644
index 0000000..1f8bf18
--- /dev/null
+++ b/cellreg/4-roicat.py
@@ -0,0 +1,692 @@
+import os
+import glob
+import shutil
+from pathlib import Path
+import copy
+import multiprocessing as mp
+import tempfile
+
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+import roicat
+
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+import argparse
+import yaml
+import pickle
+
+
+PARAMS = {
+ 'um_per_pixel': 0.7,
+ 'background_max_percentile': 99.9,
+ 'suite2p': { # `roicat.data_importing.Data_suite2p`
+ 'new_or_old_suite2p': 'new',
+ 'type_meanImg': 'meanImg',
+ },
+ 'fov_augment': { # `aligner.augment_FOV_images`
+ 'roi_FOV_mixing_factor': 0.5,
+ 'use_CLAHE': False,
+ 'CLAHE_grid_size': 1,
+ 'CLAHE_clipLimit': 1,
+ 'CLAHE_normalize': True,
+ },
+ 'fit_geometric': { # `aligner.fit_geometric`
+ 'template': 0,
+ 'template_method': 'image',
+ 'mode_transform': 'affine',
+ 'mask_borders': (5,5,5,5),
+ 'n_iter': 1000,
+ 'termination_eps': 1e-6,
+ 'gaussFiltSize': 15,
+ 'auto_fix_gaussFilt_step':1,
+ },
+ 'fit_nonrigid': { # `aligner.fit_nonrigid`
+ 'disable': True,
+ 'template': 0,
+ 'template_method': 'image',
+ 'mode_transform':'createOptFlow_DeepFlow',
+ 'kwargs_mode_transform':None,
+ },
+ 'roi_blur': {
+ 'kernel_halfWidth': 2
+ }
+
+}
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser('Run roicat for multi-session registration')
+
+ parser.add_argument("--root-datadir", type=str)
+ parser.add_argument("--subject", type=str)
+ parser.add_argument("--plane", type=str)
+ parser.add_argument("--config", type=str)
+ parser.add_argument("-g", "--use-gpu", action="store_true")
+ parser.add_argument("-v", "--verbose", action="store_true")
+
+ parser.add_argument(
+ "--max-depth", type=int, default=6,
+ help='max depth to find suite2p files, relative to `--root-datadir`'
+ )
+
+ args = parser.parse_args()
+ print(args)
+
+ # process data paths
+ ROOT_DATA_DIR = args.root_datadir
+ SUBJECT_ID = args.subject
+ PLANE_ID = args.plane
+ ROICAT_CFG_PATH = args.config
+ SUITE2P_PATH_MAXDEPTH=args.max_depth
+ USE_GPU=args.use_gpu
+ VERBOSITY = args.verbose
+
+ # define paths
+ SUBJECT_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID
+ COLLECTIVE_MUSE_DIR = SUBJECT_DIR / 'multi-session' / PLANE_ID
+ COLLECTIVE_MUSE_FIG_DIR = COLLECTIVE_MUSE_DIR / 'figures'
+ COLLECTIVE_MUSE_FIG_DIR.mkdir(parents=True, exist_ok=True)
+
+ # find suite2p paths
+ dir_allOuterFolders = str(SUBJECT_DIR)
+ pathSuffixToStat = 'stat.npy'
+ pathSuffixToOps = 'ops.npy'
+ pathShouldHave = fr'3-suite2p/{PLANE_ID}'
+
+ paths_allStat = roicat.helpers.find_paths(
+ dir_outer=dir_allOuterFolders,
+ reMatch=pathSuffixToStat,
+ reMatch_in_path=pathShouldHave,
+ depth=SUITE2P_PATH_MAXDEPTH,
+ )[:]
+
+ paths_allStat = [
+ x for x in paths_allStat
+ if pathShouldHave in x
+ ]
+
+ paths_allOps = np.array([
+ Path(path).resolve().parent / pathSuffixToOps
+ for path in paths_allStat
+ ])[:]
+
+
+ print('Paths to all suite2p STAT files:')
+ print('\n'.join(['\t-' + str(x) for x in paths_allStat]))
+ print('\n')
+ print('Paths to all suite2p OPS files:')
+ print('\n'.join(['\t-' + str(x) for x in paths_allOps]))
+ print('\n')
+
+ # load data
+ data = roicat.data_importing.Data_suite2p(
+ paths_statFiles=paths_allStat[:],
+ paths_opsFiles=paths_allOps[:],
+ um_per_pixel=PARAMS['um_per_pixel'],
+ **PARAMS['suite2p'],
+ verbose=VERBOSITY,
+ )
+
+ assert data.check_completeness(verbose=False)['tracking'],\
+ "Data object is missing attributes necessary for tracking."
+
+ # also save iscell paths
+ data.paths_iscell = [
+ Path(x).parent / 'iscell.npy'
+ for x in data.paths_ops
+ ]
+
+ # load all background images
+ background_types = [
+ 'meanImg',
+ 'meanImgE',
+ 'max_proj',
+ 'Vcorr',
+ ]
+
+ FOV_backgrounds = {k: [] for k in background_types}
+ for ops_file in data.paths_ops:
+ ops = np.load(ops_file, allow_pickle=True).item()
+ for bg in background_types:
+ FOV_backgrounds[bg].append(ops[bg])
+
+ # obtain FOVs
+ aligner = roicat.tracking.alignment.Aligner(verbose=VERBOSITY)
+
+ FOV_images = aligner.augment_FOV_images(
+ ims=data.FOV_images,
+ spatialFootprints=data.spatialFootprints,
+ **PARAMS['fov_augment']
+ )
+
+ # ALIGN FOV
+ DISABLE_NONRIGID = PARAMS['fit_nonrigid'].pop('disable')
+
+ # geometric fit
+ aligner.fit_geometric(
+ ims_moving=FOV_images,
+ **PARAMS['fit_geometric']
+ )
+ aligner.transform_images_geometric(FOV_images)
+ remap_idx = aligner.remappingIdx_geo
+
+ # non-rigid
+ if not DISABLE_NONRIGID:
+ aligner.fit_nonrigid(
+ ims_moving=aligner.ims_registered_geo,
+ remappingIdx_init=aligner.remappingIdx_geo,
+ **PARAMS['fit_nonrigid']
+ )
+ aligner.transform_images_nonrigid(FOV_images)
+ remap_idx = aligner.remappingIdx_nonrigid
+
+ # transform ROIs
+ aligner.transform_ROIs(
+ ROIs=data.spatialFootprints,
+ remappingIdx=remap_idx,
+ normalize=True,
+ )
+
+ # transform other backgrounds
+ aligned_backgrounds = {k: [] for k in background_types}
+ for bg in background_types:
+ aligned_backgrounds[bg] = aligner.transform_images(
+ FOV_backgrounds[bg],
+ remappingIdx=remap_idx
+ )
+
+ plt.figure(figsize=(20,20), layout='tight')
+ types2plt = background_types + ['ROI']
+ nrows = len(types2plt)
+ ncols = data.n_sessions
+
+ splt_cnt = 1
+ for k in types2plt:
+ image_list = aligned_backgrounds.get(k, aligner.get_ROIsAligned_maxIntensityProjection())
+ for s_id, img in enumerate(image_list):
+ plt.subplot(nrows, ncols, splt_cnt)
+ plt.imshow(
+ img, cmap='Greys_r',
+ vmax=np.percentile(
+ img,
+ PARAMS['background_max_percentile'] if k!= "ROI" else 95
+ )
+ )
+ plt.axis('off')
+ plt.title(f'Aligned {k} [#{s_id}]')
+ splt_cnt += 1
+
+ plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-fov.png')
+
+ # BUILD FEATUREs
+
+ # blur ROI
+ blurrer = roicat.tracking.blurring.ROI_Blurrer(
+ frame_shape=(data.FOV_height, data.FOV_width),
+ plot_kernel=False,
+ verbose=VERBOSITY,
+ **PARAMS['roi_blur']
+ )
+
+ blurrer.blur_ROIs(
+ spatialFootprints=aligner.ROIs_aligned[:],
+ )
+
+ # ROInet embedding
+ # TODO: Parameterize `ROInet_embedder`, `generate_dataloader`
+ DEVICE = roicat.helpers.set_device(use_GPU=USE_GPU, verbose=VERBOSITY)
+ dir_temp = tempfile.gettempdir()
+
+ roinet = roicat.ROInet.ROInet_embedder(
+ device=DEVICE,
+ dir_networkFiles=dir_temp,
+ download_method='check_local_first',
+ download_url='https://osf.io/x3fd2/download',
+ download_hash='7a5fb8ad94b110037785a46b9463ea94',
+ forward_pass_version='latent',
+ verbose=VERBOSITY
+ )
+
+ roinet.generate_dataloader(
+ ROI_images=data.ROI_images,
+ um_per_pixel=data.um_per_pixel,
+ pref_plot=False,
+ jit_script_transforms=False,
+ batchSize_dataloader=8,
+ pinMemory_dataloader=True,
+ numWorkers_dataloader=4,
+ persistentWorkers_dataloader=True,
+ prefetchFactor_dataloader=2,
+ )
+
+ roinet.generate_latents()
+
+ # Scattering wavelet embedding
+ # TODO: Parameterize `SWT`, `SWT.transform`
+ swt = roicat.tracking.scatteringWaveletTransformer.SWT(
+ kwargs_Scattering2D={'J': 3, 'L': 12},
+ image_shape=data.ROI_images[0].shape[1:3],
+ device=DEVICE,
+ )
+
+ swt.transform(
+ ROI_images=roinet.ROI_images_rs,
+ batch_size=100,
+ )
+
+ # Compute similarities
+ # TODO: Parameterize `ROI_graph`, `compute_similarity_blockwise`, `make_normalized_similarities`
+
+ sim = roicat.tracking.similarity_graph.ROI_graph(
+ n_workers=-1,
+ frame_height=data.FOV_height,
+ frame_width=data.FOV_width,
+ block_height=128,
+ block_width=128,
+ algorithm_nearestNeigbors_spatialFootprints='brute',
+ verbose=VERBOSITY,
+ )
+
+ s_sf, s_NN, s_SWT, s_sesh = sim.compute_similarity_blockwise(
+ spatialFootprints=blurrer.ROIs_blurred,
+ features_NN=roinet.latents,
+ features_SWT=swt.latents,
+ ROI_session_bool=data.session_bool,
+ spatialFootprint_maskPower=1.0,
+ )
+
+ sim.make_normalized_similarities(
+ centers_of_mass=data.centroids,
+ features_NN=roinet.latents,
+ features_SWT=swt.latents,
+ k_max=data.n_sessions*100,
+ k_min=data.n_sessions*10,
+ algo_NN='kd_tree',
+ device=DEVICE,
+ )
+
+ # Clustering
+ # TODO: Parameterize `find_optimal_parameters_for_pruning`?
+ clusterer = roicat.tracking.clustering.Clusterer(
+ s_sf=sim.s_sf,
+ s_NN_z=sim.s_NN_z,
+ s_SWT_z=sim.s_SWT_z,
+ s_sesh=sim.s_sesh,
+ )
+
+ kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning(
+ n_bins=None,
+ smoothing_window_bins=None,
+ kwargs_findParameters={
+ 'n_patience': 300,
+ 'tol_frac': 0.001,
+ 'max_trials': 1200,
+ 'max_duration': 60*10,
+ },
+ bounds_findParameters={
+ 'power_NN': (0., 5.),
+ 'power_SWT': (0., 5.),
+ 'p_norm': (-5, 0),
+ 'sig_NN_kwargs_mu': (0., 1.0),
+ 'sig_NN_kwargs_b': (0.00, 1.5),
+ 'sig_SWT_kwargs_mu': (0., 1.0),
+ 'sig_SWT_kwargs_b': (0.00, 1.5),
+ },
+ n_jobs_findParameters=-1,
+ )
+
+ kwargs_mcdm_tmp = kwargs_makeConjunctiveDistanceMatrix_best ## Use the optimized parameters
+
+ clusterer.plot_distSame(kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp)
+ plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-distrib.png')
+
+ clusterer.plot_similarity_relationships(
+ plots_to_show=[1,2,3],
+ max_samples=100000, ## Make smaller if it is running too slow
+ kwargs_scatter={'s':1, 'alpha':0.2},
+ kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp
+ );
+ plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-scatter.png')
+
+ clusterer.make_pruned_similarity_graphs(
+ d_cutoff=None,
+ kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp,
+ stringency=1.0,
+ convert_to_probability=False,
+ )
+
+ if data.n_sessions >= 8:
+ labels = clusterer.fit(
+ d_conj=clusterer.dConj_pruned,
+ session_bool=data.session_bool,
+ min_cluster_size=2,
+ n_iter_violationCorrection=3,
+ split_intraSession_clusters=True,
+ cluster_selection_method='leaf',
+ d_clusterMerge=None,
+ alpha=0.999,
+ discard_failed_pruning=False,
+ n_steps_clusterSplit=100,
+ )
+
+ else:
+ labels = clusterer.fit_sequentialHungarian(
+ d_conj=clusterer.dConj_pruned, ## Input distance matrix
+ session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions
+ thresh_cost=0.6, ## Threshold
+ )
+
+ quality_metrics = clusterer.compute_quality_metrics()
+
+ labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict = roicat.tracking.clustering.make_label_variants(labels=labels, n_roi_bySession=data.n_roi)
+
+ results = {
+ "clusters":{
+ "labels": labels_squeezed,
+ "labels_bySession": labels_bySession,
+ "labels_bool": labels_bool,
+ "labels_bool_bySession": labels_bool_bySession,
+ "labels_dict": labels_dict,
+ },
+ "ROIs": {
+ "ROIs_aligned": aligner.ROIs_aligned,
+ "ROIs_raw": data.spatialFootprints,
+ "frame_height": data.FOV_height,
+ "frame_width": data.FOV_width,
+ "idx_roi_session": np.where(data.session_bool)[1],
+ "n_sessions": data.n_sessions,
+ },
+ "input_data": {
+ "paths_stat": data.paths_stat,
+ "paths_ops": data.paths_ops,
+ },
+ "quality_metrics": clusterer.quality_metrics if hasattr(clusterer, 'quality_metrics') else None,
+ }
+
+ run_data = copy.deepcopy({
+ 'data': data.serializable_dict,
+ 'aligner': aligner.serializable_dict,
+ 'blurrer': blurrer.serializable_dict,
+ 'roinet': roinet.serializable_dict,
+ 'swt': swt.serializable_dict,
+ 'sim': sim.serializable_dict,
+ 'clusterer': clusterer.serializable_dict,
+ })
+
+ iscell_bySession = [np.load(ic_p)[:,0] for ic_p in data.paths_iscell]
+
+ with open(COLLECTIVE_MUSE_DIR / 'roicat-output.pkl', 'wb') as f:
+ pickle.dump(dict(
+ run_data = run_data,
+ results = results,
+ iscell = iscell_bySession
+ ), f)
+
+ print(f'Number of clusters: {len(np.unique(results["clusters"]["labels"]))}')
+ print(f'Number of discarded ROIs: {(results["clusters"]["labels"]==-1).sum()}')
+
+ # Visualize
+ confidence = (((results['quality_metrics']['cluster_silhouette'] + 1) / 2) * results['quality_metrics']['cluster_intra_means'])
+
+ fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))
+
+ axs[0,0].hist(results['quality_metrics']['cluster_silhouette'], 50);
+ axs[0,0].set_xlabel('cluster_silhouette');
+ axs[0,0].set_ylabel('cluster counts');
+
+ axs[0,1].hist(results['quality_metrics']['cluster_intra_means'], 50);
+ axs[0,1].set_xlabel('cluster_intra_means');
+ axs[0,1].set_ylabel('cluster counts');
+
+ axs[1,0].hist(confidence, 50);
+ axs[1,0].set_xlabel('confidence');
+ axs[1,0].set_ylabel('cluster counts');
+
+ axs[1,1].hist(results['quality_metrics']['sample_silhouette'], 50);
+ axs[1,1].set_xlabel('sample_silhouette score');
+ axs[1,1].set_ylabel('roi sample counts');
+
+ fig.savefig(COLLECTIVE_MUSE_FIG_DIR / 'cluster-metrics.png')
+
+ # FOV clusters
+ FOV_clusters = roicat.visualization.compute_colored_FOV(
+ spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']],
+ FOV_height=results['ROIs']['frame_height'],
+ FOV_width=results['ROIs']['frame_width'],
+ labels=results["clusters"]["labels_bySession"], ## cluster labels
+ # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array
+ alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),
+ # alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array
+ )
+
+ FOV_clusters_with_iscell = roicat.visualization.compute_colored_FOV(
+ spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], ## Spatial footprint sparse arrays
+ FOV_height=results['ROIs']['frame_height'],
+ FOV_width=results['ROIs']['frame_width'],
+ labels=results["clusters"]["labels_bySession"], ## cluster labels
+ # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array
+ alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),
+ alphas_sf=iscell_bySession
+ # alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array
+ )
+
+ roicat.helpers.save_gif(
+ array=FOV_clusters,
+ path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_allrois.gif'),
+ frameRate=5.0,
+ loop=0,
+ )
+
+ roicat.helpers.save_gif(
+ array=FOV_clusters_with_iscell,
+ path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_iscells.gif'),
+ frameRate=5.0,
+ loop=0,
+ )
+
+ plt.figure(figsize=(20,10), layout='tight')
+ roi_image_dict = {
+ 'all': FOV_clusters,
+ 'iscell': FOV_clusters_with_iscell
+ }
+ nrows = len(roi_image_dict)
+ ncols = data.n_sessions
+
+ splt_cnt = 1
+ for k, image_list in roi_image_dict.items():
+ for s_id, img in enumerate(image_list):
+ plt.subplot(nrows, ncols, splt_cnt)
+ plt.imshow(img)
+ plt.axis('off')
+ plt.title(f'Aligned {k} [#{s_id}]')
+ splt_cnt += 1
+
+ plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-rois.png')
+
+ # save FOVs
+ num_sessions = data.n_sessions
+ out_img = []
+ for d in range(num_sessions):
+ out_img.append(dict(
+ fov = aligner.ims_registered_geo[d],
+ roi_pre_iscell = FOV_clusters[d],
+ roi_with_iscell = FOV_clusters_with_iscell[d],
+ **{k: v[d] for k,v in aligned_backgrounds.items()},
+ ))
+
+ with open(COLLECTIVE_MUSE_DIR / 'aligned-images.pkl', 'wb') as f:
+ pickle.dump(out_img, f)
+
+
+ # save summary data
+ df = pd.DataFrame([
+ dict(
+ session=i,
+ global_roi=glv,
+ session_roi=range(len(glv)),
+ iscell = iscv
+ )
+ for i, (glv, iscv) in enumerate(zip(labels_bySession, iscell_bySession))
+ ]).explode(['global_roi','session_roi', 'iscell']).astype({'iscell': 'bool'})
+
+ df.to_csv(COLLECTIVE_MUSE_DIR / 'summary-roi.csv', index=False)
+
+ # process only iscell for stats
+ df = (
+ df.query('iscell')
+ .reset_index(drop=True)
+ )
+
+ df = df.merge(
+ (
+ df.query('global_roi >= 0')
+ .groupby('global_roi')
+ ['session'].agg(lambda x: len(list(x)))
+ .to_frame('num_sessions')
+ .reset_index()
+ ),
+ how='left'
+ )
+
+ df = (
+ df.fillna({'num_sessions': 1})
+ .astype({'num_sessions': 'int'})
+ )
+
+ # re-indexing
+ persistent_roi_reindices = (
+ df[['num_sessions', 'global_roi']]
+ .query('global_roi >= 0 and num_sessions > 1')
+ .drop_duplicates()
+ .sort_values('num_sessions', ascending=False)
+ .reset_index(drop=True)
+ .reset_index()
+ .set_index('global_roi')
+ ['index'].to_dict()
+ )
+
+ df['reindexed_global_roi'] = df['global_roi'].map(persistent_roi_reindices)
+
+ single_roi_start_indices = df['reindexed_global_roi'].max() + 1
+ single_roi_rows = df.query('reindexed_global_roi.isna()').index
+ num_single_rois = len(single_roi_rows)
+
+ df.loc[single_roi_rows, 'reindexed_global_roi'] = \
+ np.arange(num_single_rois) + single_roi_start_indices
+
+ df['reindexed_global_roi'] = df['reindexed_global_roi'].astype('int')
+ df = df.rename(columns={
+ 'global_roi': 'roicat_global_roi',
+ 'reindexed_global_roi': 'global_roi'
+ })
+
+ df.to_csv(COLLECTIVE_MUSE_DIR / 'finalized-roi.csv', index=False)
+
+ # plot persistent ROIs summary
+ persist_rois = (
+ df
+ .drop_duplicates(['global_roi'])
+ .value_counts('num_sessions', sort=False)
+ .to_frame('num_rois')
+ .reset_index()
+ )
+
+ plt.figure(figsize=(4,5))
+ ax = sns.barplot(
+ persist_rois,
+ x = 'num_sessions',
+ y = 'num_rois',
+ hue = 'num_sessions',
+ facecolor = '#afafaf',
+ dodge=False,
+ edgecolor='k'
+ )
+ sns.despine(trim=True, offset=10)
+
+ plt.legend([], [], frameon=False)
+ [ax.bar_label(c, padding=5, fontsize=10) for c in ax.containers]
+ plt.xlabel('# sessions')
+ plt.ylabel('# rois')
+ plt.title('Persisted ROIs')
+ plt.tight_layout()
+ plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-overall.png')
+
+ # plot persistent ROIs per sessions
+ df_sessions = (
+ df
+ .value_counts(['session','num_sessions'])
+ .to_frame('count')
+ .reset_index()
+ )
+
+ df_total_per_session = (
+ df_sessions
+ .groupby('session')
+ ['count'].agg('sum')
+ .to_frame('total_count')
+ .reset_index()
+ )
+
+ df_sessions = df_sessions.merge(df_total_per_session, how='left')
+ df_sessions['percent'] = 100 * df_sessions['count'] / df_sessions['total_count']
+
+ plt.figure(figsize=(10,5))
+ bar_kwargs = dict(
+ kind='bar',
+ stacked=True,
+ colormap='GnBu',
+ width=0.7,
+ edgecolor='k',
+ )
+
+ ax1 = plt.subplot(121)
+
+ (
+ df_sessions
+ .pivot(index='session',columns='num_sessions', values='count')
+ .fillna(0)
+ .plot(
+ **bar_kwargs,
+ xlabel='session ID',
+ ylabel='# rois',
+ legend=False,
+ ax=ax1)
+ )
+ plt.tick_params(rotation=0)
+
+ ax2 = plt.subplot(122)
+
+ (
+ df_sessions
+ .pivot(index='session',columns='num_sessions', values='percent')
+ .fillna(0)
+ .plot(
+ **bar_kwargs,
+ xlabel='session ID',
+ ylabel='% roi per session',
+ ax=ax2
+ )
+ )
+ plt.tick_params(rotation=0)
+
+ leg_handles, leg_labels = plt.gca().get_legend_handles_labels()
+ plt.legend(
+ reversed(leg_handles),
+ reversed(leg_labels),
+ loc='upper right',
+ bbox_to_anchor=[1.5,1],
+ title='# sessions',
+ edgecolor='k',
+ )
+
+ sns.despine(trim=True, offset=10)
+
+ plt.suptitle(
+ 'Distribution of detected and aligned ROIs across sessions',
+ )
+ plt.tight_layout(w_pad=5)
+ plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-per-session.png')
+
diff --git a/scripts/0-makedirs.sh b/scripts/0-makedirs.sh
index 178ee8c..aa409a6 100644
--- a/scripts/0-makedirs.sh
+++ b/scripts/0-makedirs.sh
@@ -1,5 +1,7 @@
#!/bin/bash
+# TODO: consider changing "4-muse" to "4-roicat"
+
# define constants
SINGLE_SESSION_SUBDIRS=( "0-raw" "1-moco" "2-deepcad" "3-suite2p" "4-muse" )
MULTI_SESSION_DIR="multi-session"
@@ -16,6 +18,8 @@ for SUBDIR in "${SINGLE_SESSION_SUBDIRS[@]}"; do
mkdir -p "$LOG_DIR/$SUBDIR"
done
+mkdir -p "$LOG_DIR/4-roicat"
+
# loop through combinations
for SUBJECT in "${SUBJECT_LIST[@]}"; do
diff --git a/scripts/4-roicat.sh b/scripts/4-roicat.sh
new file mode 100644
index 0000000..75cfd1b
--- /dev/null
+++ b/scripts/4-roicat.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+#SBATCH -N 1
+#SBATCH -n 8
+#SBATCH --account=carney-afleisc2-condo
+#SBATCH --time=00:30:00
+#SBATCH --mem=30g
+#SBATCH --job-name 4-roicat
+#SBATCH --output logs/4-roicat/log-%J.out
+#SBATCH --error logs/4-roicat/log-%J.err
+
+# define constants
+EXEC_FILE="cellreg/4-roicat.py"
+
+# define variables
+ROOT_DATADIR="/oscar/data/afleisc2/collab/multiday-reg/data/test"
+SUBJECT_LIST=( "MS2457" )
+PLANE_LIST=( "plane0" "plane1" )
+
+# activate environment
+module load miniconda/4.12.0
+source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+conda activate roicat
+which python
+python "$EXEC_FILE" --help
+
+
+# loop through combinations
+for SUBJECT in "${SUBJECT_LIST[@]}"; do
+for PLANE in "${PLANE_LIST[@]}"; do
+
+ echo ":::::: BEGIN ($SUBJECT, $PLANE) ::::::"
+
+ python "$EXEC_FILE" \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --plane "$PLANE"
+
+ echo ":::::: DONE ($SUBJECT, $PLANE) ::::::"
+ echo "--------------------------------------"
+ echo ""
+
+done
+done
+
+
+
+
diff --git a/scripts/example-steps.sh b/scripts/example-steps.sh
index 8b491ed..2fa8c70 100644
--- a/scripts/example-steps.sh
+++ b/scripts/example-steps.sh
@@ -41,4 +41,13 @@ python cellreg/3-suite2p.py \
--date "$DATE" \
--plane "$PLANE" \
--config "$SUITE2P_CFG_PATH" \
- --cleanup
\ No newline at end of file
+ --cleanup
+
+
+# after all dates are done
+# cpu env is sufficient
+conda activate roicat
+python cellreg/4-roicat.py \
+ --root-datadir "$ROOT_DATADIR" \
+ --subject "$SUBJECT" \
+ --plane "$PLANE"
From 9a94001c8fcb41a3f8aa702288d8c5c2e2c88bce Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Mon, 18 Dec 2023 14:30:38 -0500
Subject: [PATCH 05/14] env: fix deepcad env file
---
environments/deepcad.yml | 21 +++++++++++----------
1 file changed, 11 insertions(+), 10 deletions(-)
diff --git a/environments/deepcad.yml b/environments/deepcad.yml
index 88f3db1..420510a 100644
--- a/environments/deepcad.yml
+++ b/environments/deepcad.yml
@@ -6,16 +6,17 @@ dependencies:
- python=3.9
- ipykernel
- ipython
- - matplotlib
- - networkx
- - numpy
- - pandas
- pip
- - toml
- - tomli
- - tqdm
- pip:
+ - -f https://download.pytorch.org/whl/cu111/torch_stable.html
- beautifulsoup4
+ - matplotlib
+ - networkx
+ - numpy
+ - pandas
+ - toml
+ - tomli
+ - tqdm
- csbdeep
- gdown
- opencv-python
@@ -26,10 +27,10 @@ dependencies:
- scikit-image
- scipy
- tifffile
- - torch==1.10.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
- - torchaudio==0.10.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
- - torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
- typer
- typer-cli
- urllib3
- natsort
+ - "torch==1.10.1+cu111"
+ - "torchaudio==0.10.1+cu111"
+ - "torchvision==0.11.2+cu111"
From d00df19068e67d9cda5228522844e0322da80b55 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Fri, 19 Jan 2024 13:19:47 -0500
Subject: [PATCH 06/14] env: update oscar module name and path
---
README.md | 6 +++---
STEPS.md | 28 +++++++++++++++++++++-------
scripts/1-moco.sh | 12 ++++++------
scripts/2-deepcad.sh | 14 ++++++--------
scripts/3-suite2p.sh | 11 +++++------
scripts/4-roicat.sh | 12 +++++-------
scripts/example-steps.sh | 6 ++++--
7 files changed, 50 insertions(+), 39 deletions(-)
diff --git a/README.md b/README.md
index 2c1185d..a9322f5 100644
--- a/README.md
+++ b/README.md
@@ -7,11 +7,11 @@ This package is intended for use on Oscar with GPU support for PyTorch.
To use on Oscar, first clone the repo. Then load the anaconda module:
```shell
-module load anaconda/2022.05
-source /gpfs/runtime/opt/anaconda/2022.05/etc/profile.d/conda.sh
+module load miniconda3/23.11.0s
+source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
```
-If you have never loaded the `anaconda/2022.05` module before, you need to initialize
+If you have never loaded the `miniconda3/23.11.0s` module before, you need to initialize
the conda module by running the following command.
```shell
diff --git a/STEPS.md b/STEPS.md
index a8d8f03..632003a 100644
--- a/STEPS.md
+++ b/STEPS.md
@@ -2,13 +2,27 @@
## Multiple environment setups
-There are at least 3 environments that need to be created to manage steps separately:
+There are at least 3 environments that need to be created to manage steps separately.
-```shell
-conda env create -f environments/suite2p.yml
-conda env create -f environments/deepcad.yml
-conda env create -f environments/roicat.yml
-```
+- First load `miniconda3` module:
+
+ ```shell
+ module load miniconda3/23.11.0s
+ ```
+
+- If you have loaded this before, also do the following. Note: you only need to this **ONCE**.
+
+ ```shell
+ conda init bash
+ ```
+
+- Then do the following to create all three environments:
+
+ ```shell
+ conda env create -f environments/suite2p.yml
+ conda env create -f environments/deepcad.yml
+ conda env create -f environments/roicat.yml
+ ```
## Steps
@@ -357,4 +371,4 @@ For visual inspection: see the figures in `figures`
For analysis: `finalized-roi.csv` is what you may want
-
\ No newline at end of file
+
diff --git a/scripts/1-moco.sh b/scripts/1-moco.sh
index 4fe6606..bda9c1b 100644
--- a/scripts/1-moco.sh
+++ b/scripts/1-moco.sh
@@ -1,7 +1,8 @@
#!/bin/bash
#SBATCH -N 1
#SBATCH -n 16
-##SBATCH --account=carney-afleisc2-condo
+#SBATCH -p batch
+#SBATCH --account=carney-afleisc2-condo
#SBATCH --time=01:00:00
#SBATCH --mem=20g
#SBATCH --job-name 1-moco
@@ -20,10 +21,11 @@ DATE_LIST=( "20230902" "20230907" "20230912" )
PLANE_LIST=( "plane0" "plane1" )
# activate environment
-module load miniconda/4.12.0
-source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+module load miniconda3/23.11.0s
+source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
+
conda activate suite2p
-which python
+command -v python
python "$EXEC_FILE" --help
@@ -56,5 +58,3 @@ done
done
done
-
-
diff --git a/scripts/2-deepcad.sh b/scripts/2-deepcad.sh
index 55cb2f1..0a2f297 100644
--- a/scripts/2-deepcad.sh
+++ b/scripts/2-deepcad.sh
@@ -3,7 +3,6 @@
#SBATCH --account=carney-afleisc2-condo
#SBATCH -N 1
#SBATCH -n 4
-#SBATCH --account=carney-afleisc2-condo
#SBATCH --time=01:00:00
#SBATCH --mem=64g
#SBATCH --job-name 2-deepcad
@@ -22,14 +21,15 @@ DATE_LIST=( "20230902" "20230907" "20230912" )
PLANE_LIST=( "plane0" "plane1" )
# activate environment
-module load miniconda/4.12.0
-source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+module load miniconda3/23.11.0s
+source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
+
conda activate deepcad
-# not sure how to automated the following environment setting
-export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
+# TODO: check if this is needed anymore
+# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
-which python
+command -v python
python "$EXEC_FILE" --help
@@ -62,5 +62,3 @@ done
done
done
-
-
diff --git a/scripts/3-suite2p.sh b/scripts/3-suite2p.sh
index a2d6030..7fc4784 100644
--- a/scripts/3-suite2p.sh
+++ b/scripts/3-suite2p.sh
@@ -1,6 +1,7 @@
#!/bin/bash
#SBATCH -N 1
#SBATCH -n 16
+#SBATCH -p batch
#SBATCH --account=carney-afleisc2-condo
#SBATCH --time=03:00:00
#SBATCH --mem=30g
@@ -20,13 +21,13 @@ DATE_LIST=( "20230902" "20230907" "20230912" )
PLANE_LIST=( "plane0" "plane1" )
# activate environment
-module load miniconda/4.12.0
-source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+module load miniconda3/23.11.0s
+source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
+
conda activate suite2p
-which python
+command -v python
python "$EXEC_FILE" --help
-
# loop through combinations
for SUBJECT in "${SUBJECT_LIST[@]}"; do
for DATE in "${DATE_LIST[@]}"; do
@@ -56,5 +57,3 @@ done
done
done
-
-
diff --git a/scripts/4-roicat.sh b/scripts/4-roicat.sh
index 75cfd1b..56d288f 100644
--- a/scripts/4-roicat.sh
+++ b/scripts/4-roicat.sh
@@ -1,5 +1,6 @@
#!/bin/bash
#SBATCH -N 1
+#SBATCH -p batch
#SBATCH -n 8
#SBATCH --account=carney-afleisc2-condo
#SBATCH --time=00:30:00
@@ -17,13 +18,13 @@ SUBJECT_LIST=( "MS2457" )
PLANE_LIST=( "plane0" "plane1" )
# activate environment
-module load miniconda/4.12.0
-source /gpfs/runtime/opt/miniconda/4.12.0/etc/profile.d/conda.sh
+module load miniconda3/23.11.0s
+source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
+
conda activate roicat
-which python
+command -v python
python "$EXEC_FILE" --help
-
# loop through combinations
for SUBJECT in "${SUBJECT_LIST[@]}"; do
for PLANE in "${PLANE_LIST[@]}"; do
@@ -42,6 +43,3 @@ for PLANE in "${PLANE_LIST[@]}"; do
done
done
-
-
-
diff --git a/scripts/example-steps.sh b/scripts/example-steps.sh
index 2fa8c70..d8c3fdb 100644
--- a/scripts/example-steps.sh
+++ b/scripts/example-steps.sh
@@ -21,8 +21,9 @@ python cellreg/1-moco.py \
# gpu env
# e.g.: interact -q gpu -g 1 -t 01:00:00 -m 64g -n 4
conda activate deepcad
-# not sure how to automated the following environment setting
-export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
+
+# try the following if there's an error with the `python cellreg/2-deepcad.py` step
+# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
python cellreg/2-deepcad.py \
--root-datadir "$ROOT_DATADIR" \
@@ -51,3 +52,4 @@ python cellreg/4-roicat.py \
--root-datadir "$ROOT_DATADIR" \
--subject "$SUBJECT" \
--plane "$PLANE"
+
From 5457ffe8a90269bd142bb3b02045bca110702ccd Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Mon, 22 Jan 2024 16:49:24 -0500
Subject: [PATCH 07/14] feat: allow specify inp/out dir pattern + account for
cropped background images
---
cellreg/4-roicat.py | 53 +++++++++++++++++++++++++++++++++++++++------
1 file changed, 46 insertions(+), 7 deletions(-)
diff --git a/cellreg/4-roicat.py b/cellreg/4-roicat.py
index 1f8bf18..0673593 100644
--- a/cellreg/4-roicat.py
+++ b/cellreg/4-roicat.py
@@ -71,6 +71,16 @@
"--max-depth", type=int, default=6,
help='max depth to find suite2p files, relative to `--root-datadir`'
)
+
+ parser.add_argument(
+ "--suite2p-subdir", type=str, default='3-suite2p',
+ help='suite2p subdirectory'
+ )
+
+ parser.add_argument(
+ "--output-topdir", type=str, default='',
+ help='output top directory, if empty, will use the subject folder'
+ )
args = parser.parse_args()
print(args)
@@ -84,9 +94,14 @@
USE_GPU=args.use_gpu
VERBOSITY = args.verbose
+ OUTPUT_DIR=args.output_topdir
+ SUITE2P_SUBDIR=args.suite2p_subdir
+
# define paths
SUBJECT_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID
- COLLECTIVE_MUSE_DIR = SUBJECT_DIR / 'multi-session' / PLANE_ID
+ if OUTPUT_DIR in [None, '']:
+ OUTPUT_DIR = SUBJECT_DIR
+ COLLECTIVE_MUSE_DIR = OUTPUT_DIR / 'multi-session' / PLANE_ID
COLLECTIVE_MUSE_FIG_DIR = COLLECTIVE_MUSE_DIR / 'figures'
COLLECTIVE_MUSE_FIG_DIR.mkdir(parents=True, exist_ok=True)
@@ -94,7 +109,7 @@
dir_allOuterFolders = str(SUBJECT_DIR)
pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'
- pathShouldHave = fr'3-suite2p/{PLANE_ID}'
+ pathShouldHave = fr'{SUITE2P_SUBDIR}/{PLANE_ID}'
paths_allStat = roicat.helpers.find_paths(
dir_outer=dir_allOuterFolders,
@@ -115,18 +130,19 @@
print('Paths to all suite2p STAT files:')
- print('\n'.join(['\t-' + str(x) for x in paths_allStat]))
+ print('\n'.join(['\t- ' + str(x) for x in paths_allStat]))
print('\n')
print('Paths to all suite2p OPS files:')
- print('\n'.join(['\t-' + str(x) for x in paths_allOps]))
+ print('\n'.join(['\t- ' + str(x) for x in paths_allOps]))
print('\n')
- # load data
+ # load data
data = roicat.data_importing.Data_suite2p(
paths_statFiles=paths_allStat[:],
paths_opsFiles=paths_allOps[:],
um_per_pixel=PARAMS['um_per_pixel'],
- **PARAMS['suite2p'],
+ type_meanImg='meanImg', # will be overwritten in the following cell
+ **{k: v for k, v in PARAMS['suite2p'].items() if k not in ['type_meanImg']},
verbose=VERBOSITY,
)
@@ -150,8 +166,31 @@
FOV_backgrounds = {k: [] for k in background_types}
for ops_file in data.paths_ops:
ops = np.load(ops_file, allow_pickle=True).item()
+
+ im_sz = (ops['Ly'], ops['Lx'])
for bg in background_types:
- FOV_backgrounds[bg].append(ops[bg])
+ bg_im = ops[bg]
+
+ if bg_im.shape == im_sz:
+ FOV_backgrounds[bg].append(bg_im)
+ continue
+
+ print(
+ f'\t- File {ops_file}: {bg} shape is {bg_im.shape}, which is cropped from {im_sz}. '\
+ '\n\tWill attempt to add empty pixels to recover the original shape.'
+ )
+
+ im = np.zeros(im_sz).astype(bg_im.dtype)
+ cropped_xrange, cropped_yrange = ops['xrange'], ops['yrange']
+ im[
+ cropped_yrange[0]:cropped_yrange[1],
+ cropped_xrange[0]:cropped_xrange[1]
+ ] = bg_im
+
+ FOV_backgrounds[bg].append(im)
+
+ # choice of FOV images to align
+ data.FOV_images = FOV_backgrounds[PARAMS['suite2p']['type_meanImg']]
# obtain FOVs
aligner = roicat.tracking.alignment.Aligner(verbose=VERBOSITY)
From 552706305ef3ae13087c1c87e9dfe070e07512a5 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Mon, 22 Jan 2024 16:51:17 -0500
Subject: [PATCH 08/14] nb: add standalone roicat nb and inspect output nb
---
notebooks/inspect-muse.ipynb | 429 ++++++++++++++
notebooks/roicat-standalone.ipynb | 942 ++++++++++++++++++++++++++++++
2 files changed, 1371 insertions(+)
create mode 100644 notebooks/inspect-muse.ipynb
create mode 100644 notebooks/roicat-standalone.ipynb
diff --git a/notebooks/inspect-muse.ipynb b/notebooks/inspect-muse.ipynb
new file mode 100644
index 0000000..16df68c
--- /dev/null
+++ b/notebooks/inspect-muse.ipynb
@@ -0,0 +1,429 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "39dd752a-c720-4f52-a149-62668664bfd5",
+ "metadata": {},
+ "source": [
+ "# Inspect multisession outputs\n",
+ "\n",
+ "This is currently being tested to inspect the output of processed `roicat` outputs.\n",
+ "\n",
+ "This currently uses `tk` and cannot run on Oscar headless."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "020f3892-fe74-493e-81c3-eb99e5a10adc",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import glob\n",
+ "import warnings\n",
+ "\n",
+ "from pathlib import Path\n",
+ "import pickle\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import skimage\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "%matplotlib tk"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "86808689-73d8-45a2-bdf4-8928b39fa932",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "muse_dir = Path('data/SD_0664/multi-session/plane1/')\n",
+ "roicat_out_file = muse_dir / 'roicat-output.pkl'\n",
+ "aligned_img_file = muse_dir / 'aligned-images.pkl'\n",
+ "roi_table_file = muse_dir / 'finalized-roi.csv'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "37fe3dc0-6aec-4c7b-91f1-5579991bb8b0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fuse_cfg = dict(\n",
+ " background_choice = 'max_proj', # max_proj, meanImg, meanImgE, Vcorr\n",
+ " fuse_alpha = 0.4, # fusing ROI mask with background\n",
+ " fuse_bg_max_p = 98, # for clipping background max value with percentile\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7f967bc2-e41d-45aa-806a-400fb7a9a7ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!ls $muse_dir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4af7a27e-e508-4ea9-a13f-be920421fd49",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(roicat_out_file, 'rb') as f:\n",
+ " roicat_out = pickle.load(f)\n",
+ "\n",
+ "with open(aligned_img_file, 'rb') as f:\n",
+ " aligned_images = pickle.load(f)\n",
+ "\n",
+ "aligned_rois = roicat_out['results']['ROIs']['ROIs_aligned']\n",
+ "\n",
+ "# TODO: unclear if this is the right order\n",
+ "image_dims = (\n",
+ " roicat_out['results']['ROIs']['frame_width'],\n",
+ " roicat_out['results']['ROIs']['frame_height']\n",
+ ")\n",
+ "roi_table = pd.read_csv(roi_table_file)\n",
+ "\n",
+ "num_sessions = len(aligned_rois)\n",
+ "assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))\n",
+ "\n",
+ "num_global_rois = roi_table['global_roi'].nunique()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b03da632-c02f-42f2-8171-3fbd6ebcd0d4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "global_roi_colors = np.random.rand(num_global_rois, 3).round(3)\n",
+ "\n",
+ "roi_table['global_roi_color'] = roi_table['global_roi'].apply(\n",
+ " lambda x: global_roi_colors[x]\n",
+ ")\n",
+ "roi_table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "532ad296-d743-435c-953f-a8f8521acfe4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def throw_invalid_footprint_warning(footprints, roi_table, session):\n",
+ " invalid_max_idx = footprints.max(axis=1).toarray().squeeze() <= 0\n",
+ " if sum(invalid_max_idx) == 0:\n",
+ " return\n",
+ " \n",
+ " invalid_max_idx = np.where(invalid_max_idx)[0]\n",
+ " \n",
+ " roi_record = roi_table.query(\n",
+ " 'session == @session and ' \\\n",
+ " 'session_roi in @invalid_max_idx and' \\\n",
+ " '(num_sessions > 1 or roicat_global_roi >= 0)' \n",
+ " )\n",
+ " if len(roi_record) > 0:\n",
+ " warnings.warn(\n",
+ " 'The following ROI records have invalid footprints, i.e. '\\\n",
+ " '(a) max = 0, (b) persist for more than 1 sessions or `roicat_global_roi=-1`:',\n",
+ " roi_record.drop(columns='global_roi_color').to_dict('records')\n",
+ " )\n",
+ " \n",
+ "def select_session_rois(footprints, roi_table, session):\n",
+ " session_iscell_rois = (\n",
+ " roi_table.query('session == @session')\n",
+ " .sort_values(by='session_roi')\n",
+ " .reset_index(drop=True)\n",
+ " )\n",
+ " roi_colors = np.stack(session_iscell_rois['global_roi_color'].to_list())\n",
+ " roi_iscell_idx = session_iscell_rois['session_roi'].to_list()\n",
+ "\n",
+ " footprints = footprints[roi_iscell_idx]\n",
+ " assert footprints.shape[0] == len(roi_iscell_idx)\n",
+ " \n",
+ " return footprints, roi_colors \n",
+ "\n",
+ "def normalize_footprint(X):\n",
+ " max_X = X.max(axis=1).toarray()\n",
+ " X = X.multiply(1.0 / max_X)\n",
+ " return X\n",
+ " \n",
+ "def sparse2image(sparse_footprints, image_dims):\n",
+ " image = (\n",
+ " (sparse_footprints > 0)\n",
+ " .T\n",
+ " .reshape(*image_dims)\n",
+ " .toarray()\n",
+ " )\n",
+ " return image\n",
+ "\n",
+ "def sparse2contour(sparse_footprints, image_dims):\n",
+ " image = (\n",
+ " (sparse_footprints > 0)\n",
+ " .T\n",
+ " .reshape(*image_dims)\n",
+ " .toarray()\n",
+ " )\n",
+ " contour = skimage.measure.find_contours(image)\n",
+ " return contour\n",
+ "\n",
+ "def color_footprint_one_channel(sparse_footprints, color_vec, image_dims):\n",
+ " image = np.array(\n",
+ " (sparse_footprints > 0)\n",
+ " .multiply(color_vec)\n",
+ " .sum(axis=0)\n",
+ " .T\n",
+ " .reshape(*image_dims)\n",
+ " )\n",
+ " return image\n",
+ "\n",
+ "def color_footprint(sparse_footprints, color_matrix, image_dims):\n",
+ " image = np.stack([\n",
+ " color_footprint_one_channel(\n",
+ " sparse_footprints,\n",
+ " color_vec.reshape(-1,1),\n",
+ " image_dims\n",
+ " )\n",
+ " for color_vec in color_matrix.T\n",
+ " ], axis=-1)\n",
+ " image = np.clip(image, a_min=0.0, a_max=1.0)\n",
+ " return image\n",
+ "\n",
+ "def fuse_rois_in_background(rois, background, alpha=0.5, background_max_percentile=99):\n",
+ " background = (background - background.min()) / \\\n",
+ " (np.percentile(background, background_max_percentile) - background.min())\n",
+ " background = np.clip(background, a_min=0, a_max=1.0)\n",
+ " background = skimage.color.gray2rgb(background)\n",
+ " \n",
+ " fused_image = background * (1 - alpha) + rois * alpha\n",
+ " fused_image = np.clip(fused_image, a_min=0, a_max=1.0) \n",
+ " return fused_image\n",
+ " \n",
+ "def compute_session_fused_footprints(\n",
+ " images, footprints, roi_table, session, \n",
+ " background_choice = 'max_proj',\n",
+ " fuse_alpha = 0.5,\n",
+ " fuse_bg_max_p = 99,\n",
+ "):\n",
+ " background_images = images[session]\n",
+ " if background_choice not in background_images:\n",
+ " warnings.warn(f'{background_choice} not in the aligned images. Using \"fov\" field instead')\n",
+ " assert 'fov' in background_images, '\"fov\" not found in aligned images'\n",
+ " background_choice = 'fov'\n",
+ " background_image = background_images[background_choice]\n",
+ " \n",
+ " sparse_footprints = footprints[session]\n",
+ " throw_invalid_footprint_warning(sparse_footprints, roi_table, session) \n",
+ " \n",
+ " sparse_footprints, roi_colors = select_session_rois(sparse_footprints, roi_table, session)\n",
+ " sparse_footprints = normalize_footprint(sparse_footprints)\n",
+ " colored_footprints = color_footprint(sparse_footprints, roi_colors, image_dims)\n",
+ " fused_footprints = fuse_rois_in_background(\n",
+ " colored_footprints,\n",
+ " background_image,\n",
+ " alpha=fuse_alpha, \n",
+ " background_max_percentile=fuse_bg_max_p\n",
+ " )\n",
+ " return fused_footprints"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ea78c1b-7776-4230-b8ae-d20318c1c0d3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fuse_kwargs = dict(\n",
+ " images=aligned_images,\n",
+ " footprints=aligned_rois,\n",
+ " roi_table=roi_table,\n",
+ " **fuse_cfg\n",
+ ")\n",
+ "\n",
+ "fused_footprints = [\n",
+ " compute_session_fused_footprints(\n",
+ " session=session_idx,\n",
+ " **fuse_kwargs\n",
+ " )\n",
+ " for session_idx in range(num_sessions)\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1e995a99-4110-4335-962d-fd81764a49d3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "roi_contours = []\n",
+ "for session_rois in aligned_rois:\n",
+ " roi_contours.append([\n",
+ " sparse2contour(roi_footprint, image_dims) \n",
+ " for roi_footprint in session_rois\n",
+ " ])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2b8fd153-b55e-4841-a3e8-d1257763afc6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(20,10))\n",
+ "for i in range(num_sessions):\n",
+ " plt.subplot(1,num_sessions,i+1)\n",
+ " plt.imshow(fused_footprints[i])\n",
+ " plt.axis('off')\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e9966739-17a4-482e-929b-1e5460328587",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%matplotlib tk\n",
+ "\n",
+ "contour_kwargs = dict(c='r', lw=2, alpha=0.8)\n",
+ "num_cols = min(num_sessions, 6)\n",
+ "\n",
+ "\n",
+ "fig, axes = plt.subplots(\n",
+ " int(np.ceil(num_sessions/num_cols)), num_cols,\n",
+ " figsize=(20,10),\n",
+ " sharex=True,\n",
+ " sharey=True,\n",
+ ")\n",
+ "ax_list = axes.flatten()\n",
+ "\n",
+ "for i in range(num_sessions):\n",
+ " ax_list[i].imshow(fused_footprints[i])\n",
+ " ax_list[i].set_axis_off()\n",
+ " ax_list[i].set_title(f'session #{i+1}')\n",
+ " ax_list[i].session_index = i\n",
+ " \n",
+ "def find_tagged_objects(fig, value, key='tag'):\n",
+ " objs = fig.findobj(\n",
+ " match=lambda x:\n",
+ " False if not hasattr(x, key) \n",
+ " else value in getattr(x, key)\n",
+ " )\n",
+ " return objs \n",
+ "\n",
+ " \n",
+ "def onclick(event, tag='highlight'): \n",
+ " # remove previous highlighted objects\n",
+ " for x in find_tagged_objects(fig, value=tag):\n",
+ " x.remove()\n",
+ " \n",
+ " # double click: reset title\n",
+ " # not working, TBD\n",
+ " # if event.dblclick:\n",
+ " # for ax in ax_list:\n",
+ " # ax.set_title(f'session #{i+1}')\n",
+ " # return\n",
+ " \n",
+ " # get data\n",
+ " ix, iy, ax = event.xdata, event.ydata, event.inaxes\n",
+ " session = ax.session_index\n",
+ " \n",
+ " # obtain session ROI index\n",
+ " flat_idx = np.ravel_multi_index((round(iy),round(ix)), image_dims)\n",
+ " session_roi = aligned_rois[session][:,flat_idx].nonzero()[0]\n",
+ " if len(session_roi) == 0:\n",
+ " return\n",
+ " session_roi = session_roi[0] # just select first one if there are ovelap\n",
+ " select_global_roi = (\n",
+ " roi_table.query('session == @session and session_roi == @session_roi')\n",
+ " ['global_roi'].to_list()\n",
+ " )\n",
+ " assert len(select_global_roi) == 1\n",
+ " select_global_roi = select_global_roi[0]\n",
+ " \n",
+ " # obtain contours\n",
+ " select_contours = {\n",
+ " r['session']: dict(\n",
+ " contour = roi_contours[r['session']][r['session_roi']],\n",
+ " **r,\n",
+ " )\n",
+ " for _, r in roi_table.query('global_roi == @select_global_roi').iterrows()\n",
+ " }\n",
+ " \n",
+ " # plot contours\n",
+ " for session, ax in enumerate(ax_list):\n",
+ " if session not in select_contours:\n",
+ " ax.set_title(f'session #{i+1} [NOT FOUND]')\n",
+ " continue\n",
+ " \n",
+ " session_contours = select_contours[session]['contour']\n",
+ " select_session_roi = select_contours[session]['session_roi']\n",
+ " \n",
+ " for c in session_contours:\n",
+ " c_handles = ax.plot(c[:,1],c[:,0], **contour_kwargs)\n",
+ " for ch in c_handles:\n",
+ " ch.tag = tag\n",
+ " \n",
+ " ax.set_title(f'session #{i+1} [id={select_session_roi} | ID={select_global_roi}]')\n",
+ " \n",
+ " plt.show()\n",
+ " \n",
+ "cid = fig.canvas.mpl_connect('button_press_event', onclick)\n",
+ "\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "datavis",
+ "language": "python",
+ "name": "datavis"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.9"
+ },
+ "scenes_data": {
+ "active_scene": "Default Scene",
+ "init_scene": "",
+ "scenes": [
+ "Default Scene"
+ ]
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {},
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/roicat-standalone.ipynb b/notebooks/roicat-standalone.ipynb
new file mode 100644
index 0000000..5bc3fa3
--- /dev/null
+++ b/notebooks/roicat-standalone.ipynb
@@ -0,0 +1,942 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "7241263e-8942-41a0-8c5b-7393431912b3",
+ "metadata": {},
+ "source": [
+ "# ROICAT Standalone\n",
+ "\n",
+ "This is a standalone notebook to use `roicat` for cell tracking after `suite2p`.\n",
+ "\n",
+ "This was taken from `cellreg/4-roicat.py` but will not be as updated, and may be removed.\n",
+ "\n",
+ "This can be run on Oscar or locally."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "be3ba033-a6d2-4d5e-a1ca-f9bd6805735c",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import glob\n",
+ "import shutil\n",
+ "from pathlib import Path\n",
+ "import copy\n",
+ "import multiprocessing as mp\n",
+ "import tempfile\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "import roicat\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "\n",
+ "import argparse\n",
+ "import yaml\n",
+ "import pickle\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ac423786-68b2-4495-a6c5-c23f0a676a3d",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "PARAMS = {\n",
+ " 'um_per_pixel': 0.7,\n",
+ " 'background_max_percentile': 99.9,\n",
+ " 'suite2p': { # `roicat.data_importing.Data_suite2p`\n",
+ " 'new_or_old_suite2p': 'new',\n",
+ " 'type_meanImg': 'meanImgE',\n",
+ " },\n",
+ " 'fov_augment': { # `aligner.augment_FOV_images`\n",
+ " 'roi_FOV_mixing_factor': 0.5,\n",
+ " 'use_CLAHE': False,\n",
+ " 'CLAHE_grid_size': 1,\n",
+ " 'CLAHE_clipLimit': 1,\n",
+ " 'CLAHE_normalize': True,\n",
+ " },\n",
+ " 'fit_geometric': { # `aligner.fit_geometric`\n",
+ " 'template': 0, \n",
+ " 'template_method': 'image', \n",
+ " 'mode_transform': 'affine',\n",
+ " 'mask_borders': (5,5,5,5), \n",
+ " 'n_iter': 1000,\n",
+ " 'termination_eps': 1e-6, \n",
+ " 'gaussFiltSize': 15,\n",
+ " 'auto_fix_gaussFilt_step':1,\n",
+ " },\n",
+ " 'fit_nonrigid': { # `aligner.fit_nonrigid`\n",
+ " 'disable': True,\n",
+ " 'template': 0,\n",
+ " 'template_method': 'image',\n",
+ " 'mode_transform':'createOptFlow_DeepFlow',\n",
+ " 'kwargs_mode_transform':None,\n",
+ " },\n",
+ " 'roi_blur': {\n",
+ " 'kernel_halfWidth': 2\n",
+ " }\n",
+ " \n",
+ "}\n",
+ "\n",
+ "DISABLE_NONRIGID = PARAMS['fit_nonrigid'].pop('disable')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bf4851d0-cb6a-4a04-816b-eb21990cc88a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ROOT_DATA_DIR = '/oscar/data/afleisc2/sdaste/ROICat-test/'\n",
+ "SUBJECT_ID = 'SD_0664'\n",
+ "PLANE_ID = 'plane2'\n",
+ "SUITE2P_PATH_MAXDEPTH=6\n",
+ "USE_GPU = False\n",
+ "VERBOSITY = True\n",
+ "OUTPUT_DIR='../data'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c7d73c56-82cc-4a14-a045-fc15ad3e3957",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define paths\n",
+ "SUBJECT_DIR = Path(ROOT_DATA_DIR) / SUBJECT_ID\n",
+ "COLLECTIVE_MUSE_DIR = Path(OUTPUT_DIR) / SUBJECT_ID / 'multi-session' / PLANE_ID\n",
+ "COLLECTIVE_MUSE_FIG_DIR = COLLECTIVE_MUSE_DIR / 'figures'\n",
+ "COLLECTIVE_MUSE_FIG_DIR.mkdir(parents=True, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0bacdcf2-a84f-4227-b13f-a0668b9bde24",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# find suite2p paths\n",
+ "dir_allOuterFolders = str(SUBJECT_DIR)\n",
+ "pathSuffixToStat = 'stat.npy'\n",
+ "pathSuffixToOps = 'ops.npy'\n",
+ "pathShouldHave = fr'suite2p/{PLANE_ID}'\n",
+ "\n",
+ "paths_allStat = roicat.helpers.find_paths(\n",
+ " dir_outer=dir_allOuterFolders,\n",
+ " reMatch=pathSuffixToStat,\n",
+ " reMatch_in_path=pathShouldHave, \n",
+ " depth=SUITE2P_PATH_MAXDEPTH,\n",
+ ")[:]\n",
+ "\n",
+ "paths_allStat = [\n",
+ " x for x in paths_allStat \n",
+ " if pathShouldHave in x\n",
+ "]\n",
+ "\n",
+ "paths_allOps = np.array([\n",
+ " Path(path).resolve().parent / pathSuffixToOps\n",
+ " for path in paths_allStat\n",
+ "])[:]\n",
+ "\n",
+ "\n",
+ "print('Paths to all suite2p STAT files:')\n",
+ "print('\\n'.join(['\\t-' + str(x) for x in paths_allStat]))\n",
+ "print('\\n')\n",
+ "print('Paths to all suite2p OPS files:')\n",
+ "print('\\n'.join(['\\t-' + str(x) for x in paths_allOps]))\n",
+ "print('\\n')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9b95bba8-9e0d-4b0b-9d4f-a527f41622b9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load data\n",
+ "data = roicat.data_importing.Data_suite2p(\n",
+ " paths_statFiles=paths_allStat[:],\n",
+ " paths_opsFiles=paths_allOps[:],\n",
+ " um_per_pixel=PARAMS['um_per_pixel'],\n",
+ " type_meanImg='meanImg', # will be overwritten in the following cell\n",
+ " **{k: v for k, v in PARAMS['suite2p'].items() if k not in ['type_meanImg']},\n",
+ " verbose=VERBOSITY,\n",
+ ")\n",
+ "\n",
+ "assert data.check_completeness(verbose=False)['tracking'],\\\n",
+ " \"Data object is missing attributes necessary for tracking.\"\n",
+ "\n",
+ "# also save iscell paths\n",
+ "data.paths_iscell = [\n",
+ " Path(x).parent / 'iscell.npy'\n",
+ " for x in data.paths_ops\n",
+ "]\n",
+ "\n",
+ "# load all background images\n",
+ "background_types = [\n",
+ " 'meanImg',\n",
+ " 'meanImgE',\n",
+ " 'max_proj',\n",
+ " 'Vcorr',\n",
+ "]\n",
+ "\n",
+ "FOV_backgrounds = {k: [] for k in background_types}\n",
+ "for ops_file in data.paths_ops:\n",
+ " ops = np.load(ops_file, allow_pickle=True).item()\n",
+ " \n",
+ " im_sz = (ops['Ly'], ops['Lx']) \n",
+ " for bg in background_types:\n",
+ " bg_im = ops[bg]\n",
+ " \n",
+ " if bg_im.shape == im_sz:\n",
+ " FOV_backgrounds[bg].append(bg_im)\n",
+ " continue\n",
+ "\n",
+ " print(\n",
+ " f'\\t- File {ops_file}: {bg} shape is {bg_im.shape}, which is cropped from {im_sz}. '\\\n",
+ " '\\n\\tWill attempt to add empty pixels to recover the original shape.'\n",
+ " )\n",
+ "\n",
+ " im = np.zeros(im_sz).astype(bg_im.dtype)\n",
+ " cropped_xrange, cropped_yrange = ops['xrange'], ops['yrange']\n",
+ " im[\n",
+ " cropped_yrange[0]:cropped_yrange[1],\n",
+ " cropped_xrange[0]:cropped_xrange[1]\n",
+ " ] = bg_im\n",
+ " \n",
+ " FOV_backgrounds[bg].append(im)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "39220367-d78f-4073-a2ce-7a976fb8dbc2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "{k: [vi.shape for vi in v] for k,v in FOV_backgrounds.items()}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0001a617-a832-4808-951a-09dc7975f551",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# choice of FOV images to align\n",
+ "data.FOV_images = FOV_backgrounds[PARAMS['suite2p']['type_meanImg']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6ee880cf-1eb1-40b2-b2d5-6e999c80096b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "roicat.visualization.display_toggle_image_stack(data.FOV_images)\n",
+ "roicat.visualization.display_toggle_image_stack(data.get_maxIntensityProjection_spatialFootprints(), clim=[0,1])\n",
+ "roicat.visualization.display_toggle_image_stack(np.concatenate(data.ROI_images, axis=0)[:5000], image_size=(200,200))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bf6d2e0d-b805-43c5-aa8d-d308a80858cd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# obtain FOVs\n",
+ "aligner = roicat.tracking.alignment.Aligner(verbose=VERBOSITY)\n",
+ "\n",
+ "FOV_images = aligner.augment_FOV_images(\n",
+ " ims=data.FOV_images,\n",
+ " spatialFootprints=data.spatialFootprints,\n",
+ " **PARAMS['fov_augment']\n",
+ ")\n",
+ "\n",
+ "roicat.visualization.display_toggle_image_stack(FOV_images)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8c40093c-4ce4-40f6-9a93-16910cb73397",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# geometric fit\n",
+ "aligner.fit_geometric(\n",
+ " ims_moving=FOV_images,\n",
+ " **PARAMS['fit_geometric']\n",
+ ")\n",
+ "aligner.transform_images_geometric(FOV_images)\n",
+ "remap_idx = aligner.remappingIdx_geo\n",
+ "\n",
+ "roicat.visualization.display_toggle_image_stack(aligner.ims_registered_geo)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5dbc228-a463-4c40-b102-9d4e23a39675",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# non-rigid\n",
+ "if not DISABLE_NONRIGID:\n",
+ " aligner.fit_nonrigid(\n",
+ " ims_moving=aligner.ims_registered_geo,\n",
+ " remappingIdx_init=aligner.remappingIdx_geo, \n",
+ " **PARAMS['fit_nonrigid']\n",
+ " )\n",
+ " aligner.transform_images_nonrigid(FOV_images)\n",
+ " remap_idx = aligner.remappingIdx_nonrigid"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "236d1901-1ab7-4187-9bd1-14b377563a98",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# transform ROIs\n",
+ "aligner.transform_ROIs(\n",
+ " ROIs=data.spatialFootprints,\n",
+ " remappingIdx=remap_idx,\n",
+ " normalize=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "885dfb93-63e5-43fe-aa27-2dffaa178227",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# transform other backgrounds\n",
+ "aligned_backgrounds = {k: [] for k in background_types}\n",
+ "for bg in background_types:\n",
+ " aligned_backgrounds[bg] = aligner.transform_images(\n",
+ " FOV_backgrounds[bg],\n",
+ " remappingIdx=remap_idx\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "24b42d5f-2e5e-4ed6-bb82-fa83725ec32e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for bg, im in aligned_backgrounds.items():\n",
+ " roicat.visualization.display_toggle_image_stack(im)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fa7fbb38-0b8e-421d-b97f-424ba55aec13",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(20,20), layout='tight')\n",
+ "types2plt = background_types + ['ROI']\n",
+ "nrows = len(types2plt)\n",
+ "ncols = data.n_sessions\n",
+ "\n",
+ "splt_cnt = 1\n",
+ "for k in types2plt:\n",
+ " image_list = aligned_backgrounds.get(k, aligner.get_ROIsAligned_maxIntensityProjection())\n",
+ " for s_id, img in enumerate(image_list):\n",
+ " plt.subplot(nrows, ncols, splt_cnt)\n",
+ " plt.imshow(\n",
+ " img, cmap='Greys_r',\n",
+ " vmax=np.percentile(\n",
+ " img,\n",
+ " PARAMS['background_max_percentile'] if k!= \"ROI\" else 95\n",
+ " )\n",
+ " )\n",
+ " plt.axis('off')\n",
+ " plt.title(f'Aligned {k} [#{s_id}]') \n",
+ " splt_cnt += 1\n",
+ "\n",
+ "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-fov.png')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "04c4ce2e-6fac-44ae-ba96-ba046cc55d9c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# blur ROI\n",
+ "blurrer = roicat.tracking.blurring.ROI_Blurrer(\n",
+ " frame_shape=(data.FOV_height, data.FOV_width),\n",
+ " plot_kernel=False,\n",
+ " verbose=VERBOSITY,\n",
+ " **PARAMS['roi_blur']\n",
+ ")\n",
+ "\n",
+ "blurrer.blur_ROIs(\n",
+ " spatialFootprints=aligner.ROIs_aligned[:],\n",
+ ")\n",
+ "\n",
+ "# ROInet embedding\n",
+ "# TODO: Parameterize `ROInet_embedder`, `generate_dataloader`\n",
+ "DEVICE = roicat.helpers.set_device(use_GPU=USE_GPU, verbose=VERBOSITY)\n",
+ "dir_temp = tempfile.gettempdir()\n",
+ "\n",
+ "roinet = roicat.ROInet.ROInet_embedder(\n",
+ " device=DEVICE,\n",
+ " dir_networkFiles=dir_temp,\n",
+ " download_method='check_local_first',\n",
+ " download_url='https://osf.io/x3fd2/download',\n",
+ " download_hash='7a5fb8ad94b110037785a46b9463ea94',\n",
+ " forward_pass_version='latent',\n",
+ " verbose=VERBOSITY\n",
+ ")\n",
+ "\n",
+ "roinet.generate_dataloader(\n",
+ " ROI_images=data.ROI_images,\n",
+ " um_per_pixel=data.um_per_pixel,\n",
+ " pref_plot=False,\n",
+ " jit_script_transforms=False,\n",
+ " batchSize_dataloader=8, \n",
+ " pinMemory_dataloader=True,\n",
+ " numWorkers_dataloader=4,\n",
+ " persistentWorkers_dataloader=True,\n",
+ " prefetchFactor_dataloader=2,\n",
+ ")\n",
+ "\n",
+ "roinet.generate_latents()\n",
+ "\n",
+ "# Scattering wavelet embedding\n",
+ "# TODO: Parameterize `SWT`, `SWT.transform`\n",
+ "swt = roicat.tracking.scatteringWaveletTransformer.SWT(\n",
+ " kwargs_Scattering2D={'J': 3, 'L': 12}, \n",
+ " image_shape=data.ROI_images[0].shape[1:3],\n",
+ " device=DEVICE,\n",
+ ")\n",
+ "\n",
+ "swt.transform(\n",
+ " ROI_images=roinet.ROI_images_rs,\n",
+ " batch_size=100,\n",
+ ")\n",
+ "\n",
+ "# Compute similarities\n",
+ "# TODO: Parameterize `ROI_graph`, `compute_similarity_blockwise`, `make_normalized_similarities`\n",
+ "\n",
+ "sim = roicat.tracking.similarity_graph.ROI_graph(\n",
+ " n_workers=-1, \n",
+ " frame_height=data.FOV_height,\n",
+ " frame_width=data.FOV_width,\n",
+ " block_height=128, \n",
+ " block_width=128, \n",
+ " algorithm_nearestNeigbors_spatialFootprints='brute',\n",
+ " verbose=VERBOSITY, \n",
+ ")\n",
+ "\n",
+ "s_sf, s_NN, s_SWT, s_sesh = sim.compute_similarity_blockwise(\n",
+ " spatialFootprints=blurrer.ROIs_blurred,\n",
+ " features_NN=roinet.latents,\n",
+ " features_SWT=swt.latents,\n",
+ " ROI_session_bool=data.session_bool,\n",
+ " spatialFootprint_maskPower=1.0,\n",
+ ")\n",
+ "\n",
+ "sim.make_normalized_similarities(\n",
+ " centers_of_mass=data.centroids,\n",
+ " features_NN=roinet.latents,\n",
+ " features_SWT=swt.latents, \n",
+ " k_max=data.n_sessions*100,\n",
+ " k_min=data.n_sessions*10,\n",
+ " algo_NN='kd_tree',\n",
+ " device=DEVICE,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8c1b1aac-0e82-408c-85c1-9d189c13821e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Clustering\n",
+ "# TODO: Parameterize `find_optimal_parameters_for_pruning`?\n",
+ "clusterer = roicat.tracking.clustering.Clusterer(\n",
+ " s_sf=sim.s_sf,\n",
+ " s_NN_z=sim.s_NN_z,\n",
+ " s_SWT_z=sim.s_SWT_z,\n",
+ " s_sesh=sim.s_sesh,\n",
+ ")\n",
+ "\n",
+ "kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning(\n",
+ " n_bins=None, \n",
+ " smoothing_window_bins=None,\n",
+ " kwargs_findParameters={\n",
+ " 'n_patience': 300,\n",
+ " 'tol_frac': 0.001, \n",
+ " 'max_trials': 1200, \n",
+ " 'max_duration': 60*10, \n",
+ " },\n",
+ " bounds_findParameters={\n",
+ " 'power_NN': (0., 5.),\n",
+ " 'power_SWT': (0., 5.),\n",
+ " 'p_norm': (-5, 0),\n",
+ " 'sig_NN_kwargs_mu': (0., 1.0), \n",
+ " 'sig_NN_kwargs_b': (0.00, 1.5), \n",
+ " 'sig_SWT_kwargs_mu': (0., 1.0),\n",
+ " 'sig_SWT_kwargs_b': (0.00, 1.5),\n",
+ " },\n",
+ " n_jobs_findParameters=-1,\n",
+ ")\n",
+ "\n",
+ "kwargs_mcdm_tmp = kwargs_makeConjunctiveDistanceMatrix_best ## Use the optimized parameters\n",
+ "\n",
+ "clusterer.plot_distSame(kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp)\n",
+ "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-distrib.png')\n",
+ "\n",
+ "clusterer.plot_similarity_relationships(\n",
+ " plots_to_show=[1,2,3], \n",
+ " max_samples=100000, ## Make smaller if it is running too slow\n",
+ " kwargs_scatter={'s':1, 'alpha':0.2},\n",
+ " kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp\n",
+ ");\n",
+ "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'pw-sim-scatter.png')\n",
+ "\n",
+ "clusterer.make_pruned_similarity_graphs(\n",
+ " d_cutoff=None,\n",
+ " kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp,\n",
+ " stringency=1.0,\n",
+ " convert_to_probability=False, \n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "91de02af-fa8b-4252-8695-498fe1ac2b72",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "if data.n_sessions >= 8:\n",
+ " labels = clusterer.fit(\n",
+ " d_conj=clusterer.dConj_pruned,\n",
+ " session_bool=data.session_bool,\n",
+ " min_cluster_size=2,\n",
+ " n_iter_violationCorrection=3,\n",
+ " split_intraSession_clusters=True,\n",
+ " cluster_selection_method='leaf',\n",
+ " d_clusterMerge=None, \n",
+ " alpha=0.999, \n",
+ " discard_failed_pruning=False, \n",
+ " n_steps_clusterSplit=100,\n",
+ " )\n",
+ "\n",
+ "else:\n",
+ " labels = clusterer.fit_sequentialHungarian(\n",
+ " d_conj=clusterer.dConj_pruned, ## Input distance matrix\n",
+ " session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions\n",
+ " thresh_cost=0.6, ## Threshold \n",
+ " )\n",
+ "\n",
+ "quality_metrics = clusterer.compute_quality_metrics()\n",
+ "\n",
+ "labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict = roicat.tracking.clustering.make_label_variants(labels=labels, n_roi_bySession=data.n_roi)\n",
+ "\n",
+ "results = {\n",
+ " \"clusters\":{\n",
+ " \"labels\": labels_squeezed,\n",
+ " \"labels_bySession\": labels_bySession,\n",
+ " \"labels_bool\": labels_bool,\n",
+ " \"labels_bool_bySession\": labels_bool_bySession,\n",
+ " \"labels_dict\": labels_dict,\n",
+ " },\n",
+ " \"ROIs\": {\n",
+ " \"ROIs_aligned\": aligner.ROIs_aligned,\n",
+ " \"ROIs_raw\": data.spatialFootprints,\n",
+ " \"frame_height\": data.FOV_height,\n",
+ " \"frame_width\": data.FOV_width,\n",
+ " \"idx_roi_session\": np.where(data.session_bool)[1],\n",
+ " \"n_sessions\": data.n_sessions,\n",
+ " },\n",
+ " \"input_data\": {\n",
+ " \"paths_stat\": data.paths_stat,\n",
+ " \"paths_ops\": data.paths_ops,\n",
+ " },\n",
+ " \"quality_metrics\": clusterer.quality_metrics if hasattr(clusterer, 'quality_metrics') else None,\n",
+ "}\n",
+ "\n",
+ "run_data = copy.deepcopy({\n",
+ " 'data': data.serializable_dict,\n",
+ " 'aligner': aligner.serializable_dict,\n",
+ " 'blurrer': blurrer.serializable_dict,\n",
+ " 'roinet': roinet.serializable_dict,\n",
+ " 'swt': swt.serializable_dict,\n",
+ " 'sim': sim.serializable_dict,\n",
+ " 'clusterer': clusterer.serializable_dict,\n",
+ "})\n",
+ "\n",
+ "iscell_bySession = [np.load(ic_p)[:,0] for ic_p in data.paths_iscell]\n",
+ "\n",
+ "with open(COLLECTIVE_MUSE_DIR / 'roicat-output.pkl', 'wb') as f:\n",
+ " pickle.dump(dict(\n",
+ " run_data = run_data,\n",
+ " results = results,\n",
+ " iscell = iscell_bySession\n",
+ " ), f)\n",
+ "\n",
+ "\n",
+ "print(f'Number of clusters: {len(np.unique(results[\"clusters\"][\"labels\"]))}')\n",
+ "print(f'Number of discarded ROIs: {(results[\"clusters\"][\"labels\"]==-1).sum()}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ac4f230d-4aa0-40dd-86ef-3deafb75580b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Visualize\n",
+ "confidence = (((results['quality_metrics']['cluster_silhouette'] + 1) / 2) * results['quality_metrics']['cluster_intra_means'])\n",
+ "\n",
+ "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))\n",
+ "\n",
+ "axs[0,0].hist(results['quality_metrics']['cluster_silhouette'], 50);\n",
+ "axs[0,0].set_xlabel('cluster_silhouette');\n",
+ "axs[0,0].set_ylabel('cluster counts');\n",
+ "\n",
+ "axs[0,1].hist(results['quality_metrics']['cluster_intra_means'], 50);\n",
+ "axs[0,1].set_xlabel('cluster_intra_means');\n",
+ "axs[0,1].set_ylabel('cluster counts');\n",
+ "\n",
+ "axs[1,0].hist(confidence, 50);\n",
+ "axs[1,0].set_xlabel('confidence');\n",
+ "axs[1,0].set_ylabel('cluster counts');\n",
+ "\n",
+ "axs[1,1].hist(results['quality_metrics']['sample_silhouette'], 50);\n",
+ "axs[1,1].set_xlabel('sample_silhouette score');\n",
+ "axs[1,1].set_ylabel('roi sample counts');\n",
+ "\n",
+ "fig.savefig(COLLECTIVE_MUSE_FIG_DIR / 'cluster-metrics.png')\n",
+ "\n",
+ "# FOV clusters\n",
+ "FOV_clusters = roicat.visualization.compute_colored_FOV(\n",
+ " spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], \n",
+ " FOV_height=results['ROIs']['frame_height'],\n",
+ " FOV_width=results['ROIs']['frame_width'],\n",
+ " labels=results[\"clusters\"][\"labels_bySession\"], ## cluster labels\n",
+ " # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array\n",
+ " alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),\n",
+ "# alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array\n",
+ ")\n",
+ "\n",
+ "FOV_clusters_with_iscell = roicat.visualization.compute_colored_FOV(\n",
+ " spatialFootprints=[r.power(0.8) for r in results['ROIs']['ROIs_aligned']], ## Spatial footprint sparse arrays\n",
+ " FOV_height=results['ROIs']['frame_height'],\n",
+ " FOV_width=results['ROIs']['frame_width'],\n",
+ " labels=results[\"clusters\"][\"labels_bySession\"], ## cluster labels\n",
+ " # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array\n",
+ " alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4),\n",
+ " alphas_sf=iscell_bySession\n",
+ "# alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array\n",
+ ")\n",
+ "\n",
+ "roicat.helpers.save_gif(\n",
+ " array=FOV_clusters, \n",
+ " path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_allrois.gif'),\n",
+ " frameRate=5.0,\n",
+ " loop=0,\n",
+ ")\n",
+ "\n",
+ "roicat.helpers.save_gif(\n",
+ " array=FOV_clusters_with_iscell, \n",
+ " path=str(COLLECTIVE_MUSE_FIG_DIR/ 'FOV_clusters_iscells.gif'),\n",
+ " frameRate=5.0,\n",
+ " loop=0,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a266c7c8-0cbe-4277-b3b4-8ef338ffd005",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(20,10), layout='tight')\n",
+ "roi_image_dict = {\n",
+ " 'all': FOV_clusters,\n",
+ " 'iscell': FOV_clusters_with_iscell\n",
+ "}\n",
+ "nrows = len(roi_image_dict)\n",
+ "ncols = data.n_sessions\n",
+ "\n",
+ "splt_cnt = 1\n",
+ "for k, image_list in roi_image_dict.items():\n",
+ " for s_id, img in enumerate(image_list):\n",
+ " plt.subplot(nrows, ncols, splt_cnt)\n",
+ " plt.imshow(img)\n",
+ " plt.axis('off')\n",
+ " plt.title(f'Aligned {k} [#{s_id}]') \n",
+ " splt_cnt += 1\n",
+ "\n",
+ "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'aligned-rois.png')\n",
+ "\n",
+ "# save FOVs\n",
+ "num_sessions = data.n_sessions\n",
+ "out_img = []\n",
+ "for d in range(num_sessions):\n",
+ " out_img.append(dict(\n",
+ " fov = aligner.ims_registered_geo[d],\n",
+ " roi_pre_iscell = FOV_clusters[d],\n",
+ " roi_with_iscell = FOV_clusters_with_iscell[d],\n",
+ " **{k: v[d] for k,v in aligned_backgrounds.items()},\n",
+ " ))\n",
+ "\n",
+ "with open(COLLECTIVE_MUSE_DIR / 'aligned-images.pkl', 'wb') as f:\n",
+ " pickle.dump(out_img, f)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "027adff5-6ad3-4c96-87d0-47d8c83f7c9a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# save summary data\n",
+ "df = pd.DataFrame([\n",
+ " dict(\n",
+ " session=i, \n",
+ " global_roi=glv, \n",
+ " session_roi=range(len(glv)),\n",
+ " iscell = iscv\n",
+ " )\n",
+ " for i, (glv, iscv) in enumerate(zip(labels_bySession, iscell_bySession))\n",
+ "]).explode(['global_roi','session_roi', 'iscell']).astype({'iscell': 'bool'})\n",
+ "\n",
+ "df.to_csv(COLLECTIVE_MUSE_DIR / 'summary-roi.csv', index=False)\n",
+ "\n",
+ "# process only iscell for stats\n",
+ "df = (\n",
+ " df.query('iscell')\n",
+ " .reset_index(drop=True)\n",
+ ")\n",
+ "\n",
+ "df = df.merge(\n",
+ " (\n",
+ " df.query('global_roi >= 0')\n",
+ " .groupby('global_roi')\n",
+ " ['session'].agg(lambda x: len(list(x)))\n",
+ " .to_frame('num_sessions')\n",
+ " .reset_index()\n",
+ " ),\n",
+ " how='left'\n",
+ ")\n",
+ "\n",
+ "df = (\n",
+ " df.fillna({'num_sessions': 1})\n",
+ " .astype({'num_sessions': 'int'})\n",
+ ")\n",
+ "\n",
+ "# re-indexing\n",
+ "persistent_roi_reindices = (\n",
+ " df[['num_sessions', 'global_roi']]\n",
+ " .query('global_roi >= 0 and num_sessions > 1')\n",
+ " .drop_duplicates()\n",
+ " .sort_values('num_sessions', ascending=False)\n",
+ " .reset_index(drop=True)\n",
+ " .reset_index()\n",
+ " .set_index('global_roi')\n",
+ " ['index'].to_dict()\n",
+ ")\n",
+ "\n",
+ "df['reindexed_global_roi'] = df['global_roi'].map(persistent_roi_reindices)\n",
+ "\n",
+ "single_roi_start_indices = df['reindexed_global_roi'].max() + 1\n",
+ "single_roi_rows = df.query('reindexed_global_roi.isna()').index\n",
+ "num_single_rois = len(single_roi_rows)\n",
+ "\n",
+ "df.loc[single_roi_rows, 'reindexed_global_roi'] = \\\n",
+ " np.arange(num_single_rois) + single_roi_start_indices\n",
+ "\n",
+ "df['reindexed_global_roi'] = df['reindexed_global_roi'].astype('int')\n",
+ "df = df.rename(columns={\n",
+ " 'global_roi': 'roicat_global_roi', \n",
+ " 'reindexed_global_roi': 'global_roi'\n",
+ "})\n",
+ "\n",
+ "df.to_csv(COLLECTIVE_MUSE_DIR / 'finalized-roi.csv', index=False)\n",
+ "\n",
+ "# plot persistent ROIs summary\n",
+ "persist_rois = (\n",
+ " df\n",
+ " .drop_duplicates(['global_roi'])\n",
+ " .value_counts('num_sessions', sort=False)\n",
+ " .to_frame('num_rois')\n",
+ " .reset_index()\n",
+ ")\n",
+ "\n",
+ "plt.figure(figsize=(4,5))\n",
+ "ax = sns.barplot(\n",
+ " persist_rois, \n",
+ " x = 'num_sessions',\n",
+ " y = 'num_rois',\n",
+ " hue = 'num_sessions',\n",
+ " facecolor = '#afafaf',\n",
+ " dodge=False,\n",
+ " edgecolor='k'\n",
+ ")\n",
+ "sns.despine(trim=True, offset=10)\n",
+ "\n",
+ "plt.legend([], [], frameon=False)\n",
+ "[ax.bar_label(c, padding=5, fontsize=10) for c in ax.containers]\n",
+ "plt.xlabel('# sessions')\n",
+ "plt.ylabel('# rois')\n",
+ "plt.title('Persisted ROIs')\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-overall.png')\n",
+ "\n",
+ "# plot persistent ROIs per sessions\n",
+ "df_sessions = (\n",
+ " df\n",
+ " .value_counts(['session','num_sessions'])\n",
+ " .to_frame('count')\n",
+ " .reset_index()\n",
+ ")\n",
+ "\n",
+ "df_total_per_session = (\n",
+ " df_sessions\n",
+ " .groupby('session')\n",
+ " ['count'].agg('sum')\n",
+ " .to_frame('total_count')\n",
+ " .reset_index()\n",
+ ")\n",
+ "\n",
+ "df_sessions = df_sessions.merge(df_total_per_session, how='left')\n",
+ "df_sessions['percent'] = 100 * df_sessions['count'] / df_sessions['total_count']\n",
+ "\n",
+ "plt.figure(figsize=(10,5))\n",
+ "bar_kwargs = dict(\n",
+ " kind='bar',\n",
+ " stacked=True, \n",
+ " colormap='GnBu', \n",
+ " width=0.7, \n",
+ " edgecolor='k',\n",
+ ")\n",
+ "\n",
+ "ax1 = plt.subplot(121)\n",
+ "\n",
+ "(\n",
+ " df_sessions\n",
+ " .pivot(index='session',columns='num_sessions', values='count')\n",
+ " .fillna(0)\n",
+ " .plot(\n",
+ " **bar_kwargs,\n",
+ " xlabel='session ID',\n",
+ " ylabel='# rois',\n",
+ " legend=False,\n",
+ " ax=ax1)\n",
+ ")\n",
+ "plt.tick_params(rotation=0)\n",
+ "\n",
+ "ax2 = plt.subplot(122)\n",
+ "\n",
+ "(\n",
+ " df_sessions\n",
+ " .pivot(index='session',columns='num_sessions', values='percent')\n",
+ " .fillna(0)\n",
+ " .plot(\n",
+ " **bar_kwargs,\n",
+ " xlabel='session ID',\n",
+ " ylabel='% roi per session',\n",
+ " ax=ax2\n",
+ " )\n",
+ ")\n",
+ "plt.tick_params(rotation=0)\n",
+ "\n",
+ "leg_handles, leg_labels = plt.gca().get_legend_handles_labels()\n",
+ "plt.legend(\n",
+ " reversed(leg_handles),\n",
+ " reversed(leg_labels),\n",
+ " loc='upper right', \n",
+ " bbox_to_anchor=[1.5,1], \n",
+ " title='# sessions',\n",
+ " edgecolor='k',\n",
+ ")\n",
+ "\n",
+ "sns.despine(trim=True, offset=10)\n",
+ "\n",
+ "plt.suptitle(\n",
+ " 'Distribution of detected and aligned ROIs across sessions',\n",
+ ")\n",
+ "plt.tight_layout(w_pad=5)\n",
+ "plt.savefig(COLLECTIVE_MUSE_FIG_DIR / 'num-persist-roi-per-session.png')\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "roicat",
+ "language": "python",
+ "name": "roicat"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.7"
+ },
+ "scenes_data": {
+ "active_scene": "Default Scene",
+ "init_scene": "",
+ "scenes": [
+ "Default Scene"
+ ]
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {},
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From b562f9ef2340c5befca572c5cb431cc1a8d36086 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Wed, 31 Jan 2024 11:10:13 -0500
Subject: [PATCH 09/14] fix: oscar conda env path
---
README.md | 2 +-
scripts/1-moco.sh | 3 +--
scripts/2-deepcad.sh | 3 +--
scripts/3-suite2p.sh | 3 +--
scripts/4-roicat.sh | 3 +--
5 files changed, 5 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
index a9322f5..35e78b0 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ To use on Oscar, first clone the repo. Then load the anaconda module:
```shell
module load miniconda3/23.11.0s
-source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
+source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
```
If you have never loaded the `miniconda3/23.11.0s` module before, you need to initialize
diff --git a/scripts/1-moco.sh b/scripts/1-moco.sh
index bda9c1b..02baacc 100644
--- a/scripts/1-moco.sh
+++ b/scripts/1-moco.sh
@@ -22,8 +22,7 @@ PLANE_LIST=( "plane0" "plane1" )
# activate environment
module load miniconda3/23.11.0s
-source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
-
+source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
conda activate suite2p
command -v python
python "$EXEC_FILE" --help
diff --git a/scripts/2-deepcad.sh b/scripts/2-deepcad.sh
index 0a2f297..7cdc57e 100644
--- a/scripts/2-deepcad.sh
+++ b/scripts/2-deepcad.sh
@@ -22,8 +22,7 @@ PLANE_LIST=( "plane0" "plane1" )
# activate environment
module load miniconda3/23.11.0s
-source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
-
+source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
conda activate deepcad
# TODO: check if this is needed anymore
diff --git a/scripts/3-suite2p.sh b/scripts/3-suite2p.sh
index 7fc4784..d61bc8e 100644
--- a/scripts/3-suite2p.sh
+++ b/scripts/3-suite2p.sh
@@ -22,8 +22,7 @@ PLANE_LIST=( "plane0" "plane1" )
# activate environment
module load miniconda3/23.11.0s
-source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
-
+source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
conda activate suite2p
command -v python
python "$EXEC_FILE" --help
diff --git a/scripts/4-roicat.sh b/scripts/4-roicat.sh
index 56d288f..d571455 100644
--- a/scripts/4-roicat.sh
+++ b/scripts/4-roicat.sh
@@ -19,8 +19,7 @@ PLANE_LIST=( "plane0" "plane1" )
# activate environment
module load miniconda3/23.11.0s
-source /oscar/rt/9.2/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
-
+source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
conda activate roicat
command -v python
python "$EXEC_FILE" --help
From 48fec0fa61f23dea9b37109a82b70e8ee48f1a0b Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Wed, 31 Jan 2024 11:17:21 -0500
Subject: [PATCH 10/14] helpers: add nb to convert binary file to tiff
---
notebooks/convert-bin2tiff.ipynb | 193 +++++++++++++++++++++++++++++++
1 file changed, 193 insertions(+)
create mode 100644 notebooks/convert-bin2tiff.ipynb
diff --git a/notebooks/convert-bin2tiff.ipynb b/notebooks/convert-bin2tiff.ipynb
new file mode 100644
index 0000000..7bd01ce
--- /dev/null
+++ b/notebooks/convert-bin2tiff.ipynb
@@ -0,0 +1,193 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d197cc2c-bc68-485a-ae34-d54ed2e4a582",
+ "metadata": {},
+ "source": [
+ "# Convert `suite2p` binary files to tiff stacks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ea818e9d-0892-4517-9fbc-283a5c81cc24",
+ "metadata": {},
+ "source": [
+ "This is a helper notebook to convert the `suite2p` binary files to tiff stacks (in the same folder).\n",
+ "\n",
+ "| Binary input file | Tiff output file | \n",
+ "| --: | :-- | \n",
+ "| `data.bin` | `chan1.tiff` | \n",
+ "| `data_chan2.bin` | `chan2.tiff` |\n",
+ "| `data_raw.bin` | `chan1_raw.tiff` | \n",
+ "| `data_chan2_raw.bin` | `chan2_raw.tiff` |\n",
+ "\n",
+ "This can be run on any environment as long as it has the necessary packages. If not do:\n",
+ "\n",
+ "```bash\n",
+ "pip install numpy tifffile tqdm\n",
+ "```\n",
+ "\n",
+ "This is preferably running in parallel for faster conversion of multiple files.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "eb6c11d4-ef92-4e25-a152-05794f15a72d",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import glob\n",
+ "\n",
+ "import numpy as np\n",
+ "from tifffile import TiffWriter\n",
+ "from tqdm.notebook import tqdm\n",
+ "from tqdm.contrib.concurrent import process_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "038bc5ca-c48a-4dee-a5b4-a1945459a290",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# detect ops files with a certain path pattern\n",
+ "# alternatively, define `ops_files` as a list pointing to the appropriate `ops.npy` files\n",
+ "main_dir = '/oscar/data/afleisc2/collab/multiday-reg/data/SD_*'\n",
+ "\n",
+ "ops_files = sorted(glob.glob(os.path.join(main_dir, '**/ops.npy'), recursive=True))\n",
+ "ops_files"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7de0e424-f76a-42b0-89bc-eac6d4b8f356",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "len(ops_files)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "66ab5705-9542-4b0a-8aa8-d5bd7752e4eb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def bin2tiff(bin_file, tiff_file, ops_file, tqdm_kwargs=dict()):\n",
+ " \"\"\"Convert a binary mmap file to tiff file from suite2p\"\"\"\n",
+ " # metadata\n",
+ " ops = np.load(ops_file, allow_pickle=True).item()\n",
+ " n_frames, Ly, Lx = ops['nframes'], ops['Ly'], ops['Lx']\n",
+ "\n",
+ " # read in binary file\n",
+ " memmap_obj = np.memmap(\n",
+ " bin_file,\n",
+ " mode='r',\n",
+ " dtype='int16',\n",
+ " shape=(n_frames, Ly, Lx)\n",
+ " )\n",
+ " \n",
+ " # write to tiff file\n",
+ " with TiffWriter(tiff_file, bigtiff=True) as f:\n",
+ " for i in tqdm(range(n_frames), **tqdm_kwargs):\n",
+ " curr_frame = np.floor(memmap_obj[i]).astype(np.int16)\n",
+ " f.write(curr_frame, contiguous=True)\n",
+ " \n",
+ " # close binary file\n",
+ " memmap_obj._mmap.close()\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "158b56ec-8cc3-4b1b-ba2c-5dd125908d2b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def bin2tiff_one_ops(ops_file):\n",
+ " \"\"\"Conversion function for parallel\"\"\"\n",
+ " \n",
+ " ops_dir = os.path.dirname(ops_file)\n",
+ " fluo_files = {\n",
+ " 'data.bin': 'chan1.tiff', \n",
+ " 'data_chan2.bin': 'chan2.tiff',\n",
+ " 'data_raw.bin': 'chan1_raw.tiff', \n",
+ " 'data_chan2_raw.bin': 'chan2_raw.tiff',\n",
+ " }\n",
+ " \n",
+ " for bin_file, tiff_file in fluo_files.items():\n",
+ " bin_file = os.path.join(ops_dir, bin_file)\n",
+ " tiff_file = os.path.join(ops_dir, tiff_file)\n",
+ " if (\n",
+ " (not os.path.exists(bin_file))\n",
+ " or os.path.exists(tiff_file)\n",
+ " ):\n",
+ " continue\n",
+ " \n",
+ " bin2tiff(\n",
+ " bin_file,\n",
+ " tiff_file,\n",
+ " ops_file,\n",
+ " tqdm_kwargs=dict(\n",
+ " disable = True\n",
+ " )\n",
+ " )\n",
+ "\n",
+ "process_map(\n",
+ " bin2tiff_one_ops,\n",
+ " ops_files,\n",
+ " max_workers=16\n",
+ ")\n",
+ " "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "roicat",
+ "language": "python",
+ "name": "roicat"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.7"
+ },
+ "scenes_data": {
+ "active_scene": "Default Scene",
+ "init_scene": "",
+ "scenes": [
+ "Default Scene"
+ ]
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {},
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From 8a5d8a2fd936c47e3dbd148912dbe01c1f319205 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Sun, 18 Feb 2024 15:47:18 -0500
Subject: [PATCH 11/14] helpers: add nb to convert 2-chan tiff from flowreg to
suite2p accepted tiff dimension order
---
.../convert-flowregtiff2suite2ptiff.ipynb | 121 ++++++++++++++++++
1 file changed, 121 insertions(+)
create mode 100644 notebooks/convert-flowregtiff2suite2ptiff.ipynb
diff --git a/notebooks/convert-flowregtiff2suite2ptiff.ipynb b/notebooks/convert-flowregtiff2suite2ptiff.ipynb
new file mode 100644
index 0000000..711cf1a
--- /dev/null
+++ b/notebooks/convert-flowregtiff2suite2ptiff.ipynb
@@ -0,0 +1,121 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "5c04ceeb-e83d-4118-8843-4b0e9afd465a",
+ "metadata": {},
+ "source": [
+ "# Convert 2-channel TIFF from `flow_registration` to `suite2p` \n",
+ "\n",
+ "This notebook converts the TIFF stack output from `flow_registration` with dimension `frame x spatial_1 x spatial_2 x channel` to what `suite2p` expects for a 2-channel TIFF stack: `frame x channel x spatial_1 x spatial_2`. \n",
+ "\n",
+ "This basically just transpose the dimensions and saves as an OME.TIFF\n",
+ "\n",
+ "## Requirements\n",
+ "\n",
+ "- `numpy`\n",
+ "- `pyometiff`\n",
+ "- `scikit-image`\n",
+ "\n",
+ "If these are not installed, do this in the activated environment:\n",
+ "\n",
+ "```bash\n",
+ "pip install numpy pyometiff scikit-image\n",
+ "```\n",
+ "\n",
+ "## Caution on memory\n",
+ "\n",
+ "This current reads in the whole TIFF stack in memory instead doing it in frame batches, so you should sufficient memory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "db413ba4-ece8-451e-a8a6-8b96961be99d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %pip install scikit-image pyometiff numpy\n",
+ "\n",
+ "import os\n",
+ "import glob\n",
+ "\n",
+ "import numpy as np\n",
+ "import skimage.io as skio\n",
+ "from pyometiff import OMETIFFWriter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ee4e1ffe-3cfd-42fb-9f92-6dd843470547",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define paths\n",
+ "input_tiff_path = \"PATH/TO/FLOW/REG/compensated.TIFF\" # input\n",
+ "output_tiff_path = \"out-data/flowreg_tcyx_4s2p.ome.tiff\" # output\n",
+ "output_dim_order = \"TCYX\" # what suite2p wants; don't change"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "aacdafc4-e405-4c51-a6c2-f184047bdff6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# read\n",
+ "X = skio.imread(input_tiff_path) # flowreg is \"TYXC\"\n",
+ "X = np.transpose(X, (0,3,1,2))\n",
+ "\n",
+ "# write\n",
+ "writer = OMETIFFWriter(\n",
+ " fpath=output_tiff_path,\n",
+ " dimension_order=output_dim_order,\n",
+ " array=X,\n",
+ " metadata=dict(),\n",
+ " explicit_tiffdata=False,\n",
+ " bigtiff=True\n",
+ ")\n",
+ "\n",
+ "writer.write()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ },
+ "scenes_data": {
+ "active_scene": "Default Scene",
+ "init_scene": "",
+ "scenes": [
+ "Default Scene"
+ ]
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {},
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From 728751403720284ca6740466ba88c6ce77ee15a5 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Sun, 18 Feb 2024 15:48:19 -0500
Subject: [PATCH 12/14] helpers: add nb to refine roicat output
currently very experimental
---
notebooks/inspect-muse.ipynb | 12 +-
notebooks/refine-muse.ipynb | 766 +++++++++++++++++++++++++++++++++++
2 files changed, 772 insertions(+), 6 deletions(-)
create mode 100644 notebooks/refine-muse.ipynb
diff --git a/notebooks/inspect-muse.ipynb b/notebooks/inspect-muse.ipynb
index 16df68c..da99a07 100644
--- a/notebooks/inspect-muse.ipynb
+++ b/notebooks/inspect-muse.ipynb
@@ -9,16 +9,16 @@
"\n",
"This is currently being tested to inspect the output of processed `roicat` outputs.\n",
"\n",
- "This currently uses `tk` and cannot run on Oscar headless."
+ "This currently uses `tk` and cannot run on Oscar headless.\n",
+ "\n",
+ "Note: this is now superseded by `refine-muse.ipynb`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "020f3892-fe74-493e-81c3-eb99e5a10adc",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"import os\n",
@@ -370,7 +370,7 @@
" # plot contours\n",
" for session, ax in enumerate(ax_list):\n",
" if session not in select_contours:\n",
- " ax.set_title(f'session #{i+1} [NOT FOUND]')\n",
+ " ax.set_title(f'session #{session+1} [NOT FOUND]')\n",
" continue\n",
" \n",
" session_contours = select_contours[session]['contour']\n",
@@ -381,7 +381,7 @@
" for ch in c_handles:\n",
" ch.tag = tag\n",
" \n",
- " ax.set_title(f'session #{i+1} [id={select_session_roi} | ID={select_global_roi}]')\n",
+ " ax.set_title(f'session #{session+1} [id={select_session_roi} | ID={select_global_roi}]')\n",
" \n",
" plt.show()\n",
" \n",
diff --git a/notebooks/refine-muse.ipynb b/notebooks/refine-muse.ipynb
new file mode 100644
index 0000000..12e1e06
--- /dev/null
+++ b/notebooks/refine-muse.ipynb
@@ -0,0 +1,766 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b58dc7ec-7a67-49dc-a26d-b4dfdf90ab3d",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Inspect and refine multisession outputs\n",
+ "\n",
+ "## Description\n",
+ "\n",
+ "This is currently being tested to inspect and refine the output of processed `roicat` outputs.\n",
+ "\n",
+ "Very **EXPERIMENTAL**. Use with caution\n",
+ "\n",
+ "## Requirements\n",
+ "\n",
+ "- `numpy`\n",
+ "- `pandas`\n",
+ "- `scikit-image`\n",
+ "- `matplotlib` with `tk` (can change to others like `qt` if needed, but not tested)\n",
+ "\n",
+ "## Limitations\n",
+ "\n",
+ "This currently uses `tk` and cannot run on Oscar headless. Not sure about Open Ondemand Jupyter.\n",
+ "\n",
+ "So either run this locally or use Open Ondemand Desktop session.\n",
+ "\n",
+ "The rendering is currently not optimized and will be slow (as it replots many things depending on the operation).\n",
+ "\n",
+ "This only plots the `iscell` ones.\n",
+ "\n",
+ "## Instructions\n",
+ "\n",
+ "### Paths and config\n",
+ "\n",
+ "Define the parameters and paths in `Define paths and parameters`\n",
+ "\n",
+ "Then run the whole notebook. \n",
+ "\n",
+ "### GUI interactions\n",
+ "\n",
+ "- Note: \n",
+ " - Sometimes it might take some time to render\n",
+ " - The buttons can sometimes be not very responsive, click hard on them\n",
+ "- To simply inspect: \n",
+ " - Click on an ROI of a cetain session\n",
+ " - If available, **red** contours of the same ROI will show up in other sessions\n",
+ " - The titles will change:\n",
+ " - Lowercase `id` is the session local ROI index\n",
+ " - Uppercase `ID` is the global experiment ROI index\n",
+ " - If ROI cannot be found in a given session based on `roicat` output, the title will say `NOT FOUND`\n",
+ "- At any point, to clear current selections and requests, click on `Clear`\n",
+ "- To **chain** ROIs across sessions: \n",
+ " - First select an ROI so that the **red** contours show up\n",
+ " - Then click `Chain Request`\n",
+ " - Then click on another ROI in other sessions\n",
+ " - **Light blue** contours will show up\n",
+ " - Then click `Chain`\n",
+ " - Wait a bit, the colors will change so they will match\n",
+ " - Click `Clear` and then re-select the ROIs to make sure they now have the **same** global ID (i.e. *chained*)\n",
+ " - Note: only ROIs that do not appear in the same sessions can be chained\n",
+ "- To **unchain** ROIs of a certain session (only one at a time):\n",
+ " - First select an ROI so that the **red** contours show up\n",
+ " - Then click `Unchain Request`\n",
+ " - Then enter the session number (starts from 1) to unchain in the box next to it\n",
+ " - Then press the `Enter` key (or `Return` key on Mac)\n",
+ " - Wait a bit, the color of the requested ROI to unchain will change\n",
+ " - Click `Clear` and then re-select the ROIs to make sure they now have **different** global IDs (i.e. *unchained*)\n",
+ "- If you're satisfied, click `Save`, then the `refined-roi.csv` file will appear\n",
+ " - Check the same `muse_dir` directory to make sure a file `refined-roi.csv` appears\n",
+ " - Note: this will not save the color columns\n",
+ "- To re-inspect the refined output, re-run the notebook with `plot_refined = True`\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2da3aecc-b2a0-45aa-8d57-60cd9e001af0",
+ "metadata": {},
+ "source": [
+ "## Import"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "020f3892-fe74-493e-81c3-eb99e5a10adc",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import glob\n",
+ "import warnings\n",
+ "from datetime import datetime\n",
+ "\n",
+ "from pathlib import Path\n",
+ "import pickle\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import skimage\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.widgets import Button, MultiCursor, TextBox\n",
+ "\n",
+ "# change to other backend if needed\n",
+ "%matplotlib tk"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f83286b4-a951-4e31-a4f2-67bb6e66902c",
+ "metadata": {},
+ "source": [
+ "## Define paths and parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "37fe3dc0-6aec-4c7b-91f1-5579991bb8b0",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Define where multisession directory is (with specified plane)\n",
+ "# e.g:`path/to/SUBJECT_ID/multi-session/PLANE_ID`\n",
+ "muse_dir = 'data/SD_0664/multi-session/plane0'\n",
+ "\n",
+ "# only if there's a `refined-roi.csv` file\n",
+ "# if run for first time of a `muse_dir`, use `False`\n",
+ "# use this to indicate whether you want to replot the refined ROI\n",
+ "plot_refined = True\n",
+ "\n",
+ "# Define visualization configurations (combining ROI masks and background images)\n",
+ "fuse_cfg = dict(\n",
+ " background_choice = 'max_proj', # max_proj, meanImg, meanImgE, Vcorr\n",
+ " fuse_alpha = 0.4, # fusing ROI mask with background\n",
+ " fuse_bg_max_p = 98, # for clipping background max value with percentile\n",
+ ")\n",
+ "\n",
+ "# use -1, 0 or None to indicate num subplot columns = num sessions\n",
+ "# use integer > 2 otherwise \n",
+ "column_wrap = None"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bd081a74-4912-4d92-8708-0df1c09a34ef",
+ "metadata": {},
+ "source": [
+ "## Initialize"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0e61a9bb-ed7b-48f4-b30b-4c3d8fb57514",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define paths\n",
+ "muse_dir = Path(muse_dir)\n",
+ "roicat_out_file = muse_dir / 'roicat-output.pkl'\n",
+ "aligned_img_file = muse_dir / 'aligned-images.pkl'\n",
+ "roi_table_file = muse_dir / 'finalized-roi.csv'\n",
+ "save_roi_file = muse_dir / 'refined-roi.csv'\n",
+ "\n",
+ "if plot_refined:\n",
+ " assert save_roi_file.exists(), f'The refined file \"{save_roi_file}\" does not exist to plot'\n",
+ " roi_table_file = save_roi_file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4af7a27e-e508-4ea9-a13f-be920421fd49",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# read in\n",
+ "with open(roicat_out_file, 'rb') as f:\n",
+ " roicat_out = pickle.load(f)\n",
+ "\n",
+ "with open(aligned_img_file, 'rb') as f:\n",
+ " aligned_images = pickle.load(f)\n",
+ "\n",
+ "aligned_rois = roicat_out['results']['ROIs']['ROIs_aligned']\n",
+ "\n",
+ "# TODO: unclear if this is the right order\n",
+ "image_dims = (\n",
+ " roicat_out['results']['ROIs']['frame_width'],\n",
+ " roicat_out['results']['ROIs']['frame_height']\n",
+ ")\n",
+ "roi_table = pd.read_csv(roi_table_file)\n",
+ "\n",
+ "num_sessions = len(aligned_rois)\n",
+ "assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))\n",
+ "\n",
+ "num_global_rois = roi_table['global_roi'].nunique()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b03da632-c02f-42f2-8171-3fbd6ebcd0d4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# assign unique colors\n",
+ "global_roi_colors = np.random.rand(num_global_rois, 3).round(3)\n",
+ "\n",
+ "roi_table['global_roi_color'] = roi_table['global_roi'].apply(\n",
+ " lambda x: global_roi_colors[x]\n",
+ ")\n",
+ "\n",
+ "# save backup of modifiable columns \n",
+ "backup_tag = datetime.now().strftime('pre[%Y%m%d_%H%M%S]')\n",
+ "\n",
+ "roi_table[f'{backup_tag}::num_sessions'] = roi_table['num_sessions'].copy()\n",
+ "roi_table[f'{backup_tag}::global_roi'] = roi_table['global_roi'].copy()\n",
+ "roi_table[f'{backup_tag}::global_roi_color'] = roi_table['global_roi_color'].copy()\n",
+ "\n",
+ "roi_table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "532ad296-d743-435c-953f-a8f8521acfe4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# functions\n",
+ "# TODO: these should be in a source file\n",
+ "def throw_invalid_footprint_warning(footprints, roi_table, session):\n",
+ " invalid_max_idx = footprints.max(axis=1).toarray().squeeze() <= 0\n",
+ " if sum(invalid_max_idx) == 0:\n",
+ " return\n",
+ " \n",
+ " invalid_max_idx = np.where(invalid_max_idx)[0]\n",
+ " \n",
+ " roi_record = roi_table.query(\n",
+ " 'session == @session and ' \\\n",
+ " 'session_roi in @invalid_max_idx and' \\\n",
+ " '(num_sessions > 1 or roicat_global_roi >= 0)' \n",
+ " )\n",
+ " if len(roi_record) > 0:\n",
+ " warnings.warn(\n",
+ " 'The following ROI records have invalid footprints, i.e. '\\\n",
+ " '(a) max = 0, (b) persist for more than 1 sessions or `roicat_global_roi=-1`:',\n",
+ " roi_record.drop(columns='global_roi_color').to_dict('records')\n",
+ " )\n",
+ " \n",
+ "def select_session_rois(footprints, roi_table, session):\n",
+ " session_iscell_rois = (\n",
+ " roi_table.query('session == @session')\n",
+ " .sort_values(by='session_roi')\n",
+ " .reset_index(drop=True)\n",
+ " )\n",
+ " roi_colors = np.stack(session_iscell_rois['global_roi_color'].to_list())\n",
+ " roi_iscell_idx = session_iscell_rois['session_roi'].to_list()\n",
+ "\n",
+ " footprints = footprints[roi_iscell_idx]\n",
+ " assert footprints.shape[0] == len(roi_iscell_idx)\n",
+ " \n",
+ " return footprints, roi_colors \n",
+ "\n",
+ "def normalize_footprint(X):\n",
+ " max_X = X.max(axis=1).toarray()\n",
+ " X = X.multiply(1.0 / max_X)\n",
+ " return X\n",
+ " \n",
+ "def sparse2image(sparse_footprints, image_dims):\n",
+ " image = (\n",
+ " (sparse_footprints > 0)\n",
+ " .T\n",
+ " .reshape(*image_dims)\n",
+ " .toarray()\n",
+ " )\n",
+ " return image\n",
+ "\n",
+ "def sparse2contour(sparse_footprints, image_dims):\n",
+ " image = (\n",
+ " (sparse_footprints > 0)\n",
+ " .T\n",
+ " .reshape(*image_dims)\n",
+ " .toarray()\n",
+ " )\n",
+ " contour = skimage.measure.find_contours(image)\n",
+ " return contour\n",
+ "\n",
+ "def color_footprint_one_channel(sparse_footprints, color_vec, image_dims):\n",
+ " image = np.array(\n",
+ " (sparse_footprints > 0)\n",
+ " .multiply(color_vec)\n",
+ " .sum(axis=0)\n",
+ " .T\n",
+ " .reshape(*image_dims)\n",
+ " )\n",
+ " return image\n",
+ "\n",
+ "def color_footprint(sparse_footprints, color_matrix, image_dims):\n",
+ " image = np.stack([\n",
+ " color_footprint_one_channel(\n",
+ " sparse_footprints,\n",
+ " color_vec.reshape(-1,1),\n",
+ " image_dims\n",
+ " )\n",
+ " for color_vec in color_matrix.T\n",
+ " ], axis=-1)\n",
+ " image = np.clip(image, a_min=0.0, a_max=1.0)\n",
+ " return image\n",
+ "\n",
+ "def fuse_rois_in_background(rois, background, alpha=0.5, background_max_percentile=99):\n",
+ " background = (background - background.min()) / \\\n",
+ " (np.percentile(background, background_max_percentile) - background.min())\n",
+ " background = np.clip(background, a_min=0, a_max=1.0)\n",
+ " background = skimage.color.gray2rgb(background)\n",
+ " \n",
+ " fused_image = background * (1 - alpha) + rois * alpha\n",
+ " fused_image = np.clip(fused_image, a_min=0, a_max=1.0) \n",
+ " return fused_image\n",
+ " \n",
+ "def compute_session_fused_footprints(\n",
+ " images, footprints, roi_table, session, \n",
+ " background_choice = 'max_proj',\n",
+ " fuse_alpha = 0.5,\n",
+ " fuse_bg_max_p = 99,\n",
+ "):\n",
+ " background_images = images[session]\n",
+ " if background_choice not in background_images:\n",
+ " warnings.warn(f'{background_choice} not in the aligned images. Using \"fov\" field instead')\n",
+ " assert 'fov' in background_images, '\"fov\" not found in aligned images'\n",
+ " background_choice = 'fov'\n",
+ " background_image = background_images[background_choice]\n",
+ " \n",
+ " sparse_footprints = footprints[session]\n",
+ " throw_invalid_footprint_warning(sparse_footprints, roi_table, session) \n",
+ " \n",
+ " sparse_footprints, roi_colors = select_session_rois(sparse_footprints, roi_table, session)\n",
+ " sparse_footprints = normalize_footprint(sparse_footprints)\n",
+ " colored_footprints = color_footprint(sparse_footprints, roi_colors, image_dims)\n",
+ " fused_footprints = fuse_rois_in_background(\n",
+ " colored_footprints,\n",
+ " background_image,\n",
+ " alpha=fuse_alpha, \n",
+ " background_max_percentile=fuse_bg_max_p\n",
+ " )\n",
+ " return fused_footprints"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ea78c1b-7776-4230-b8ae-d20318c1c0d3",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# fuse ROI and background\n",
+ "fuse_kwargs = dict(\n",
+ " images=aligned_images,\n",
+ " footprints=aligned_rois,\n",
+ " roi_table=roi_table,\n",
+ " **fuse_cfg\n",
+ ")\n",
+ "\n",
+ "fused_footprints = [\n",
+ " compute_session_fused_footprints(\n",
+ " session=session_idx,\n",
+ " **fuse_kwargs\n",
+ " )\n",
+ " for session_idx in range(num_sessions)\n",
+ "]\n",
+ "\n",
+ "# get contours\n",
+ "roi_contours = []\n",
+ "for session_rois in aligned_rois:\n",
+ " roi_contours.append([\n",
+ " sparse2contour(roi_footprint, image_dims) \n",
+ " for roi_footprint in session_rois\n",
+ " ])\n",
+ "\n",
+ "# plot to see what they look like\n",
+ "# plt.figure(figsize=(20,10))\n",
+ "# for i in range(num_sessions):\n",
+ "# plt.subplot(1,num_sessions,i+1)\n",
+ "# plt.imshow(fused_footprints[i])\n",
+ "# plt.axis('off')\n",
+ "# plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5ad4f64f-c89a-4a8d-8000-0d7fd52f1c88",
+ "metadata": {},
+ "source": [
+ "## Refine GUI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e9966739-17a4-482e-929b-1e5460328587",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# TODO: set variables inside objects instead of modifying global objects\n",
+ "# - roi_table\n",
+ "# - aligned_images\n",
+ "# - aligned_rois\n",
+ "# - highlight_contour_kwargs\n",
+ "# - request_contour_kwargs\n",
+ "# - save_roi_file\n",
+ "# TODO: figure out how to optimize re-drawing bc very slow right now\n",
+ "# TODO: warn when an operation is not valid\n",
+ "\n",
+ "highlight_contour_kwargs = dict(c='r', lw=2, alpha=0.8)\n",
+ "request_contour_kwargs = dict(c='#67a9cf', lw=2, alpha=0.8, ls='--')\n",
+ "\n",
+ "if column_wrap is None:\n",
+ " column_wrap = num_sessions\n",
+ "if column_wrap < 2:\n",
+ " column_wrap = num_sessions\n",
+ "num_cols = min(num_sessions, column_wrap)\n",
+ "\n",
+ "fig, axes = plt.subplots(\n",
+ " int(np.ceil(num_sessions/num_cols)), num_cols,\n",
+ " figsize=(20,10),\n",
+ " sharex=True,\n",
+ " sharey=True,\n",
+ ")\n",
+ "ax_list = axes.flatten()\n",
+ "\n",
+ "for i in range(num_sessions):\n",
+ " ax_list[i].imshow(fused_footprints[i])\n",
+ " ax_list[i].set_axis_off()\n",
+ " ax_list[i].set_title(f'session #{i+1}')\n",
+ " ax_list[i].session_index = i\n",
+ " \n",
+ "def find_tagged_objects(fig, value, key='tag'):\n",
+ " objs = fig.findobj(\n",
+ " match=lambda x:\n",
+ " False if not hasattr(x, key) \n",
+ " else value in getattr(x, key)\n",
+ " )\n",
+ " return objs \n",
+ "\n",
+ "class MuseStateCallBacks:\n",
+ " state = dict(\n",
+ " chain_request = False,\n",
+ " chain = False,\n",
+ " unchain_request = False,\n",
+ " # unchain = False,\n",
+ " )\n",
+ " selection = dict(\n",
+ " session = None,\n",
+ " session_roi = None,\n",
+ " global_roi = None\n",
+ " )\n",
+ " chain_selection = dict( \n",
+ " session = None,\n",
+ " session_roi = None,\n",
+ " global_roi = None\n",
+ " )\n",
+ " \n",
+ " def __init__(self, save_path):\n",
+ " self.save_path = save_path\n",
+ " \n",
+ " def refuse_footprint(self):\n",
+ " fuse_kwargs = dict(\n",
+ " images=aligned_images,\n",
+ " footprints=aligned_rois,\n",
+ " roi_table=roi_table,\n",
+ " **fuse_cfg\n",
+ " )\n",
+ "\n",
+ " fused_footprints = [\n",
+ " compute_session_fused_footprints(\n",
+ " session=session_idx,\n",
+ " **fuse_kwargs\n",
+ " )\n",
+ " for session_idx in range(num_sessions)\n",
+ " ]\n",
+ " \n",
+ " \n",
+ " for i in range(num_sessions):\n",
+ " ax_list[i].imshow(fused_footprints[i])\n",
+ " \n",
+ " plt.draw()\n",
+ " \n",
+ " def clear(self, event):\n",
+ " self.reset_state()\n",
+ " clear_objs = find_tagged_objects(fig, value='highlight') +\\\n",
+ " find_tagged_objects(fig, value='chain_request')\n",
+ " \n",
+ " for x in clear_objs:\n",
+ " x.remove() \n",
+ " for i, ax in enumerate(ax_list):\n",
+ " ax.set_title(f'session #{i+1}')\n",
+ " \n",
+ " def reset_state(self):\n",
+ " for k in self.state.keys():\n",
+ " self.state[k] = False\n",
+ " \n",
+ " def unchain(self, unchain_session):\n",
+ " print(self.state, self.selection)\n",
+ " if not self.state['unchain_request'] or None in self.selection.values():\n",
+ " self.reset_state()\n",
+ " return\n",
+ " unchain_session = int(unchain_session) - 1\n",
+ " curr_global_roi = self.selection['global_roi']\n",
+ " unchain_roi_idx = (\n",
+ " roi_table\n",
+ " .query('global_roi == @curr_global_roi and session == @unchain_session')\n",
+ " .index.to_list()\n",
+ " )\n",
+ " if len(unchain_roi_idx) < 1:\n",
+ " self.reset_state()\n",
+ " return\n",
+ " unchain_roi_idx = list(unchain_roi_idx)[0]\n",
+ " if roi_table.loc[unchain_roi_idx, 'num_sessions'] < 2:\n",
+ " self.reset_state()\n",
+ " return\n",
+ " next_global_roi = roi_table['global_roi'].max() + 1\n",
+ " roi_table.loc[roi_table['global_roi'] == curr_global_roi, 'num_sessions'] -= 1\n",
+ " roi_table.loc[unchain_roi_idx, 'global_roi'] = next_global_roi \n",
+ " roi_table.at[unchain_roi_idx, 'global_roi_color'] = np.random.rand(3)\n",
+ " \n",
+ " self.reset_state()\n",
+ " self.refuse_footprint()\n",
+ " \n",
+ " def chain(self, event):\n",
+ " print(self.state, self.selection)\n",
+ " if (\n",
+ " not self.state['chain_request'] or \\\n",
+ " None in self.selection.values()\n",
+ " ):\n",
+ " self.reset_state()\n",
+ " return\n",
+ " \n",
+ " if None in self.chain_selection.values():\n",
+ " return\n",
+ " \n",
+ " curr_glob_roi = self.selection['global_roi']\n",
+ " curr_glob_roi_rows = roi_table.query('global_roi == @curr_glob_roi')\n",
+ " \n",
+ " chain_glob_roi = self.chain_selection['global_roi']\n",
+ " chain_glob_roi_rows = roi_table.query('global_roi == @chain_glob_roi')\n",
+ " \n",
+ " if curr_glob_roi == chain_glob_roi:\n",
+ " return\n",
+ " \n",
+ " print(curr_glob_roi_rows, chain_glob_roi_rows)\n",
+ " shared_sessions = (\n",
+ " set(curr_glob_roi_rows['session'].to_list())\n",
+ " .intersection(chain_glob_roi_rows['session'].to_list())\n",
+ " )\n",
+ " union_sessions = (\n",
+ " set(curr_glob_roi_rows['session'].to_list())\n",
+ " .union(chain_glob_roi_rows['session'].to_list())\n",
+ " )\n",
+ " if len(shared_sessions) > 0:\n",
+ " return\n",
+ " \n",
+ " min_glob_roi = min(curr_glob_roi, chain_glob_roi)\n",
+ " glob_color = roi_table.query('global_roi == @min_glob_roi')['global_roi_color'].iloc[0]\n",
+ " concat_idx = list(curr_glob_roi_rows.index) + list(chain_glob_roi_rows.index)\n",
+ " roi_table.loc[concat_idx, 'num_sessions'] = len(union_sessions)\n",
+ " roi_table.loc[concat_idx, 'global_roi'] = min_glob_roi\n",
+ " for idx in concat_idx:\n",
+ " roi_table.at[idx, 'global_roi_color'] = glob_color\n",
+ " \n",
+ " self.reset_state()\n",
+ " self.refuse_footprint()\n",
+ " \n",
+ " def unchain_request(self, event):\n",
+ " self.reset_state()\n",
+ " if None in self.selection.values():\n",
+ " return\n",
+ " self.state['unchain_request'] = True\n",
+ " \n",
+ " def chain_request(self, event):\n",
+ " self.refuse_footprint()\n",
+ " if None in self.selection.values():\n",
+ " return\n",
+ " self.state['chain_request'] = True\n",
+ " \n",
+ " def onclick(self, event, tag='highlight'): \n",
+ " # get data\n",
+ " ix, iy, ax = event.xdata, event.ydata, event.inaxes\n",
+ " if not hasattr(ax, 'session_index'):\n",
+ " return\n",
+ " \n",
+ " contour_kwargs = highlight_contour_kwargs.copy()\n",
+ " if self.state['chain_request']:\n",
+ " tag = 'chain_request'\n",
+ " contour_kwargs = request_contour_kwargs.copy()\n",
+ "\n",
+ " # remove previous highlighted objects \n",
+ " for x in find_tagged_objects(fig, value=tag):\n",
+ " x.remove() \n",
+ " session = ax.session_index\n",
+ "\n",
+ " # obtain session ROI index\n",
+ " flat_idx = np.ravel_multi_index((round(iy),round(ix)), image_dims)\n",
+ " session_roi = aligned_rois[session][:,flat_idx].nonzero()[0]\n",
+ " if len(session_roi) == 0:\n",
+ " return\n",
+ " session_roi = session_roi[0] # just select first one if there are ovelap\n",
+ " select_global_roi = (\n",
+ " roi_table.query('session == @session and session_roi == @session_roi')\n",
+ " ['global_roi'].to_list()\n",
+ " )\n",
+ " if len(select_global_roi) != 1:\n",
+ " return\n",
+ " select_global_roi = select_global_roi[0]\n",
+ " \n",
+ " if not (self.state['chain_request']):\n",
+ " self.selection['session'] = session\n",
+ " self.selection['session_roi'] = session_roi\n",
+ " self.selection['global_roi'] = select_global_roi\n",
+ " else:\n",
+ " self.chain_selection['session'] = session\n",
+ " self.chain_selection['session_roi'] = session_roi\n",
+ " self.chain_selection['global_roi'] = select_global_roi\n",
+ " \n",
+ " # obtain contours\n",
+ " select_contours = {\n",
+ " r['session']: dict(\n",
+ " contour = roi_contours[r['session']][r['session_roi']],\n",
+ " **r,\n",
+ " )\n",
+ " for _, r in roi_table.query('global_roi == @select_global_roi').iterrows()\n",
+ " }\n",
+ "\n",
+ " # plot contours\n",
+ " for session, ax in enumerate(ax_list):\n",
+ " if session not in select_contours:\n",
+ " ax.set_title(f'session #{session+1} [NOT FOUND]')\n",
+ " continue\n",
+ "\n",
+ " session_contours = select_contours[session]['contour']\n",
+ " select_session_roi = select_contours[session]['session_roi']\n",
+ "\n",
+ " for c in session_contours:\n",
+ " c_handles = ax.plot(c[:,1],c[:,0], **contour_kwargs)\n",
+ " for ch in c_handles:\n",
+ " ch.tag = tag\n",
+ "\n",
+ " ax.set_title(f'session #{session+1} [id={select_session_roi} | ID={select_global_roi}]')\n",
+ "\n",
+ " plt.draw()\n",
+ " \n",
+ " def save(self, event):\n",
+ " # currently avoid saving color columns\n",
+ " # TODO: if a file exists, warn before saving\n",
+ " color_columns = (\n",
+ " roi_table\n",
+ " .filter(regex='.*global_roi_color.*')\n",
+ " .columns\n",
+ " .to_list()\n",
+ " )\n",
+ " \n",
+ " (\n",
+ " roi_table\n",
+ " .drop(columns=color_columns)\n",
+ " .to_csv(\n",
+ " self.save_path,\n",
+ " index=False\n",
+ " )\n",
+ " )\n",
+ " \n",
+ "\n",
+ "muse_cb = MuseStateCallBacks(\n",
+ " save_path = save_roi_file,\n",
+ ")\n",
+ "cid = fig.canvas.mpl_connect('button_press_event', muse_cb.onclick)\n",
+ "\n",
+ "ax_buttons = dict(\n",
+ " chain_request = fig.add_axes([0.10, 0.05, 0.08, 0.05]),\n",
+ " chain = fig.add_axes([0.20, 0.05, 0.08, 0.05]),\n",
+ " unchain_request = fig.add_axes([0.40, 0.05, 0.08, 0.05]),\n",
+ " clear = fig.add_axes([0.75, 0.05, 0.08, 0.05]),\n",
+ " save = fig.add_axes([0.9, 0.05, 0.08, 0.05]),\n",
+ ")\n",
+ "\n",
+ "buttons = {\n",
+ " k: Button(v, k.replace('_', ' ').title())\n",
+ " for k, v in ax_buttons.items()\n",
+ "}\n",
+ "buttons['clear'].on_clicked(muse_cb.clear)\n",
+ "buttons['chain_request'].on_clicked(muse_cb.chain_request)\n",
+ "buttons['chain'].on_clicked(muse_cb.chain)\n",
+ "buttons['unchain_request'].on_clicked(muse_cb.unchain_request)\n",
+ "buttons['save'].on_clicked(muse_cb.save)\n",
+ "\n",
+ "ax_text = fig.add_axes([0.52, 0.05, 0.05, 0.05])\n",
+ "text_box = TextBox(ax_text, f\"Session \\n (1-{num_sessions}) \", textalignment=\"left\")\n",
+ "text_box.on_submit(muse_cb.unchain)\n",
+ "\n",
+ "multi = MultiCursor(None, ax_list, color='r', lw=0.5, horizOn=True, vertOn=True)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "eb1175f3-1987-4f5f-b07c-cdd4d7f328c8",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "!ls $muse_dir"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ },
+ "scenes_data": {
+ "active_scene": "Default Scene",
+ "init_scene": "",
+ "scenes": [
+ "Default Scene"
+ ]
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {},
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From fa6212a6e57a630f06cff0b876b360222ad594d4 Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Tue, 16 Jul 2024 12:17:14 -0400
Subject: [PATCH 13/14] fix: issue#9 for refinement notebook
---
notebooks/refine-muse.ipynb | 101 ++++++++++++++++++++++++++----------
1 file changed, 75 insertions(+), 26 deletions(-)
diff --git a/notebooks/refine-muse.ipynb b/notebooks/refine-muse.ipynb
index 12e1e06..fed1026 100644
--- a/notebooks/refine-muse.ipynb
+++ b/notebooks/refine-muse.ipynb
@@ -87,9 +87,7 @@
"cell_type": "code",
"execution_count": null,
"id": "020f3892-fe74-493e-81c3-eb99e5a10adc",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"import os\n",
@@ -123,9 +121,7 @@
"cell_type": "code",
"execution_count": null,
"id": "37fe3dc0-6aec-4c7b-91f1-5579991bb8b0",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"# Define where multisession directory is (with specified plane)\n",
@@ -180,9 +176,7 @@
"cell_type": "code",
"execution_count": null,
"id": "4af7a27e-e508-4ea9-a13f-be920421fd49",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"# read in\n",
@@ -202,25 +196,78 @@
"roi_table = pd.read_csv(roi_table_file)\n",
"\n",
"num_sessions = len(aligned_rois)\n",
- "assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))\n",
- "\n",
- "num_global_rois = roi_table['global_roi'].nunique()"
+ "assert all(num_sessions == np.array([len(aligned_images), roi_table['session'].nunique()]))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "b03da632-c02f-42f2-8171-3fbd6ebcd0d4",
+ "id": "c6158c18-ca02-4fe1-bf93-1b982aaf597f",
"metadata": {},
"outputs": [],
"source": [
"# assign unique colors\n",
- "global_roi_colors = np.random.rand(num_global_rois, 3).round(3)\n",
+ "num_global_rois = roi_table['global_roi'].nunique()\n",
+ "max_global_roi_id = roi_table['global_roi'].max()\n",
+ "\n",
+ "global_roi_colors = np.random.rand(max_global_roi_id, 3).round(3)\n",
"\n",
"roi_table['global_roi_color'] = roi_table['global_roi'].apply(\n",
" lambda x: global_roi_colors[x]\n",
- ")\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7e763fd9-5cd7-4cac-8e17-8a533884b7f6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# check, and recount if needed, the `num_sessions` column\n",
+ "def check_num_sessions(df):\n",
+ " \"Check `num_sessions` column and count if inconsistent\"\n",
+ " \n",
+ " assert all(df.groupby('global_roi')['num_sessions'].nunique() == 1), \\\n",
+ " 'There should only be one value for `num_sessions` per `global_roi`'\n",
+ "\n",
+ " nses_file_df = (\n",
+ " df.filter(['global_roi', 'num_sessions'])\n",
+ " .drop_duplicates()\n",
+ " .set_index('global_roi')\n",
+ " .sort_index()\n",
+ " )\n",
"\n",
+ " nses_recount_df = (\n",
+ " df.groupby('global_roi')\n",
+ " ['session'].nunique()\n",
+ " .to_frame('num_sessions')\n",
+ " )\n",
+ "\n",
+ " nses_diff = nses_file_df.compare(nses_recount_df)\n",
+ " if len(nses_diff) > 0:\n",
+ " print(\n",
+ " '`num_sessions` is incorrect for `global_roi`: ', \n",
+ " nses_diff.index.to_list()\n",
+ " )\n",
+ " print('-> Dropping current `num_sessions` column and re-count in data frame')\n",
+ "\n",
+ " df = (\n",
+ " df.drop(columns='num_sessions')\n",
+ " .merge(nses_recount_df.reset_index(), on='global_roi')\n",
+ " )\n",
+ " return df\n",
+ "\n",
+ "roi_table = check_num_sessions(roi_table)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b03da632-c02f-42f2-8171-3fbd6ebcd0d4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
"# save backup of modifiable columns \n",
"backup_tag = datetime.now().strftime('pre[%Y%m%d_%H%M%S]')\n",
"\n",
@@ -361,9 +408,7 @@
"cell_type": "code",
"execution_count": null,
"id": "9ea78c1b-7776-4230-b8ae-d20318c1c0d3",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"# fuse ROI and background\n",
@@ -411,9 +456,7 @@
"cell_type": "code",
"execution_count": null,
"id": "e9966739-17a4-482e-929b-1e5460328587",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"# TODO: set variables inside objects instead of modifying global objects\n",
@@ -533,9 +576,15 @@
" if roi_table.loc[unchain_roi_idx, 'num_sessions'] < 2:\n",
" self.reset_state()\n",
" return\n",
+ "\n",
" next_global_roi = roi_table['global_roi'].max() + 1\n",
- " roi_table.loc[roi_table['global_roi'] == curr_global_roi, 'num_sessions'] -= 1\n",
- " roi_table.loc[unchain_roi_idx, 'global_roi'] = next_global_roi \n",
+ " \n",
+ " # decrement `num_sessions` for the ROI that used to be associated with the unchained one\n",
+ " roi_table.loc[roi_table['global_roi'] == curr_global_roi, 'num_sessions'] -= 1 \n",
+ " \n",
+ " # unchained ROI becomes a new ROI \n",
+ " roi_table.loc[unchain_roi_idx, 'global_roi'] = next_global_roi\n",
+ " roi_table.loc[unchain_roi_idx, 'num_sessions'] = 1\n",
" roi_table.at[unchain_roi_idx, 'global_roi_color'] = np.random.rand(3)\n",
" \n",
" self.reset_state()\n",
@@ -673,6 +722,8 @@
" .to_list()\n",
" )\n",
" \n",
+ " roi_table = check_num_sessions(roi_table)\n",
+ " \n",
" (\n",
" roi_table\n",
" .drop(columns=color_columns)\n",
@@ -719,9 +770,7 @@
"cell_type": "code",
"execution_count": null,
"id": "eb1175f3-1987-4f5f-b07c-cdd4d7f328c8",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"!ls $muse_dir"
From a89f9fe48ea3ea4f0e7340d143767a3fba21aa7e Mon Sep 17 00:00:00 2001
From: Tuan Pham <42875763+tuanpham96@users.noreply.github.com>
Date: Tue, 16 Jul 2024 12:41:53 -0400
Subject: [PATCH 14/14] fix: index for global_roi_colors + refactor save
---
notebooks/refine-muse.ipynb | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/notebooks/refine-muse.ipynb b/notebooks/refine-muse.ipynb
index fed1026..460a7f8 100644
--- a/notebooks/refine-muse.ipynb
+++ b/notebooks/refine-muse.ipynb
@@ -210,7 +210,7 @@
"num_global_rois = roi_table['global_roi'].nunique()\n",
"max_global_roi_id = roi_table['global_roi'].max()\n",
"\n",
- "global_roi_colors = np.random.rand(max_global_roi_id, 3).round(3)\n",
+ "global_roi_colors = np.random.rand(max_global_roi_id + 1, 3).round(3)\n",
"\n",
"roi_table['global_roi_color'] = roi_table['global_roi'].apply(\n",
" lambda x: global_roi_colors[x]\n",
@@ -722,10 +722,8 @@
" .to_list()\n",
" )\n",
" \n",
- " roi_table = check_num_sessions(roi_table)\n",
- " \n",
" (\n",
- " roi_table\n",
+ " check_num_sessions(roi_table)\n",
" .drop(columns=color_columns)\n",
" .to_csv(\n",
" self.save_path,\n",
@@ -766,6 +764,14 @@
"plt.show()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "1fdbd953-2acf-4b4a-8b34-035895cd4451",
+ "metadata": {},
+ "source": [
+ "Check the modification dates of the `*-roi.csv` files. If they are not what's expected, try using what's inside the `MuseStateCallBacks.save` function explicitly."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -773,7 +779,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!ls $muse_dir"
+ "!ls -lthr $muse_dir"
]
}
],