Skip to content

Commit

Permalink
Merge pull request #3 from jonahpearl/workflows
Browse files Browse the repository at this point in the history
Add basler workflow (beta)
  • Loading branch information
guitchounts authored Sep 9, 2021
2 parents 3c9613a + 882f019 commit a8ddc28
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 56 deletions.
114 changes: 114 additions & 0 deletions moseq2_ephys_sync/basler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from datetime import time
import numpy as np
import pandas as pd
import sys,os
from tqdm import tqdm
from glob import glob
import joblib
import argparse
import moseq2_extract.io.video as moseq_video
import pickle
import decord
from skimage import color

import pdb

import plotting, extract_leds, sync

def basler_workflow(base_path, save_path, num_leds, led_blink_interval, basler_chunk_size=1000, led_rois_from_file=False, overwrite_models=False):
"""
Workflow to extract led codes from a Basler mp4 file.
We know the LEDs only change once every (led_blink_interval), so we can just grab a few frames per interval.
"""

# Set up
basler_path = glob('%s/*.mp4' % base_path )[0]
vr = decord.VideoReader(basler_path, ctx=decord.cpu(0), num_threads=8)
num_frames = len(vr)
timestamps = vr.get_frame_timestamp(np.arange(0,num_frames)) # blazing fast. nframes x 2 (beginning,end)
timestamps = timestamps*2 # when basler records at 120 fps, timebase is halved :/

############### Cycle through the frame chunks to get all LED events: ###############

# Prepare to load video (use frame batches)
frame_batches = gen_batch_sequence(num_frames, basler_chunk_size, 0, offset=0)
num_chunks = len(frame_batches)
basler_led_events = []
led_roi_list = load_led_rois_from_file(base_path)
basler_led_events_path = os.path.join(base_path, 'basler_led_events.npz')
print('num_chunks = ', num_chunks)

# Do the loading
if overwrite_models or (not os.path.isfile(basler_led_events_path)):

for i in tqdm(range(num_chunks)[0:]):

print(frame_batches[i])
frame_data_chunk = color.rgb2gray(vr.get_batch(list(frame_batches[i])).asnumpy()) # appears to have memory leak issue, delete var after each iteration (see https://www.kaggle.com/leighplt/decord-videoreader, https://github.com/dmlc/decord/issues?q=is%3Aissue+is%3Aopen+memory)
batch_timestamps = timestamps[frame_batches[i], 0]

pdb.set_trace()

if i==0:
plotting.plot_video_frame(frame_data_chunk.std(axis=0),'%s/basler_frame_std.pdf' % save_path)

leds = extract_leds.get_led_data_from_rois(frame_data_chunk=frame_data_chunk,
led_roi_list=led_roi_list,
led_thresh=0.5, # since we converted to gray, all vals betw 0 and 1. "Off is around 0.1, "On" is around 0.8.
save_path=save_path)

tmp_event = extract_leds.get_events(leds, batch_timestamps)
basler_led_events.append(tmp_event)

# actual_led_nums = np.unique(tmp_event[:,1]) ## i.e. what was found in this chunk
# if np.all(actual_led_nums == range(num_leds)):
# else:
# print('Found %d LEDs found in chunk %d. Skipping... ' % (len(actual_led_nums),i))

del frame_data_chunk

basler_led_events = np.concatenate(basler_led_events)
np.savez(basler_led_events_path,led_events=basler_led_events)
else:
basler_led_events = np.load(basler_led_events_path)['led_events']

pdb.set_trace()
############### Convert the LED events to bit codes ###############
basler_led_codes, latencies = sync.events_to_codes(basler_led_events, nchannels=4, minCodeTime=(led_blink_interval-1))
basler_led_codes = np.asarray(basler_led_codes)

return basler_led_codes


### Basler HELPER FUNCTIONS ###

def load_led_rois_from_file(base_path):
fin = os.path.join(base_path, 'led_rois.pickle')
with open(fin, 'rb') as f:
led_roi_list = pickle.load(f)
return led_roi_list


def gen_batch_sequence(nframes, chunk_size, overlap, offset=0):
'''
Generates batches used to chunk videos prior to extraction.
Parameters
----------
nframes (int): total number of frames
chunk_size (int): desired chunk size
overlap (int): number of overlapping frames
offset (int): frame offset
Returns
-------
Yields list of batches
'''

seq = range(offset, nframes)
out = []
for i in range(0, len(seq) - overlap, chunk_size - overlap):
out.append(seq[i:i + chunk_size])
return out
22 changes: 14 additions & 8 deletions moseq2_ephys_sync/extract_leds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pdb


def get_led_data_from_rois(frame_data_chunk, rois, led_thresh=2e4, save_path=None):
def get_led_data_from_rois(frame_data_chunk, led_roi_list, led_thresh=2e4, save_path=None):
"""
Given pre-determined rois for LEDs, return sequences of ons and offs
Inputs:
Expand All @@ -25,12 +25,12 @@ def get_led_data_from_rois(frame_data_chunk, rois, led_thresh=2e4, save_path=Non

leds = []

for i in range(len(rois)):
for i in range(len(led_roi_list)):

led_x = 2# {get x vals}
led_y = 2# {get y vals}

led = frame_data_chunk[:,led_x,led_y].mean(axis=1) #on/off block signals
pts = led_roi_list[i]

led = frame_data_chunk[:, pts[0], pts[1]].mean(axis=1) #on/off block signals (slice returns a nframes x npts array, then you get mean of all the pts at each time)

led_on = np.where(np.diff(led) > led_thresh)[0] #rise indices
led_off = np.where(np.diff(led) < -led_thresh)[0] #fall indices
Expand Down Expand Up @@ -189,8 +189,14 @@ def get_led_data_with_stds(frame_data_chunk, num_leds = 4, chunk_num=0, led_loc=

return leds

def get_events(leds,timestamps,time_offset=0,num_leds=2):
def get_events(leds, timestamps):
"""
Convert list of led ons/offs + timestamps into list of ordered events
Inputs:
leds(np.array): num leds x num frames
timestamps (np.array): 1 x num frames
"""
## e.g. [123,1,-1 ] time was 123rd frame, channel 1 changed from on to off...

times = []
Expand All @@ -212,8 +218,8 @@ def get_events(leds,timestamps,time_offset=0,num_leds=2):


times = np.hstack(times)
channels = np.hstack(channels)
directions = np.hstack(directions)
channels = np.hstack(channels).astype('int')
directions = np.hstack(directions).astype('int')
sorting = np.argsort(times)
events = np.vstack([times[sorting],channels[sorting],directions[sorting]]).T

Expand Down
69 changes: 32 additions & 37 deletions moseq2_ephys_sync/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
from mlinsights.mlmodel import PiecewiseRegressor
from sklearn.preprocessing import KBinsDiscretizer

import mkv, arduino, ttl, sync, plotting
import mkv, arduino, ttl, sync, plotting, basler

import pdb

"""
TODO:

-- add ROIs capability to MKV workflow
-- refactor extract_leds to work with various other videos
"""


def main_function(base_path, output_dir_name, first_source, second_source, led_loc=None, led_blink_interval=5, arduino_spec=None, overwrite_models=False):
def main_function(base_path,
output_dir_name,
first_source,
second_source,
led_loc=None,
led_blink_interval=5,
arduino_spec=None,
led_rois_from_file=False,
overwrite_models=False):
"""
Uses 4-bit code sequences to create a piecewise linear model to predict first_source times from second_source times
----
Expand All @@ -33,12 +33,13 @@ def main_function(base_path, output_dir_name, first_source, second_source, led_l
second_source (str): same as first_source, but these codes are used to predict first_source.
led_loc (str): specifiy one of four corners of the movie in which to find the LEDs: topright, bottomright, topleft, bottomleft
led_blink_interval (int): interval in seconds between LED changes. Typically 5 seconds.
led_rois_from_file (bool): whether to look in base path for led roi pickle.
Outputs:
-
Notes:
- Each workflow checks for already-pre-processed data, so that the script should be pretty easy to debug.
- Basler code expects an mp4 at 120 fps. If you use 60 fps, probably need to change the minCodeTime arg in line 80 of basler.py.
"""

print(f'Running sync on {base_path} with {first_source} as first source and {second_source} as second source.')
Expand All @@ -47,6 +48,7 @@ def main_function(base_path, output_dir_name, first_source, second_source, led_l
#### SETUP ####
# Built-in params (should make dynamic)
mkv_chunk_size = 2000
basler_chunk_size = 1000 # too much larger (incl 2000) crashes O2 with 64 GB, not sure why since these chunks are only 10G.
num_leds = 4
ephys_fs = 3e4 # sampling rate in Hz

Expand All @@ -61,52 +63,41 @@ def main_function(base_path, output_dir_name, first_source, second_source, led_l
raise RuntimeError("Models already exist and overwrite_models is false!")


# Check if user is accidentally using conflicting led extraction params
if led_loc and led_rois_from_file:
raise RuntimeError("User cannot specify both led location (top right, etc) and list of exact LED ROIs!")
elif ((first_source == 'basler') or (second_source == 'basler')) and not led_rois_from_file:
raise RuntimeError("User must specify LED rois for basler workflow")

#### INDIVIDUAL DATA STREAM WORKFLOWS ####

# Deal with first source
if first_source == 'ttl':
first_source_led_codes = ttl.ttl_workflow(base_path, save_path, num_leds, led_blink_interval, ephys_fs)
elif first_source == 'mkv':
first_source_led_codes = mkv.mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size, led_loc)
first_source_led_codes = mkv.mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size, led_loc, led_rois_from_file)
elif first_source == 'arduino':
first_source_led_codes, ino_average_fs = arduino.arduino_workflow(base_path, save_path, num_leds, led_blink_interval, arduino_spec)
elif first_source == 'basler':
first_source_led_codes = basler.basler_workflow(base_path, save_path, num_leds, led_blink_interval, basler_chunk_size, led_rois_from_file, overwrite_models)

# Deal with second source
if second_source == 'ttl':
second_source_led_codes = ttl.ttl_workflow(base_path, save_path, num_leds, led_blink_interval, ephys_fs)
elif second_source == 'mkv':
second_source_led_codes = mkv.mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size, led_loc)
second_source_led_codes = mkv.mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size, led_loc, led_rois_from_file)
elif second_source == 'arduino':
second_source_led_codes, ino_average_fs = arduino.arduino_workflow(base_path, save_path, num_leds, led_blink_interval, arduino_spec)

elif second_source == 'basler':
second_source_led_codes = basler.basler_workflow(base_path, save_path, num_leds, led_blink_interval, basler_chunk_size, led_rois_from_file, overwrite_models)

# Save the codes for use later
np.savez('%s/codes.npz' % save_path, first_source_codes=first_source_led_codes, second_source_codes=second_source_led_codes)


# Visualize a small chunk of the bit codes. do you see a match?
# Codes array should have times in seconds by this point
plotting.plot_code_chunk(first_source_led_codes, second_source_led_codes, save_path)

################################# Load the ephys TTL data #####################################

ephys_ttl_path = glob('%s/**/TTL_*/' % base_path,recursive = True)[0]
channels = np.load('%s/channel_states.npy' % ephys_ttl_path)
ephys_timestamps = np.load('%s/timestamps.npy' % ephys_ttl_path)

## need to subtract the raw traces' starting timestamp from the TTL timestamps:
continuous_timestamps_path = glob('%s/**/continuous/**/timestamps.npy' % base_path,recursive = True)[0] ## load the continuous stream's timestamps
continuous_timestamps = np.load(continuous_timestamps_path)

ephys_timestamps -= continuous_timestamps[0] ## subract the first timestamp from all TTLs; this way continuous ephys can safely start at 0 samples or seconds


ephys_fs = 3e4

led_fs = 30



plotting.plot_code_chunk(first_source_led_codes, first_source, second_source_led_codes, second_source, save_path)


#### SYNCING :D ####
Expand All @@ -118,6 +109,8 @@ def main_function(base_path, output_dir_name, first_source, second_source, led_l
second_source_led_codes[:,1],
minMatch=10,maxErr=0,remove_duplicates=True ))

pdb.set_trace()

## Plot the matched codes against each other:
plotting.plot_matched_scatter(matches, save_path)

Expand Down Expand Up @@ -178,11 +171,12 @@ def main_function(base_path, output_dir_name, first_source, second_source, led_l

parser.add_argument('-path', type=str) # path to data
parser.add_argument('-o', '--output_dir_name', type=str, default='sync') # name of output folder within path
parser.add_argument('-s1', '--first_source', type=str) # ttl, mkv, arduino (txt)
parser.add_argument('-s2', '--second_source', type=str) # ttl, mkv, arduino
parser.add_argument('-s1', '--first_source', type=str) # ttl, mkv, basler (mp4 with rois), arduino (txt)
parser.add_argument('-s2', '--second_source', type=str)
parser.add_argument('--led_loc', type=str)
parser.add_argument('--led_blink_interval', type=int, default=5) # default blink every 5 seconds
parser.add_argument('--arduino_spec', type=str, help="Currently supported: fictive_olfaction, odor_on_wheel, ") # specifiy cols in arduino text file
parser.add_argument('--led_rois_from_file', action="store_true", help="Path to pickle with lists of points for led rois") # need to run separate jup notbook first to get this
parser.add_argument('--overwrite_models', action="store_true") # overwrites old models if True (1)

settings = parser.parse_args();
Expand All @@ -194,6 +188,7 @@ def main_function(base_path, output_dir_name, first_source, second_source, led_l
led_loc=settings.led_loc,
led_blink_interval=settings.led_blink_interval,
arduino_spec=settings.arduino_spec,
led_rois_from_file=settings.led_rois_from_file,
overwrite_models=settings.overwrite_models)


18 changes: 14 additions & 4 deletions moseq2_ephys_sync/mkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import json
import moseq2_extract.io.video as moseq_video
import subprocess
import pickle

import plotting, extract_leds, sync

def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size=2000, led_loc=None, led_rois=None):
def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_size=2000, led_loc=None, led_rois_from_file=False):
"""
Workflow to extract led codes from an MKV file
Expand Down Expand Up @@ -61,6 +62,9 @@ def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_s
mkv_led_events = []
print('num_chunks = ', num_chunks)

if led_rois_from_file:
led_roi_list = load_led_rois_from_file(base_path)

mkv_led_events_path = '%s_led_events.npz' % os.path.splitext(depth_path)[0]

# Do the loading
Expand All @@ -69,7 +73,7 @@ def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_s
for i in tqdm(range(num_chunks)[0:]):
# for i in [45]:

frame_data_chunk = moseq_video.load_movie_data(depth_path,
frame_data_chunk = moseq_video.load_movie_data(depth_path, # nframes, nrows, ncols
frames=frame_batches[i],
mapping=stream_names['IR'], movie_dtype=">u2", pixel_format="gray16be",
frame_size=info['dims'],timestamps=timestamps,threads=8,
Expand All @@ -78,8 +82,8 @@ def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_s
if i==0:
plotting.plot_video_frame(frame_data_chunk.std(axis=0),'%s/frame_std.pdf' % save_path)

if led_rois is not None:
leds = extract_leds.get_led_data_from_rois(frame_data_chunk=frame_data_chunk, rois=led_rois, save_path=save_path)
if led_rois_from_file:
leds = extract_leds.get_led_data_from_rois(frame_data_chunk=frame_data_chunk, led_roi_list=led_roi_list, save_path=save_path)
else:
leds = extract_leds.get_led_data_with_stds(frame_data_chunk=frame_data_chunk,
num_leds=num_leds,chunk_num=i, led_loc=led_loc, sort_by='horizontal',save_path=save_path)
Expand Down Expand Up @@ -117,6 +121,12 @@ def mkv_workflow(base_path, save_path, num_leds, led_blink_interval, mkv_chunk_s

### MKV HELPER FUNCTIONS ###

def load_led_rois_from_file(base_path):
fin = os.path.join(base_path, 'led_rois.pickle')
with open(fin, 'rb') as f:
led_roi_list = pickle.load(f, pickle.HIGHEST_PROTOCOL)
return led_roi_list

def get_mkv_info(fileloc, stream=1):
stream_features = ["width", "height", "r_frame_rate", "pix_fmt"]

Expand Down
Loading

0 comments on commit a8ddc28

Please sign in to comment.