Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multicam DLC project support #684

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
import re
from functools import reduce
from typing import Dict

Expand Down Expand Up @@ -372,7 +373,7 @@ def _no_transaction_make(self, key, verbose=True):
"interval_list_name": interval_list_name,
}
).fetch1("valid_times")

cam_device_str = r"camera_device (\d+)"
is_found = False
for ind, video in enumerate(videos.values()):
if isinstance(video, pynwb.image.ImageSeries):
Expand All @@ -385,7 +386,11 @@ def _no_transaction_make(self, key, verbose=True):
interval_list_contains(valid_times, video_obj.timestamps)
> 0.9 * len(video_obj.timestamps)
):
key["video_file_num"] = ind
nwb_cam_device = video_obj.device.name
# returns whatever was captured in the first group (within the parentheses) of the regular expression -- in this case, 0
key["video_file_num"] = int(
re.match(cam_device_str, nwb_cam_device)[1]
)
camera_name = video_obj.device.camera_name
if CameraDevice & {"camera_name": camera_name}:
key["camera_name"] = video_obj.device.camera_name
Expand Down
7 changes: 2 additions & 5 deletions src/spyglass/position/v1/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
from tqdm import tqdm as tqdm

from spyglass.common.common_behav import VideoFile
from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir


Expand Down Expand Up @@ -418,8 +419,6 @@ def get_video_path(key):
"""
import pynwb

from ...common.common_behav import VideoFile

vf_key = {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]}
VideoFile()._no_transaction_make(vf_key, verbose=False)
video_query = VideoFile & vf_key
Expand All @@ -434,9 +433,7 @@ def get_video_path(key):
with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out:
nwb_file = in_out.read()
nwb_video = nwb_file.objects[video_info["video_file_object_id"]]
video_filepath = VideoFile.get_abs_path(
{"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]}
)
video_filepath = VideoFile.get_abs_path(vf_key)
video_dir = os.path.dirname(video_filepath) + "/"
video_filename = video_filepath.split(video_dir)[-1]
meters_per_pixel = nwb_video.device.meters_per_pixel
Expand Down
15 changes: 8 additions & 7 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

from spyglass.common.common_behav import ( # noqa: F401
RawPosition,
VideoFile,
convert_epoch_interval_name_to_position_interval_name,
)
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.utils.dj_mixin import SpyglassMixin

from ...common.common_nwbfile import AnalysisNwbfile
from ...utils.dj_mixin import SpyglassMixin
from .dlc_utils import OutputLogger, infer_output_dir
from .position_dlc_model import DLCModel

Expand Down Expand Up @@ -87,10 +86,11 @@ def insert_estimation_task(
Parameters
----------
key: DataJoint key specifying a pairing of VideoRecording and Model.
task_mode (bool): Default 'trigger' computation. Or 'load' existing results.
task_mode (bool): Default 'trigger' computation.
Or 'load' existing results.
params (dict): Optional. Parameters passed to DLC's analyze_videos:
videotype, gputouse, save_as_csv, batchsize, cropping, TFGPUinference,
dynamic, robust_nframes, allow_growth, use_shelve
videotype, gputouse, save_as_csv, batchsize, cropping,
TFGPUinference, dynamic, robust_nframes, allow_growth, use_shelve
"""
from .dlc_utils import check_videofile, get_video_path

Expand Down Expand Up @@ -261,7 +261,8 @@ def make(self, key):
del key["meters_per_pixel"]
body_parts = dlc_result.df.columns.levels[0]
body_parts_df = {}
# Insert dlc pose estimation into analysis NWB file for each body part.
# Insert dlc pose estimation into analysis NWB file for
# each body part.
for body_part in bodyparts:
if body_part in body_parts:
body_parts_df[body_part] = pd.DataFrame.from_dict(
Expand Down
102 changes: 93 additions & 9 deletions src/spyglass/position/v1/position_dlc_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def insert_new_project(
}
# TODO: make permissions setting more flexible.
if set_permissions:
raise NotImplementedError(
"permission-setting is not functional at this time"
)
permissions = (
stat.S_IRUSR
| stat.S_IWUSR
Expand Down Expand Up @@ -375,9 +378,84 @@ def insert_new_project(
config_path = config_path.as_posix()
return {"project_name": project_name, "config_path": config_path}

@classmethod
def add_video_files(
cls,
video_list,
config_path=None,
key=None,
output_path: str = os.getenv("DLC_VIDEO_PATH"),
add_new=False,
add_to_files=True,
**kwargs,
):
if add_new & (not config_path):
if not key:
raise ValueError(
"at least one of config_path or key have to be passed if add_new=True"
)
else:
config_path = (cls & key).fetch1("config_path")
if (not key) & add_to_files:
if config_path:
if len(cls & {"config_path": config_path}) == 1:
pass
else:
raise ValueError(
"Cannot set add_to_files=True without passing key"
)
else:
raise ValueError(
"Cannot set add_to_files=True without passing key"
)

if all(isinstance(n, Dict) for n in video_list):
videos_to_convert = [
get_video_path(video_key) for video_key in video_list
]
videos = [
check_videofile(
video_path=video[0],
output_path=output_path,
video_filename=video[1],
)[0].as_posix()
for video in videos_to_convert
]
# If not dict, assume list of video file paths
# that may or may not need to be converted
else:
videos = []
if not all([Path(video).exists() for video in video_list]):
raise OSError("at least one file in video_list does not exist")
for video in video_list:
video_path = Path(video).parent
video_filename = video.rsplit(
video_path.as_posix(), maxsplit=1
)[-1].split("/")[-1]
videos.extend(
[
check_videofile(
video_path=video_path,
output_path=output_path,
video_filename=video_filename,
)[0].as_posix()
]
)
if len(videos) < 1:
raise ValueError(f"no .mp4 videos found in{video_path}")
if add_new:
from deeplabcut import add_new_videos

add_new_videos(config=config_path, videos=videos, copy_videos=True)
if add_to_files:
# Add videos to training files
cls.add_training_files(key, **kwargs)
return videos

@classmethod
def add_training_files(cls, key, **kwargs):
"""Add training videos and labeled frames .h5 and .csv to DLCProject.File"""
"""Add training videos and labeled frames .h5
and .csv to DLCProject.File"""
config_path = (cls & {"project_name": key["project_name"]}).fetch1(
"config_path"
)
Expand All @@ -394,7 +472,8 @@ def add_training_files(cls, key, **kwargs):
)[0]
training_files.extend(
glob.glob(
f"{cfg['project_path']}/labeled-data/{video_name}/*Collected*"
f"{cfg['project_path']}/"
f"labeled-data/{video_name}/*Collected*"
)
)
for video in video_names:
Expand Down Expand Up @@ -457,16 +536,19 @@ def import_labeled_frames(
video_filenames: Union[str, List],
**kwargs,
):
"""Function to import pre-labeled frames from an existing project into a new project
"""Function to import pre-labeled frames from an existing project
into a new project

Parameters
----------
key : Dict
key to specify entry in DLCProject table to add labeled frames to
import_project_path : str
absolute path to project directory containing labeled frames to import
absolute path to project directory containing
labeled frames to import
video_filenames : str or List
filename or list of filenames of video(s) from which to import frames.
filename or list of filenames of video(s)
from which to import frames.
without file extension
"""
project_entry = (cls & key).fetch1()
Expand All @@ -476,9 +558,10 @@ def import_labeled_frames(
f"{current_project_path.as_posix()}/labeled-data"
)
if isinstance(import_project_path, PosixPath):
assert (
import_project_path.exists()
), f"import_project_path: {import_project_path.as_posix()} does not exist"
assert import_project_path.exists(), (
"import_project_path: "
f"{import_project_path.as_posix()} does not exist"
)
import_labeled_data_path = Path(
f"{import_project_path.as_posix()}/labeled-data"
)
Expand All @@ -504,7 +587,8 @@ def import_labeled_frames(
dlc_df.columns = dlc_df.columns.set_levels([team_name], level=0)
dlc_df.to_hdf(
Path(
f"{current_labeled_data_path.as_posix()}/{video_file}/CollectedData_{team_name}.h5"
f"{current_labeled_data_path.as_posix()}/"
f"{video_file}/CollectedData_{team_name}.h5"
).as_posix(),
"df_with_missing",
)
Expand Down
Loading