diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 78afdbcee..67b3ce95d 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -23,7 +23,7 @@ plot_track_graph, ) -from ..settings import raw_dir +from ..settings import raw_dir, video_dir from ..utils.dj_helper_fn import fetch_nwb from .common_behav import RawPosition, VideoFile from .common_interval import IntervalList # noqa F401 @@ -248,51 +248,6 @@ def _fix_kwargs( max_plausible_speed, ) - @staticmethod - def _fix_col_names(spatial_df): - """Renames columns in spatial dataframe according to previous norm - - Accepts unnamed first led, 1 or 0 indexed. - Prompts user for confirmation of renaming unexpected columns. - For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2" - """ - - DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"] - ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"] - ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"] - - input_cols = list(spatial_df.columns) - - has_default = all([c in input_cols for c in DEFAULT_COLS]) - has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS]) - has_1_idx = all([c in input_cols for c in ONE_IDX_COLS]) - - # if unexpected columns, ask user to confirm - if len(input_cols) != 4 or not (has_default or has_0_idx or has_1_idx): - choice = dj.utils.user_choice( - "Unexpected columns in raw position. Assume " - + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" - ) - if choice.lower() not in ["yes", "y"]: - raise ValueError( - f"Unexpected columns in raw position: {input_cols}" - ) - spatial_df.columns = DEFAULT_COLS + input_cols[4:] - - # Ensure data order, only 4 col - spatial_df = ( - spatial_df[DEFAULT_COLS] - if has_default - else spatial_df[ZERO_IDX_COLS] - if has_0_idx - else spatial_df[ONE_IDX_COLS] - ) - - # rename to default - spatial_df.columns = DEFAULT_COLS - - return spatial_df - @staticmethod def _upsample( front_LED, @@ -378,7 +333,7 @@ def calculate_position_info( **kwargs, ) - spatial_df = self._fix_col_names(spatial_df) + spatial_df = _fix_col_names(spatial_df) # Get spatial series properties time = np.asarray(spatial_df.index) # seconds position = np.asarray(spatial_df.iloc[:, :4]) # meters @@ -898,10 +853,10 @@ def make(self, key): VideoFile() & {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} ).fetch1() - io = pynwb.NWBHDF5IO(raw_dir() + video_info["nwb_file_name"], "r") + io = pynwb.NWBHDF5IO(raw_dir + "/" + video_info["nwb_file_name"], "r") nwb_file = io.read() nwb_video = nwb_file.objects[video_info["video_file_object_id"]] - video_filename = nwb_video.external_file.value[0] + video_filename = nwb_video.external_file[0] nwb_base_filename = key["nwb_file_name"].replace(".nwb", "") output_video_filename = ( @@ -909,6 +864,15 @@ def make(self, key): f'{key["position_info_param_name"]}.mp4' ) + # ensure standardized column names + raw_position_df = _fix_col_names(raw_position_df) + # if IntervalPositionInfo supersampled position, downsample to video + if position_info_df.shape[0] > raw_position_df.shape[0]: + ind = np.digitize( + raw_position_df.index, position_info_df.index, right=True + ) + position_info_df = position_info_df.iloc[ind] + centroids = { "red": np.asarray(raw_position_df[["xloc", "yloc"]]), "green": np.asarray(raw_position_df[["xloc2", "yloc2"]]), @@ -925,7 +889,7 @@ def make(self, key): print("Making video...") self.make_video( - video_filename, + f"{video_dir}/{video_filename}", centroids, head_position_mean, head_orientation_mean, @@ -1082,3 +1046,48 @@ def make_video( video.release() out.release() cv2.destroyAllWindows() + + +def _fix_col_names(spatial_df): + """Renames columns in spatial dataframe according to previous norm + + Accepts unnamed first led, 1 or 0 indexed. + Prompts user for confirmation of renaming unexpected columns. + For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2" + """ + + DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"] + ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"] + ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"] + + input_cols = list(spatial_df.columns) + + has_default = all([c in input_cols for c in DEFAULT_COLS]) + has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS]) + has_1_idx = all([c in input_cols for c in ONE_IDX_COLS]) + + if has_default: + # move the 4 position columns to front, continue + spatial_df = spatial_df[DEFAULT_COLS] + elif has_0_idx: + # move the 4 position columns to front, rename to default, continue + spatial_df = spatial_df[ZERO_IDX_COLS] + spatial_df.columns = DEFAULT_COLS + elif has_1_idx: + # move the 4 position columns to front, rename to default, continue + spatial_df = spatial_df[ONE_IDX_COLS] + spatial_df.columns = DEFAULT_COLS + else: + if len(input_cols) != 4 or not has_default: + choice = dj.utils.user_choice( + "Unexpected columns in raw position. Assume " + + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" + ) + if choice.lower() not in ["yes", "y"]: + raise ValueError( + f"Unexpected columns in raw position: {input_cols}" + ) + # rename first 4 columns, keep rest. Rest dropped below + spatial_df.columns = DEFAULT_COLS + input_cols[4:] + + return spatial_df diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index d1e39ff27..9ae32639b 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -438,4 +438,5 @@ def debug_mode(self) -> bool: analysis_dir = sg_config.analysis_dir sorting_dir = sg_config.sorting_dir waveform_dir = sg_config.waveform_dir +video_dir = sg_config.video_dir debug_mode = sg_config.debug_mode