diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bf62d02 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +debugImgs/oe_minus_mkv_diffs.tiff +debugImgs/pdb.tiff +debugImgs/rpi_minus_mkv_diffs.tiff +debugImgs/rpi_minus_oe_diffs.tiff +build +dist +*.egg-info +*.png diff --git a/README.md b/README.md index dc73bb2..77b046a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ To get started, try the following in termanal (these instructions assume you're 4. `git clone https://github.com/guitchounts/moseq2-ephys-sync.git` 5. `cd ./moseq2-ephys-sync/` 6. `python setup.py install` -7. `pip install git+ssh://git@github.com/dattalab/moseq2-extract.git@autosetting-params` (alternatively, try: `pip install git+https://github.com/dattalab/moseq2-extract.git@autosetting-params`) +7. `pip install git+ssh://git@github.com/dattalab/moseq2-extract.git@autosetting-params` (alternatively, try: `pip install git+https://github.com/dattalab/moseq2-extract.git@autosetting-params`) (alternatively, try using conda) 8. `conda install scikit-learn=0.24` (moseq2-extract pins `scikit` to an earlier version; need to update to `0.24` 9. `module load ffmpeg` diff --git a/dist/moseq2_ephys_sync-0.0.1-py3.7.egg b/dist/moseq2_ephys_sync-0.0.1-py3.7.egg index 5b895cf..513be83 100644 Binary files a/dist/moseq2_ephys_sync-0.0.1-py3.7.egg and b/dist/moseq2_ephys_sync-0.0.1-py3.7.egg differ diff --git a/moseq2-extract b/moseq2-extract new file mode 160000 index 0000000..9d08897 --- /dev/null +++ b/moseq2-extract @@ -0,0 +1 @@ +Subproject commit 9d08897161117cbcaf7f759106565aeda535501a diff --git a/moseq2_ephys_sync/.gitignore b/moseq2_ephys_sync/.gitignore new file mode 100644 index 0000000..a81c8ee --- /dev/null +++ b/moseq2_ephys_sync/.gitignore @@ -0,0 +1,138 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/moseq2_ephys_sync/__init__.py b/moseq2_ephys_sync/__init__.py index b6ab921..e69de29 100644 --- a/moseq2_ephys_sync/__init__.py +++ b/moseq2_ephys_sync/__init__.py @@ -1,5 +0,0 @@ -import moseq2_ephys_sync -from . import extract_leds -from . import sync -from . import video -from . import plotting \ No newline at end of file diff --git a/moseq2_ephys_sync/__pycache__/__init__.cpython-37.pyc b/moseq2_ephys_sync/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 2517ba3..0000000 Binary files a/moseq2_ephys_sync/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/moseq2_ephys_sync/__pycache__/extract_leds.cpython-37.pyc b/moseq2_ephys_sync/__pycache__/extract_leds.cpython-37.pyc deleted file mode 100644 index b8342d1..0000000 Binary files a/moseq2_ephys_sync/__pycache__/extract_leds.cpython-37.pyc and /dev/null differ diff --git a/moseq2_ephys_sync/__pycache__/sync.cpython-37.pyc b/moseq2_ephys_sync/__pycache__/sync.cpython-37.pyc deleted file mode 100644 index 9533e79..0000000 Binary files a/moseq2_ephys_sync/__pycache__/sync.cpython-37.pyc and /dev/null differ diff --git a/moseq2_ephys_sync/__pycache__/video.cpython-37.pyc b/moseq2_ephys_sync/__pycache__/video.cpython-37.pyc deleted file mode 100644 index 80a5611..0000000 Binary files a/moseq2_ephys_sync/__pycache__/video.cpython-37.pyc and /dev/null differ diff --git a/moseq2_ephys_sync/arduino.py b/moseq2_ephys_sync/arduino.py new file mode 100644 index 0000000..efbdcf1 --- /dev/null +++ b/moseq2_ephys_sync/arduino.py @@ -0,0 +1,96 @@ +import pandas as pd +import numpy as np +from glob import glob +import pdb + +import sync + +def arduino_workflow(base_path, save_path, num_leds, led_blink_interval, arduino_spec): + """ + Workflow to get codes from arduino txt file. Note arduino sampling rate is calculated empirically below because it's not stable from datapoint to datapoint. + + """ + assert num_leds==4, "Arduino code expects 4 LED channels, other nums of channels not yet supported" + assert arduino_spec is not None, "Arduino source requires a spec for the column names and datatypes (see arg arduino_spec)" + arduino_colnames, arduino_dtypes = get_col_info(arduino_spec) + ino_data = load_arduino_data(base_path, arduino_colnames, arduino_dtypes, file_glob='*.txt') + ino_timestamps = ino_data.time # these are in milliseconds + ino_events = list_to_events(ino_timestamps, ino_data.led1, ino_data.led2, ino_data.led3, ino_data.led4) + ino_average_fs = 1/(np.mean(np.diff(ino_timestamps)))*1000 # fs = sampling freq in Hz + ino_codes, _ = sync.events_to_codes(ino_events, nchannels=4, minCodeTime=(led_blink_interval-1)*1000) # I think as long as the column 'timestamps' in events and the minCodeTime are in the same units, it's fine (for ephys, its nsamples, for arudino, it's ms) + ino_codes = np.asarray(ino_codes) + ino_codes[:,0] = ino_codes[:,0] / 1000 ## convert to seconds + + return ino_codes, ino_average_fs + + + +def get_col_info(spec): + """ + Given a string specifying the experiment type, return expected list of columns in arudino text file + """ + if spec == "fictive_olfaction": + arduino_colnames = ['time', 'led1', 'led2', 'led3', 'led4', 'yaw', 'roll', 'pitch', 'accx', 'accy', 'accz', 'therm', 'olfled'] + arduino_dtypes = ['int64', 'int64', 'int64', 'int64','int64', 'float64', 'float64', 'float64', 'float64', 'float64', 'float64'] + elif spec == "odor_on_wheel": + arduino_colnames = ['time', 'led1', 'led2', 'led3', 'led4', 'wheel'] + arduino_dtypes = ['int64', 'int64', 'int64', 'int64','int64', 'int64'] + return arduino_colnames, arduino_dtypes + + + +def load_arduino_data(base_path, colnames, dtypes, file_glob='*.txt'): + arduino_data = glob(f'{base_path}/{file_glob}') + try: + arduino_data = arduino_data[0] + except IndexError: + raise FileNotFoundError("Could not find arduino data (*.txt) in specified location!") + + dtype_dict = {colname: dtype for colname, dtype in zip(colnames, dtypes)} + try: + # Try loading the entire thing first. + data = pd.read_csv(arduino_data, header=0, names=colnames, dtype=dtype_dict, error_bad_lines=False) + except ValueError: + try: + # If needed, try ignoring the last line. This is slower so we don't use as default. + data = pd.read_csv(arduino_data, header=0, names=colnames, dtype=dtype_dict, error_bad_lines=False, warn_bad_lines=True, skipfooter=1) + except: + raise RuntimeError('Could not load arduino data -- check text file for weirdness. \ + Most common issues text file issues are: \ + -- line that ends with a "-" (minus sign), "." (decima) \ + -- line that begins with a "," (comma) \ + -- usually no more than one issue like this per txt file') + return data + + +def list_to_events(time_list, led1, led2, led3, led4): + """ + Transforms list of times and led states into list of led change events. + --- + Input: pd.Series from arduino text file + --- + Output: + events : 2d array + Array of pixel clock events (single channel transitions) where: + events[:,0] = times + events[:,1] = channels (0-indexed) + events[:,2] = directions (1 or -1) + """ + led_states = [led1, led2, led3, led4] + + # Get lists of relevant times and events + times = pd.Series(dtype='int64', name='times') + channels = pd.Series(dtype='int8', name='channels') + directions = pd.Series(dtype='int8', name='directions') + for i in range(4): + states = led_states[i] + diffs = np.diff(states) + events_idx = np.asarray(diffs != 0).nonzero()[0] + 1 # plus 1, because the event should be the first timepoint where it's different + times = times.append(pd.Series(time_list[events_idx], name='times'), ignore_index=True) + channels = channels.append(pd.Series(np.repeat(i,len(events_idx)), name='channels'), ignore_index=True) + directions = directions.append(pd.Series(np.sign(diffs[events_idx-1]), name='directions'), ignore_index=True) + events = pd.concat([times, channels, directions], axis=1) + sorting = np.argsort(events.loc[:,'times']) + events = events.loc[sorting, :] + assert np.all(np.diff(events.times)>=0), 'Event times are not sorted!' + return np.array(events) diff --git a/moseq2_ephys_sync/extract_leds.py b/moseq2_ephys_sync/extract_leds.py index 18e38cd..f6b0049 100644 --- a/moseq2_ephys_sync/extract_leds.py +++ b/moseq2_ephys_sync/extract_leds.py @@ -1,43 +1,57 @@ ''' -Tools for extracting LED states from mkv files +Tools for extracting LED states from video files ''' + import os import numpy as np from skimage.feature import canny from scipy import ndimage as ndi from skimage.filters import threshold_otsu from moseq2_ephys_sync.plotting import plot_code_chunk, plot_matched_scatter, plot_model_errors, plot_matches_video_time,plot_video_frame +import pdb -def gen_batch_sequence(nframes, chunk_size, overlap, offset=0): - ''' - Generates batches used to chunk videos prior to extraction. +def get_led_data_from_rois(frame_data_chunk, rois, led_thresh=2e4, save_path=None): + """ + Given pre-determined rois for LEDs, return sequences of ons and offs + Inputs: + frame_data_chunk: array-like of video data (typically from moseq_video.load_movie_data() but could be any) + rois (list): ordered list of rois [{specify ROI format, eg x1 x2 y1 y2}] to get LED data from + led_thresh (int): value above which LEDs are considered on. Default 2e4. In the k4a recorder, LEDs that are on register as 65535 = 6.5e4, off ones are roughly 1000. + save_path (str): where to save plots for debugging if desired + Returns: + leds (np.array): (num leds) x (num frames) array of 0s and 1s, indicating if LED is above or below threshold (ie, on or off) + """ - Parameters - ---------- - nframes (int): total number of frames - chunk_size (int): desired chunk size - overlap (int): number of overlapping frames - offset (int): frame offset + leds = [] - Returns - ------- - Yields list of batches - ''' + for i in range(len(rois)): - seq = range(offset, nframes) - out = [] - for i in range(0, len(seq) - overlap, chunk_size - overlap): - out.append(seq[i:i + chunk_size]) - return out + led_x = 2# {get x vals} + led_y = 2# {get y vals} + led = frame_data_chunk[:,led_x,led_y].mean(axis=1) #on/off block signals -def get_led_data(frame_data_chunk,num_leds = 4,chunk_num=0, - flip_horizontal=False,flip_vertical=False,sort_by=None,save_path=None): - + led_on = np.where(np.diff(led) > led_thresh)[0] #rise indices + led_off = np.where(np.diff(led) < -led_thresh)[0] #fall indices + + + led_vec = np.zeros(frame_data_chunk.shape[0]) + led_vec[led_on] = 1 + led_vec[led_off] = -1 + + leds.append(led_vec) + + leds = np.vstack(leds) #spiky differenced signals to extract times + + return leds + + + + +def get_led_data_with_stds(frame_data_chunk, num_leds = 4, chunk_num=0, led_loc=None, + flip_horizontal=False, flip_vertical=False, sort_by=None, save_path=None): - ## cropping: - #frame_data_chunk = frame_data_chunk[:,:,:-100] if flip_horizontal: print('Flipping image horizontally') @@ -48,25 +62,68 @@ def get_led_data(frame_data_chunk,num_leds = 4,chunk_num=0, frame_uint8 = np.asarray(frame_data_chunk / frame_data_chunk.max() * 255, dtype='uint8') + std_px = frame_uint8.std(axis=0) mean_px = frame_uint8.mean(axis=0) vary_px = std_px if np.std(std_px) < np.std(mean_px) else mean_px # pick the one with the lower variance - ## threshold the image to get rid of edge noise: + # Initial thresholding thresh = threshold_otsu(vary_px) thresh_px = np.copy(vary_px) thresh_px[thresh_px num_leds: + # print('Too many features, using second thresholding step...') + # thresh2 = threshold_otsu(thresh_px[thresh_px > 5]) + # thresh_px[thresh_px < thresh2] = 0 + # edges = canny(thresh_px/255.) ## find the edges + # filled_image = ndi.binary_fill_holes(edges) ## fill its edges + # labeled_leds, num_features = ndi.label(filled_image) ## get the clusters + # # plot_video_frame(labeled_leds,'%s/frame_%d_led_labels_secondThreshold.png' % (save_path,chunk_num)) + + # If still too many features, check for location parameter and filter by it + if (num_features > num_leds) and led_loc: + print('Too many features, using provided LED position...') + centers_of_mass = ndi.measurements.center_of_mass(filled_image, labeled_leds, range(1, np.unique(labeled_leds)[-1] + 1)) # exclude 0, which is background + centers_of_mass = [(x/filled_image.shape[0], y/filled_image.shape[1]) for (x,y) in centers_of_mass] # normalize + # x is flipped, y is not + if led_loc == 'topright': + idx = np.asarray([((x < 0.5) and (y > 0.5)) for (x,y) in centers_of_mass]).nonzero()[0] + elif led_loc == 'topleft': + idx = np.asarray([((x > 0.5) and (y > 0.5)) for (x,y) in centers_of_mass]).nonzero()[0] + elif led_loc == 'bottomleft': + idx = np.asarray([((x > 0.5) and (y < 0.5)) for (x,y) in centers_of_mass]).nonzero()[0] + elif led_loc == 'bottomright': + idx = np.asarray([((x < 0.5) and (y < 0.5)) for (x,y) in centers_of_mass]).nonzero()[0] + else: + RuntimeError('led_loc not recognized') + + # Add back one to account for background + idx = idx+1 + # Remove non-LED labels + labeled_leds[~np.isin(labeled_leds, idx)] = 0 + num_features = len(idx) - - if num_features != num_leds: + # Ensure LEDs have labels 1,2,3,4 + if not np.all(idx == np.array([1,2,3,4])): + # pdb.set_trace() + for i,val in enumerate([1,2,3,4]): + labeled_leds[labeled_leds==idx[i]] = val + + # If still too many features, remove small ones + if num_features > num_leds: print('OoOOoOooOooOops! Number of features (%d) did not match the number of LEDs (%d)' % (num_features,num_leds)) ## erase extra labels: @@ -80,6 +137,10 @@ def get_led_data(frame_data_chunk,num_leds = 4,chunk_num=0, print('Erasing extraneous label #%d' % erase) labeled_leds[labeled_leds==erase] = 0 + + # Show led labels for debugging + plot_video_frame(labeled_leds,'%s/frame_%d_led_labels.png' % (save_path,chunk_num) ) + ## assign labels to the LEDs labels = [label for label in np.unique(labeled_leds) if label > 0 ] @@ -150,15 +211,10 @@ def get_events(leds,timestamps,time_offset=0,num_leds=2): directions.append(np.repeat(direction_sign,times_of_dir.shape[0] )) - times = np.hstack(times) channels = np.hstack(channels) directions = np.hstack(directions) - sorting = np.argsort(times) - - events = np.vstack([times[sorting],channels[sorting],directions[sorting]]).T - return events \ No newline at end of file diff --git a/moseq2_ephys_sync/main.py b/moseq2_ephys_sync/main.py index 53cb806..0e7699c 100644 --- a/moseq2_ephys_sync/main.py +++ b/moseq2_ephys_sync/main.py @@ -1,120 +1,92 @@ import numpy as np -import pandas as pd -import sys,os -from tqdm import tqdm -import subprocess -from glob import glob +import os import joblib import argparse -import json - from mlinsights.mlmodel import PiecewiseRegressor -from sklearn.tree import DecisionTreeRegressor from sklearn.preprocessing import KBinsDiscretizer -import moseq2_extract.io.video as moseq_video - -from moseq2_ephys_sync.video import get_mkv_stream_names, get_mkv_info -from moseq2_ephys_sync.extract_leds import gen_batch_sequence, get_led_data, get_events -from moseq2_ephys_sync.sync import events_to_codes, match_codes -from moseq2_ephys_sync.plotting import plot_code_chunk, plot_matched_scatter, plot_model_errors, plot_matches_video_time,plot_video_frame - - -def sync(base_path): - - save_path = '%s/sync/' % base_path - if not os.path.exists(save_path): - os.makedirs(save_path) +import mkv, arduino, ttl, sync, plotting - depth_path = glob('%s/*.mkv' % base_path )[0] +import pdb - print('Running sync on %s.' % base_path) +""" +TODO: - stream_names = get_mkv_stream_names(depth_path) # e.g. {'DEPTH': 0, 'IR': 1} +-- add ROIs capability to MKV workflow +-- refactor extract_leds to work with various other videos - ### make paths for info and timestamps. if they exist, don't recompute: - info_path = '%s/info.json' % base_path - timestamp_path = '%s/mkv_timestamps.csv' % base_path +""" - if (os.path.exists(info_path) and os.path.exists(timestamp_path) ): - - with open(info_path,'r') as f: - info = json.load(f) - timestamps = pd.read_csv(timestamp_path) - timestamps = timestamps.values[:,1].flatten() +def main_function(base_path, output_dir_name, first_source, second_source, led_loc=None, led_blink_interval=5, arduino_spec=None, overwrite_models=False): + """ + Uses 4-bit code sequences to create a piecewise linear model to predict first_source times from second_source times + ---- + Inputs: + base_path (str): path to the .mkv and any other files needed + output_dir: path to save output models and plots. Default: {base_path}/sync. + first_source (str): 'ttl', 'mkv', 'arduino', or 'basler'. Source to be predicted. + ttl: looks for open ephys data in __ format + mkv: looks for an MKV file recorded with the k4a recorder + arduino: looks for a text file with cols specified by arduino_col_type + basler: looks for an mp4 + second_source (str): same as first_source, but these codes are used to predict first_source. + led_loc (str): specifiy one of four corners of the movie in which to find the LEDs: topright, bottomright, topleft, bottomleft + led_blink_interval (int): interval in seconds between LED changes. Typically 5 seconds. - else: - ## get info on the depth file; we'll use this to see how many frames we have - info,timestamps = get_mkv_info(depth_path,stream=stream_names['DEPTH']) + Outputs: + - - ## save info and timestamps: - timestamps = pd.DataFrame(timestamps) - timestamps.to_csv(timestamp_path) # save the timestamps - timestamps = timestamps.values.flatten() - - with open(info_path, 'w') as f: - json.dump(info, f) + Notes: + - Each workflow checks for already-pre-processed data, so that the script should be pretty easy to debug. + """ - + print(f'Running sync on {base_path} with {first_source} as first source and {second_source} as second source.') - ## we'll load the actual frames in chunks of 1000/2000. let's see how many chunks we need: - nframes = info['nframes'] - chunk_size = 2000 + #### SETUP #### + # Built-in params (should make dynamic) + mkv_chunk_size = 2000 num_leds = 4 + ephys_fs = 3e4 # sampling rate in Hz - ## get frame batches like in moseq2-extract: - frame_batches = gen_batch_sequence(info['nframes'], chunk_size, - 0, offset=0) - - print('info = ', info) - print('timestamps.shape = ', timestamps.shape) - - ############### Cycle through the frame chunks to get all LED events: - num_chunks = len(frame_batches) - led_events = [] - print('num_chunks = ', num_chunks) - - led_events_path = '%s_led_events.npz' % os.path.splitext(depth_path)[0] + # Set up save path + save_path = f'{base_path}/{output_dir_name}/' + if not os.path.exists(save_path): + os.makedirs(save_path) - if not os.path.isfile(led_events_path): + # Check if models already exist, only over-write if requested + model_exists_bool = os.path.exists(f'{save_path}/{first_source}_from_{second_source}.p') or os.path.exists(f'{save_path}/{second_source}_from_{first_source}.p') + if model_exists_bool and not overwrite_models: + raise RuntimeError("Models already exist and overwrite_models is false!") - for i in tqdm(range(num_chunks)[0:]): - - frame_data_chunk = moseq_video.load_movie_data(depth_path, - frames=frame_batches[i], - mapping=stream_names['IR'], movie_dtype=">u2", pixel_format="gray16be", - frame_size=info['dims'],timestamps=timestamps,threads=8, - finfo=info) - if i==0: - plot_video_frame(frame_data_chunk.std(axis=0),'%s/frame_std.pdf' % save_path) - leds = get_led_data(frame_data_chunk=frame_data_chunk, - num_leds=num_leds,chunk_num=i,sort_by='horizontal',save_path=save_path) - - time_offset = frame_batches[i][0] ## how many frames away from first chunk's #### frame_chunks[0,i] - - tmp_event = get_events(leds,timestamps[frame_batches[i]],time_offset,num_leds=num_leds) + #### INDIVIDUAL DATA STREAM WORKFLOWS #### - actual_led_nums = np.unique(tmp_event[:,1]) ## i.e. what was found in this chunk + # Deal with first source + if first_source == 'ttl': + first_source_led_codes = ttl.ttl_workflow(base_path, save_path, num_leds, led_blink_interval, ephys_fs) + elif first_source == 'mkv': + first_source_led_codes = mkv.mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size, led_loc) + elif first_source == 'arduino': + first_source_led_codes, ino_average_fs = arduino.arduino_workflow(base_path, save_path, num_leds, led_blink_interval, arduino_spec) - if np.all(actual_led_nums == range(num_leds)): - led_events.append(tmp_event) - else: - print('Found %d LEDs found in chunk %d. Skipping... ' % (len(actual_led_nums),i)) + # Deal with second source + if second_source == 'ttl': + second_source_led_codes = ttl.ttl_workflow(base_path, save_path, num_leds, led_blink_interval, ephys_fs) + elif second_source == 'mkv': + second_source_led_codes = mkv.mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size, led_loc) + elif second_source == 'arduino': + second_source_led_codes, ino_average_fs = arduino.arduino_workflow(base_path, save_path, num_leds, led_blink_interval, arduino_spec) + + # Save the codes for use later + np.savez('%s/codes.npz' % save_path, first_source_codes=first_source_led_codes, second_source_codes=second_source_led_codes) - - - led_events = np.concatenate(led_events) - ## optional: save the events for further use - np.savez(led_events_path,led_events=led_events) - - else: - led_events = np.load(led_events_path)['led_events'] - + # Visualize a small chunk of the bit codes. do you see a match? + # Codes array should have times in seconds by this point + plotting.plot_code_chunk(first_source_led_codes, second_source_led_codes, save_path) ################################# Load the ephys TTL data ##################################### @@ -133,96 +105,95 @@ def sync(base_path): led_fs = 30 - led_interval = 5 # seconds - - - ## convert the LED events to bit codes: - led_codes, latencies = events_to_codes(led_events,nchannels=4,minCodeTime=led_interval-1) - led_codes = np.asarray(led_codes) - ## convert the ephys TTL events to bit codes: - ephys_events = np.vstack([ephys_timestamps[abs(channels)!=5],abs(channels[abs(channels)!=5])-1,np.sign(channels[abs(channels)!=5]) ]).T - ephys_codes, ephys_latencies = events_to_codes(ephys_events,nchannels=4,minCodeTime=(led_interval-1)*ephys_fs) - ephys_codes = np.asarray(ephys_codes) + #### SYNCING :D #### - np.savez('%s/codes.npz' % save_path, led_codes=led_codes, ephys_codes=ephys_codes) + # Returns two columns of matched event times + matches = np.asarray(sync.match_codes(first_source_led_codes[:,0], ## all times should be in seconds by here + first_source_led_codes[:,1], + second_source_led_codes[:,0], + second_source_led_codes[:,1], + minMatch=10,maxErr=0,remove_duplicates=True )) - ## visualize a small chunk of the bit codes. do you see a match? - plot_code_chunk(ephys_codes,led_codes,ephys_fs,save_path) + ## Plot the matched codes against each other: + plotting.plot_matched_scatter(matches, save_path) - ################### Match the codes! ########################## - matches = np.asarray(match_codes(ephys_codes[:,0] / ephys_fs, ## converting the ephys times to seconds for matching (led times already in seconds) - ephys_codes[:,1], - led_codes[:,0], - led_codes[:,1], - minMatch=10,maxErr=0,remove_duplicates=True ) ) - ## plot the matched codes against each other: - plot_matched_scatter(matches,save_path) - - ####################### Make the models! #################### - - ephys_model = PiecewiseRegressor(verbose=True, - binner=KBinsDiscretizer(n_bins=10)) - ephys_model.fit(matches[:,0].reshape(-1, 1), matches[:,1]) - - - predicted_video_matches = ephys_model.predict(matches[:,0].reshape(-1, 1) ) ## for checking the error - - predicted_video_times = ephys_model.predict(ephys_codes[:,0].reshape(-1, 1) / ephys_fs ) ## for all predicted times - - joblib.dump(ephys_model, '%s/ephys_timebase.p' % save_path) - print('Saved ephys model') - - ## how big are the differences between the matched ephys and video code times ? - time_errors = (predicted_video_matches - matches[:,1]) - - ## plot model errors: - plot_model_errors(time_errors,save_path) - - ## plot the codes on the same time scale - plot_matches_video_time(predicted_video_times,ephys_codes,led_codes,save_path) - - - ################################# - - video_model = PiecewiseRegressor(verbose=True, - binner=KBinsDiscretizer(n_bins=10)) - video_model.fit(matches[:,1].reshape(-1, 1), matches[:,0]) + #### Make the models! #### + # Rename for clarity. + ground_truth_source1_event_times = matches[:,0] + ground_truth_source2_event_times = matches[:,1] + + # Model first source from second soure, and vice versa. + # I'm sure there's a cleaner way to do this, but it works for now. + for i in range(2): + if i == 0: + s1 = ground_truth_source1_event_times + t1 = first_source_led_codes + n1 = first_source + s2 = ground_truth_source2_event_times + t2 = second_source_led_codes + n2 = second_source + elif i == 1: + s1 = ground_truth_source2_event_times + t1 = second_source_led_codes + n1 = second_source + s2 = ground_truth_source1_event_times + t2 = first_source_led_codes + n2 = first_source + + # Learn to predict s1 from s2. Syntax is fit(X,Y). + mdl = PiecewiseRegressor(verbose=True, + binner=KBinsDiscretizer(n_bins=10)) + mdl.fit(s2.reshape(-1, 1), s1) - predicted_ephys_matches = video_model.predict(matches[:,1].reshape(-1, 1) ) + # Verify accuracy of predicted event times + predicted_event_times = mdl.predict(s2.reshape(-1, 1) ) + time_errors = predicted_event_times - s1 + plotting.plot_model_errors(time_errors,save_path) - predicted_ephys_times = video_model.predict(led_codes[:,0].reshape(-1, 1) ) + # Verify accuracy of all predicted times + all_predicted_times = mdl.predict(t2[:,0].reshape(-1, 1) ) + plotting.plot_matches_video_time(all_predicted_times, t2, t1, save_path) - joblib.dump(video_model, '%s/video_timebase.p' % save_path) - print('Saved video model') + # Save + joblib.dump(mdl, f'{save_path}/{n1}_from_{n2}.p') + print(f'Saved model that predicts {n1} from {n2}') print('Syncing complete. FIN') - if __name__ == "__main__" : - ## take a config file w/ a list of paths, sync each of those and plot the results in a subfolder called /sync/ - parser = argparse.ArgumentParser() - parser.add_argument('-path', type=str) - + parser.add_argument('-path', type=str) # path to data + parser.add_argument('-o', '--output_dir_name', type=str, default='sync') # name of output folder within path + parser.add_argument('-s1', '--first_source', type=str) # ttl, mkv, arduino (txt) + parser.add_argument('-s2', '--second_source', type=str) # ttl, mkv, arduino + parser.add_argument('--led_loc', type=str) + parser.add_argument('--led_blink_interval', type=int, default=5) # default blink every 5 seconds + parser.add_argument('--arduino_spec', type=str, help="Currently supported: fictive_olfaction, odor_on_wheel, ") # specifiy cols in arduino text file + parser.add_argument('--overwrite_models', action="store_true") # overwrites old models if True (1) + settings = parser.parse_args(); - base_path = settings.path #'/n/groups/datta/maya/ofa-snr/mj-snr-01/mj_snr_01_2021-03-24_11-06-33/' - - - sync(base_path) + main_function(base_path=settings.path, + output_dir_name=settings.output_dir_name, + first_source=settings.first_source, + second_source=settings.second_source, + led_loc=settings.led_loc, + led_blink_interval=settings.led_blink_interval, + arduino_spec=settings.arduino_spec, + overwrite_models=settings.overwrite_models) diff --git a/moseq2_ephys_sync/mkv.py b/moseq2_ephys_sync/mkv.py new file mode 100644 index 0000000..acfc7be --- /dev/null +++ b/moseq2_ephys_sync/mkv.py @@ -0,0 +1,268 @@ +from datetime import time +import numpy as np +import pandas as pd +import sys,os +from tqdm import tqdm +import subprocess +from glob import glob +import joblib +import argparse +import json +import moseq2_extract.io.video as moseq_video +import subprocess + +import plotting, extract_leds, sync + +def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size=2000, led_loc=None, led_rois=None): + """ + Workflow to extract led codes from an MKV file + + """ + + # Set up paths + depth_path = glob('%s/*.mkv' % base_path )[0] + stream_names = get_mkv_stream_names(depth_path) # e.g. {'DEPTH': 0, 'IR': 1} + info_path = '%s/info.json' % base_path # make paths for info and timestamps. if they exist, don't recompute: + timestamp_path = '%s/mkv_timestamps.csv' % base_path + + + # Load timestamps and mkv info if exist, otherwise calculate + if (os.path.exists(info_path) and os.path.exists(timestamp_path) ): + + with open(info_path,'r') as f: + info = json.load(f) + + timestamps = pd.read_csv(timestamp_path) + timestamps = timestamps.values[:,1].flatten() + + else: + ## get info on the depth file; we'll use this to see how many frames we have + info,timestamps = get_mkv_info(depth_path,stream=stream_names['DEPTH']) + + ## save info and timestamps: + timestamps = pd.DataFrame(timestamps) + timestamps.to_csv(timestamp_path) # save the timestamps + timestamps = timestamps.values.flatten() + + with open(info_path, 'w') as f: + json.dump(info, f) + + # Debugging + # print('info = ', info) + # print('timestamps.shape = ', timestamps.shape) + + + ############### Cycle through the frame chunks to get all LED events: ############### + + # Prepare to load video (use frame batches like in moseq2-extract) + frame_batches = gen_batch_sequence(info['nframes'], mkv_chunk_size, + 0, offset=0) + num_chunks = len(frame_batches) + mkv_led_events = [] + print('num_chunks = ', num_chunks) + + mkv_led_events_path = '%s_led_events.npz' % os.path.splitext(depth_path)[0] + + # Do the loading + if not os.path.isfile(mkv_led_events_path): + + for i in tqdm(range(num_chunks)[0:]): + # for i in [45]: + + frame_data_chunk = moseq_video.load_movie_data(depth_path, + frames=frame_batches[i], + mapping=stream_names['IR'], movie_dtype=">u2", pixel_format="gray16be", + frame_size=info['dims'],timestamps=timestamps,threads=8, + finfo=info) + + if i==0: + plotting.plot_video_frame(frame_data_chunk.std(axis=0),'%s/frame_std.pdf' % save_path) + + if led_rois is not None: + leds = extract_leds.get_led_data_from_rois(frame_data_chunk=frame_data_chunk, rois=led_rois, save_path=save_path) + else: + leds = extract_leds.get_led_data_with_stds(frame_data_chunk=frame_data_chunk, + num_leds=num_leds,chunk_num=i, led_loc=led_loc, sort_by='horizontal',save_path=save_path) + + time_offset = frame_batches[i][0] ## how many frames away from first chunk's #### frame_chunks[0,i] + + tmp_event = extract_leds.get_events(leds,timestamps[frame_batches[i]],time_offset,num_leds=num_leds) + + actual_led_nums = np.unique(tmp_event[:,1]) ## i.e. what was found in this chunk + + + if np.all(actual_led_nums == range(num_leds)): + mkv_led_events.append(tmp_event) + else: + print('Found %d LEDs found in chunk %d. Skipping... ' % (len(actual_led_nums),i)) + + + + mkv_led_events = np.concatenate(mkv_led_events) + + ## optional: save the events for further use + np.savez(mkv_led_events_path,led_events=mkv_led_events) + + else: + mkv_led_events = np.load(mkv_led_events_path)['led_events'] + + + + ############### Convert the LED events to bit codes ############### + mkv_led_codes, latencies = sync.events_to_codes(mkv_led_events, nchannels=4, minCodeTime=(led_blink_interval-1)) + mkv_led_codes = np.asarray(mkv_led_codes) + + return mkv_led_codes + + +### MKV HELPER FUNCTIONS ### + +def get_mkv_info(fileloc, stream=1): + stream_features = ["width", "height", "r_frame_rate", "pix_fmt"] + + outs = {} + for _feature in stream_features: + command = [ + "ffprobe", + "-select_streams", + "v:{}".format(int(stream)), + "-v", + "fatal", + "-show_entries", + "stream={}".format(_feature), + "-of", + "default=noprint_wrappers=1:nokey=1", + fileloc, + "-sexagesimal", + ] + ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out, err = ffmpeg.communicate() + if err: + print(err) + outs[_feature] = out.decode("utf-8").rstrip("\n") + + # need to get duration and nframes the old fashioned way + outs["duration"] = get_mkv_duration(fileloc) + timestamps = get_mkv_timestamps(fileloc,stream) + outs["nframes"] = len(timestamps) + + return ( + { + "file": fileloc, + "dims": (int(outs["width"]), int(outs["height"])), + "fps": float(outs["r_frame_rate"].split("/")[0]) + / float(outs["r_frame_rate"].split("/")[1]), + "duration": outs["duration"], + "pixel_format": outs["pix_fmt"], + "nframes": outs["nframes"], + }, + timestamps, + ) + +def get_mkv_duration(fileloc, stream=1): + command = [ + "ffprobe", + "-v", + "fatal", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + fileloc, + ] + + ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out, err = ffmpeg.communicate() + if err: + print(err) + return float(out.decode("utf-8").rstrip("\n")) + + +def get_mkv_timestamps(fileloc, stream=1,threads=8): + command = [ + "ffprobe", + "-select_streams", + "v:{}".format(int(stream)), + "-v", + "fatal", + "-threads", str(threads), + "-show_entries", + "frame=pkt_pts_time", + "-of", + "csv=p=0", + fileloc, + ] + + ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out, err = ffmpeg.communicate() + if err: + print(err) + timestamps = out.decode("utf-8").rstrip("\n").split("\n") + timestamps = np.array([float(_) for _ in timestamps]) + return timestamps + +def get_mkv_stream_names(fileloc): + stream_tag = "title" + + outs = {} + command = [ + "ffprobe", + "-v", + "fatal", + "-show_entries", + "stream_tags={}".format(stream_tag), + "-of", + "default=noprint_wrappers=1:nokey=1", + fileloc, + ] + ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out, err = ffmpeg.communicate() + if err: + print(err) + out = out.decode("utf-8").rstrip("\n").split("\n") + + + ## !! changed the key/value order here from what JM had: (i.e. so the string name is the key, the stream is the value) + return dict(list(zip(out,np.arange(len(out))))) + + +def get_mkv_stream_tag(fileloc, stream=1, tag="K4A_START_OFFSET_NS"): + + command = [ + "ffprobe", + "-v", + "fatal", + "-show_entries", + "format_tags={}".format(tag), + "-of", + "default=noprint_wrappers=1:nokey=1", + fileloc, + ] + ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out, err = ffmpeg.communicate() + if err: + print(err) + out = out.decode("utf-8").rstrip("\n") + return out + +def gen_batch_sequence(nframes, chunk_size, overlap, offset=0): + ''' + Generates batches used to chunk videos prior to extraction. + + Parameters + ---------- + nframes (int): total number of frames + chunk_size (int): desired chunk size + overlap (int): number of overlapping frames + offset (int): frame offset + + Returns + ------- + Yields list of batches + ''' + + seq = range(offset, nframes) + out = [] + for i in range(0, len(seq) - overlap, chunk_size - overlap): + out.append(seq[i:i + chunk_size]) + return out \ No newline at end of file diff --git a/moseq2_ephys_sync/plotting.py b/moseq2_ephys_sync/plotting.py index 5ef338b..7840002 100644 --- a/moseq2_ephys_sync/plotting.py +++ b/moseq2_ephys_sync/plotting.py @@ -7,22 +7,35 @@ plt.rcParams['pdf.fonttype'] = 'truetype' import numpy as np -## visualize a small chunk of the bit codes. do you see a match? -def plot_code_chunk(ephys_codes,led_codes,ephys_fs,save_path): + +def plot_code_chunk(source2_codes, led_codes, save_path, fname='match_check' ): + """ + Visualize a small chunk of the bit codes. do you see a match? + --- + Input: + codes : 2d array + Array of reconstructed pixel clock codes where: + codes[:,0] = time (already converted to seconds in main script) + codes[:,1] = code + codes[:,2] = trigger channel + These codes are NOT offset for latencies of the triggered channel + """ f,axarr = plt.subplots(2,1,dpi=600,sharex=True) - axarr[0].plot(ephys_codes[:,0]/ephys_fs,ephys_codes[:,1],label='ephys bit codes') + axarr[0].plot(source2_codes[:,0], source2_codes[:,1],label='ephys bit codes') + axarr[0].set_title('Source2 codes') axarr[1].plot(led_codes[:,0],led_codes[:,1],label='video bit codes') + axarr[1].set_title('MKV codes') + plt.xlim([0,200]) - plt.xlabel('time (sec)') plt.ylabel('bit code') plt.legend() - f.savefig('%s/bit_code_chunk.pdf' % save_path) + f.savefig(f'{save_path}/{fname}.png') plt.close(f) @@ -41,20 +54,20 @@ def plot_matched_scatter(matches,save_path): plt.xlabel('time of ephys codes') plt.ylabel('time of video codes') - f.savefig('%s/matched_codes_scatter.pdf' % save_path) + f.savefig('%s/matched_codes_scatter.png' % save_path) plt.close(f) ## plot model errors: -def plot_model_errors(time_errors,save_path): +def plot_model_errors(time_errors, save_path, fname='model_errors'): f = plt.figure(dpi=600) ax = plt.hist(time_errors) - plt.title('%.2f sec. mean abs. error in Ephys Code Times' % np.abs(np.mean(time_errors))) + plt.title('%.2f sec. mean abs. error in second source Times' % np.abs(np.mean(time_errors))) plt.xlabel('Predicted - actual matched video code times') - f.savefig('%s/ephys_model_errors.pdf' % save_path) + f.savefig(f'{save_path}/{fname}.png') plt.close(f) @@ -73,7 +86,7 @@ def plot_matches_video_time(predicted_video_times,ephys_codes,led_codes,save_pat plt.legend() - f.savefig('%s/matched_codes_video_time.pdf' % save_path) + f.savefig('%s/matched_codes_video_time.png' % save_path) plt.close(f) diff --git a/moseq2_ephys_sync/sync.py b/moseq2_ephys_sync/sync.py index a8bcc94..4d28fd2 100644 --- a/moseq2_ephys_sync/sync.py +++ b/moseq2_ephys_sync/sync.py @@ -12,7 +12,7 @@ def events_to_codes(events, nchannels, minCodeTime): # swap_12_codes = 1,swap_03 ---------- events : 2d array Array of pixel clock events (single channel transitions) where: - events[:,0] = times + events[:,0] = times (in n samples??) events[:,1] = channels events[:,2] = directions nchannels : int diff --git a/moseq2_ephys_sync/ttl.py b/moseq2_ephys_sync/ttl.py new file mode 100644 index 0000000..6f41450 --- /dev/null +++ b/moseq2_ephys_sync/ttl.py @@ -0,0 +1,21 @@ +import numpy as np + +from glob import glob +import sync + +def ttl_workflow(base_path, save_path, num_leds, led_blink_interval, ephys_fs): + """ + + """ + assert num_leds==4, "TTL code expects 4 LED channels, other nums of channels not yet supported" + + ephys_ttl_path = glob('%s/**/TTL_*/' % base_path,recursive = True)[0] + channels = np.load('%s/channel_states.npy' % ephys_ttl_path) + ephys_timestamps = np.load('%s/timestamps.npy' % ephys_ttl_path) # these are in sample number + ttl_channels = [-4,-3,-2,-1,1,2,3,4] + ttl_bool = np.isin(channels, ttl_channels) + ephys_events = np.vstack([ephys_timestamps[ttl_bool], abs(channels[ttl_bool])-1, np.sign(channels[ttl_bool])]).T + codes, ephys_latencies = sync.events_to_codes(ephys_events, nchannels=num_leds, minCodeTime=(led_blink_interval-1)*ephys_fs) + codes = np.asarray(codes) + + return codes \ No newline at end of file diff --git a/moseq2_ephys_sync/video.py b/moseq2_ephys_sync/video.py deleted file mode 100644 index e1e7b52..0000000 --- a/moseq2_ephys_sync/video.py +++ /dev/null @@ -1,136 +0,0 @@ -''' -Tools for extracting info, timestamps, and frames from mkv files -''' - -import os -import subprocess -import numpy as np - - -def get_mkv_info(fileloc, stream=1): - stream_features = ["width", "height", "r_frame_rate", "pix_fmt"] - - outs = {} - for _feature in stream_features: - command = [ - "ffprobe", - "-select_streams", - "v:{}".format(int(stream)), - "-v", - "fatal", - "-show_entries", - "stream={}".format(_feature), - "-of", - "default=noprint_wrappers=1:nokey=1", - fileloc, - "-sexagesimal", - ] - ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - out, err = ffmpeg.communicate() - if err: - print(err) - outs[_feature] = out.decode("utf-8").rstrip("\n") - - # need to get duration and nframes the old fashioned way - outs["duration"] = get_mkv_duration(fileloc) - timestamps = get_mkv_timestamps(fileloc,stream) - outs["nframes"] = len(timestamps) - - return ( - { - "file": fileloc, - "dims": (int(outs["width"]), int(outs["height"])), - "fps": float(outs["r_frame_rate"].split("/")[0]) - / float(outs["r_frame_rate"].split("/")[1]), - "duration": outs["duration"], - "pixel_format": outs["pix_fmt"], - "nframes": outs["nframes"], - }, - timestamps, - ) - -def get_mkv_duration(fileloc, stream=1): - command = [ - "ffprobe", - "-v", - "fatal", - "-show_entries", - "format=duration", - "-of", - "default=noprint_wrappers=1:nokey=1", - fileloc, - ] - - ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - out, err = ffmpeg.communicate() - if err: - print(err) - return float(out.decode("utf-8").rstrip("\n")) - - -def get_mkv_timestamps(fileloc, stream=1,threads=8): - command = [ - "ffprobe", - "-select_streams", - "v:{}".format(int(stream)), - "-v", - "fatal", - "-threads", str(threads), - "-show_entries", - "frame=pkt_pts_time", - "-of", - "csv=p=0", - fileloc, - ] - - ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - out, err = ffmpeg.communicate() - if err: - print(err) - timestamps = out.decode("utf-8").rstrip("\n").split("\n") - timestamps = np.array([float(_) for _ in timestamps]) - return timestamps - -def get_mkv_stream_names(fileloc): - stream_tag = "title" - - outs = {} - command = [ - "ffprobe", - "-v", - "fatal", - "-show_entries", - "stream_tags={}".format(stream_tag), - "-of", - "default=noprint_wrappers=1:nokey=1", - fileloc, - ] - ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - out, err = ffmpeg.communicate() - if err: - print(err) - out = out.decode("utf-8").rstrip("\n").split("\n") - - - ## !! changed the key/value order here from what JM had: (i.e. so the string name is the key, the stream is the value) - return dict(list(zip(out,np.arange(len(out))))) - - -def get_mkv_stream_tag(fileloc, stream=1, tag="K4A_START_OFFSET_NS"): - - command = [ - "ffprobe", - "-v", - "fatal", - "-show_entries", - "format_tags={}".format(tag), - "-of", - "default=noprint_wrappers=1:nokey=1", - fileloc, - ] - ffmpeg = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - out, err = ffmpeg.communicate() - if err: - print(err) - out = out.decode("utf-8").rstrip("\n") - return out \ No newline at end of file