diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index d7d4759fb..530cdac36 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -1,4 +1,5 @@ import pathlib +import re from functools import reduce from typing import Dict @@ -372,7 +373,7 @@ def _no_transaction_make(self, key, verbose=True): "interval_list_name": interval_list_name, } ).fetch1("valid_times") - + cam_device_str = r"camera_device (\d+)" is_found = False for ind, video in enumerate(videos.values()): if isinstance(video, pynwb.image.ImageSeries): @@ -385,7 +386,11 @@ def _no_transaction_make(self, key, verbose=True): interval_list_contains(valid_times, video_obj.timestamps) > 0.9 * len(video_obj.timestamps) ): - key["video_file_num"] = ind + nwb_cam_device = video_obj.device.name + # returns whatever was captured in the first group (within the parentheses) of the regular expression -- in this case, 0 + key["video_file_num"] = int( + re.match(cam_device_str, nwb_cam_device)[1] + ) camera_name = video_obj.device.camera_name if CameraDevice & {"camera_name": camera_name}: key["camera_name"] = video_obj.device.camera_name diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 6fc7dc741..8ef76074e 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -18,6 +18,7 @@ import pandas as pd from tqdm import tqdm as tqdm +from spyglass.common.common_behav import VideoFile from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir @@ -418,8 +419,6 @@ def get_video_path(key): """ import pynwb - from ...common.common_behav import VideoFile - vf_key = {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} VideoFile()._no_transaction_make(vf_key, verbose=False) video_query = VideoFile & vf_key @@ -434,9 +433,7 @@ def get_video_path(key): with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out: nwb_file = in_out.read() nwb_video = nwb_file.objects[video_info["video_file_object_id"]] - video_filepath = VideoFile.get_abs_path( - {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} - ) + video_filepath = VideoFile.get_abs_path(vf_key) video_dir = os.path.dirname(video_filepath) + "/" video_filename = video_filepath.split(video_dir)[-1] meters_per_pixel = nwb_video.device.meters_per_pixel diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 120d292bb..4f964bb2a 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -10,12 +10,11 @@ from spyglass.common.common_behav import ( # noqa: F401 RawPosition, - VideoFile, convert_epoch_interval_name_to_position_interval_name, ) +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.utils.dj_mixin import SpyglassMixin -from ...common.common_nwbfile import AnalysisNwbfile -from ...utils.dj_mixin import SpyglassMixin from .dlc_utils import OutputLogger, infer_output_dir from .position_dlc_model import DLCModel @@ -87,10 +86,11 @@ def insert_estimation_task( Parameters ---------- key: DataJoint key specifying a pairing of VideoRecording and Model. - task_mode (bool): Default 'trigger' computation. Or 'load' existing results. + task_mode (bool): Default 'trigger' computation. + Or 'load' existing results. params (dict): Optional. Parameters passed to DLC's analyze_videos: - videotype, gputouse, save_as_csv, batchsize, cropping, TFGPUinference, - dynamic, robust_nframes, allow_growth, use_shelve + videotype, gputouse, save_as_csv, batchsize, cropping, + TFGPUinference, dynamic, robust_nframes, allow_growth, use_shelve """ from .dlc_utils import check_videofile, get_video_path @@ -261,7 +261,8 @@ def make(self, key): del key["meters_per_pixel"] body_parts = dlc_result.df.columns.levels[0] body_parts_df = {} - # Insert dlc pose estimation into analysis NWB file for each body part. + # Insert dlc pose estimation into analysis NWB file for + # each body part. for body_part in bodyparts: if body_part in body_parts: body_parts_df[body_part] = pd.DataFrame.from_dict( diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index 0cdb89a8d..e99e19e9e 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -340,6 +340,9 @@ def insert_new_project( } # TODO: make permissions setting more flexible. if set_permissions: + raise NotImplementedError( + "permission-setting is not functional at this time" + ) permissions = ( stat.S_IRUSR | stat.S_IWUSR @@ -375,9 +378,84 @@ def insert_new_project( config_path = config_path.as_posix() return {"project_name": project_name, "config_path": config_path} + @classmethod + def add_video_files( + cls, + video_list, + config_path=None, + key=None, + output_path: str = os.getenv("DLC_VIDEO_PATH"), + add_new=False, + add_to_files=True, + **kwargs, + ): + if add_new & (not config_path): + if not key: + raise ValueError( + "at least one of config_path or key have to be passed if add_new=True" + ) + else: + config_path = (cls & key).fetch1("config_path") + if (not key) & add_to_files: + if config_path: + if len(cls & {"config_path": config_path}) == 1: + pass + else: + raise ValueError( + "Cannot set add_to_files=True without passing key" + ) + else: + raise ValueError( + "Cannot set add_to_files=True without passing key" + ) + + if all(isinstance(n, Dict) for n in video_list): + videos_to_convert = [ + get_video_path(video_key) for video_key in video_list + ] + videos = [ + check_videofile( + video_path=video[0], + output_path=output_path, + video_filename=video[1], + )[0].as_posix() + for video in videos_to_convert + ] + # If not dict, assume list of video file paths + # that may or may not need to be converted + else: + videos = [] + if not all([Path(video).exists() for video in video_list]): + raise OSError("at least one file in video_list does not exist") + for video in video_list: + video_path = Path(video).parent + video_filename = video.rsplit( + video_path.as_posix(), maxsplit=1 + )[-1].split("/")[-1] + videos.extend( + [ + check_videofile( + video_path=video_path, + output_path=output_path, + video_filename=video_filename, + )[0].as_posix() + ] + ) + if len(videos) < 1: + raise ValueError(f"no .mp4 videos found in{video_path}") + if add_new: + from deeplabcut import add_new_videos + + add_new_videos(config=config_path, videos=videos, copy_videos=True) + if add_to_files: + # Add videos to training files + cls.add_training_files(key, **kwargs) + return videos + @classmethod def add_training_files(cls, key, **kwargs): - """Add training videos and labeled frames .h5 and .csv to DLCProject.File""" + """Add training videos and labeled frames .h5 + and .csv to DLCProject.File""" config_path = (cls & {"project_name": key["project_name"]}).fetch1( "config_path" ) @@ -394,7 +472,8 @@ def add_training_files(cls, key, **kwargs): )[0] training_files.extend( glob.glob( - f"{cfg['project_path']}/labeled-data/{video_name}/*Collected*" + f"{cfg['project_path']}/" + f"labeled-data/{video_name}/*Collected*" ) ) for video in video_names: @@ -457,16 +536,19 @@ def import_labeled_frames( video_filenames: Union[str, List], **kwargs, ): - """Function to import pre-labeled frames from an existing project into a new project + """Function to import pre-labeled frames from an existing project + into a new project Parameters ---------- key : Dict key to specify entry in DLCProject table to add labeled frames to import_project_path : str - absolute path to project directory containing labeled frames to import + absolute path to project directory containing + labeled frames to import video_filenames : str or List - filename or list of filenames of video(s) from which to import frames. + filename or list of filenames of video(s) + from which to import frames. without file extension """ project_entry = (cls & key).fetch1() @@ -476,9 +558,10 @@ def import_labeled_frames( f"{current_project_path.as_posix()}/labeled-data" ) if isinstance(import_project_path, PosixPath): - assert ( - import_project_path.exists() - ), f"import_project_path: {import_project_path.as_posix()} does not exist" + assert import_project_path.exists(), ( + "import_project_path: " + f"{import_project_path.as_posix()} does not exist" + ) import_labeled_data_path = Path( f"{import_project_path.as_posix()}/labeled-data" ) @@ -504,7 +587,8 @@ def import_labeled_frames( dlc_df.columns = dlc_df.columns.set_levels([team_name], level=0) dlc_df.to_hdf( Path( - f"{current_labeled_data_path.as_posix()}/{video_file}/CollectedData_{team_name}.h5" + f"{current_labeled_data_path.as_posix()}/" + f"{video_file}/CollectedData_{team_name}.h5" ).as_posix(), "df_with_missing", )