diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 6a9c64b6..b319f55b 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -30,6 +30,23 @@ def dict_to_uuid(key) -> uuid.UUID: return uuid.UUID(hex=hashed.hexdigest()) +def fetch_stream(query, drop_pk=True): + """ + Provided a query containing data from a Stream table, + fetch and aggregate the data into one DataFrame indexed by "time" + """ + df = (query & "sample_count > 0").fetch(format="frame").reset_index() + cols2explode = [ + c for c in query.heading.secondary_attributes if query.heading.attributes[c].type == "longblob" + ] + df = df.explode(column=cols2explode) + cols2drop = ["sample_count"] + (query.primary_key if drop_pk else []) + df.drop(columns=cols2drop, inplace=True, errors="ignore") + df.rename(columns={"timestamps": "time"}, inplace=True) + df.set_index("time", inplace=True) + return df + + try: from . import streams except ImportError: diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 5057a86e..fcd10134 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -1,6 +1,6 @@ import datetime import pathlib - +import re import datajoint as dj import numpy as np import pandas as pd @@ -670,7 +670,7 @@ def make(self, key): "devices_schema_name" ), ) - device = devices_schema.ExperimentalMetadata + device = devices_schema.Environment try: # handles corrupted files - issue: https://github.com/SainsburyWellcomeCentre/aeon_mecha/issues/153 @@ -684,12 +684,18 @@ def make(self, key): logger.warning("Can't read from device.MessageLog") log_messages = pd.DataFrame() - state_messages = io_api.load( + env_states = io_api.load( root=raw_data_dir.as_posix(), reader=device.EnvironmentState, start=pd.Timestamp(chunk_start), end=pd.Timestamp(chunk_end), ) + block_states = io_api.load( + root=raw_data_dir.as_posix(), + reader=device.BlockState, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) self.insert1(key) self.Message.insert( @@ -712,13 +718,147 @@ def make(self, key): "message": r.state, "message_type": "EnvironmentState", } - for _, r in state_messages.iterrows() + for _, r in env_states.iterrows() ), skip_duplicates=True, ) +# ------------------- ENVIRONMENT -------------------- + + +@schema +class Environment(dj.Imported): + definition = """ # Experiment environments + -> Chunk + """ + + class EnvironmentState(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) timestamps + state: longblob + """ + + class BlockState(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) timestamps + pellet_ct: longblob + pellet_ct_thresh: longblob + due_time: longblob + """ + + class LightEvents(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) timestamps + channel: longblob + value: longblob + """ + + class MessageLog(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) + priority: longblob + type: longblob + message: longblob + """ + + class SubjectState(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) timestamps + id: longblob + weight: longblob + type: longblob + """ + + class SubjectVisits(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) timestamps + id: longblob + type: longblob + region: longblob + """ + + class SubjectWeight(dj.Part): + definition = """ + -> master + --- + sample_count: int # number of data points acquired from this stream for a given chunk + timestamps: longblob # (datetime) timestamps + weight: longblob + confidence: longblob + subject_id: longblob + int_id: longblob + """ + + def make(self, key): + chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end") + + # Populate the part table + raw_data_dir = Experiment.get_data_directory(key) + devices_schema = getattr( + aeon_schemas, + (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + device = devices_schema.Environment + + self.insert1(key) + + for stream_type, part_table in [ + ("EnvironmentState", self.EnvironmentState), + ("BlockState", self.BlockState), + ("LightEvents", self.LightEvents), + ("MessageLog", self.MessageLog), + ("SubjectState", self.SubjectState), + ("SubjectVisits", self.SubjectVisits), + ("SubjectWeight", self.SubjectWeight), + ]: + stream_reader = getattr(device, stream_type) + + stream_data = io_api.load( + root=raw_data_dir.as_posix(), + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + part_table.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, + }, + ignore_extra_fields=True, + ) + + # ------------------- EVENTS -------------------- + + @schema class FoodPatchEvent(dj.Imported): definition = """ # events associated with a given ExperimentFoodPatch @@ -1220,3 +1360,20 @@ def _load_legacy_subjectdata(experiment_name, data_dir, start, end): subject_data.sort_index(inplace=True) return subject_data + + +def create_chunk_restriction(experiment_name, start_time, end_time): + """ + Create a time restriction string for the chunks between the specified "start" and "end" times + """ + start_restriction = f'"{start_time}" BETWEEN chunk_start AND chunk_end' + end_restriction = f'"{end_time}" BETWEEN chunk_start AND chunk_end' + start_query = Chunk & {"experiment_name": experiment_name} & start_restriction + end_query = Chunk & {"experiment_name": experiment_name} & end_restriction + if not (start_query and end_query): + raise ValueError(f"No Chunk found between {start_time} and {end_time}") + time_restriction = ( + f'chunk_start >= "{min(start_query.fetch("chunk_start"))}"' + f' AND chunk_start < "{max(end_query.fetch("chunk_end"))}"' + ) + return time_restriction diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py new file mode 100644 index 00000000..342528f6 --- /dev/null +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -0,0 +1,343 @@ +import datetime +import datajoint as dj +import pandas as pd +import json +import numpy as np + +from aeon.analysis import utils as analysis_utils + +from aeon.dj_pipeline import get_schema_name, fetch_stream +from aeon.dj_pipeline import acquisition, tracking, streams +from aeon.dj_pipeline.analysis.visit import ( + get_maintenance_periods, + filter_out_maintenance_periods, +) + +schema = dj.schema(get_schema_name("analysis")) + + +@schema +class Block(dj.Manual): + definition = """ + -> acquisition.Experiment + block_start: datetime(6) + --- + block_end=null: datetime(6) + """ + + +@schema +class BlockAnalysis(dj.Computed): + definition = """ + -> Block + --- + block_duration: float # (hour) + """ + + key_source = Block & "block_end IS NOT NULL" + + class Patch(dj.Part): + definition = """ + -> master + patch_name: varchar(36) # e.g. Patch1, Patch2 + --- + pellet_count: int + pellet_timestamps: longblob + wheel_cumsum_distance_travelled: longblob # wheel's cumulative distance travelled + wheel_timestamps: longblob + patch_threshold: longblob + patch_threshold_timestamps: longblob + patch_rate: float + """ + + class Subject(dj.Part): + definition = """ + -> master + subject_name: varchar(32) + --- + weights: longblob + weight_timestamps: longblob + position_x: longblob + position_y: longblob + position_likelihood: longblob + position_timestamps: longblob + cumsum_distance_travelled: longblob # subject's cumulative distance travelled + """ + + def make(self, key): + """ + Restrict, fetch and aggregate data from different streams to produce intermediate data products + at a per-block level (for different patches and different subjects) + 1. Query data for all chunks within the block + 2. Fetch streams, filter by maintenance period + 3. Fetch subject position data (SLEAP) + 4. Aggregate and insert into the table + """ + block_start, block_end = (Block & key).fetch1("block_start", "block_end") + + chunk_restriction = acquisition.create_chunk_restriction( + key["experiment_name"], block_start, block_end + ) + + self.insert1({**key, "block_duration": (block_end - block_start).total_seconds() / 3600}) + + # Patch data - TriggerPellet, DepletionState, Encoder (distancetravelled) + # For wheel data, downsample by 50x - 10Hz + wheel_downsampling_factor = 50 + + maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end) + + patch_query = ( + streams.UndergroundFeeder.join(streams.UndergroundFeeder.RemovalTime, left=True) + & key + & f'"{block_start}" >= underground_feeder_install_time' + & f'"{block_end}" < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + patch_keys, patch_names = patch_query.fetch("KEY", "underground_feeder_name") + + for patch_key, patch_name in zip(patch_keys, patch_names): + delivered_pellet_df = fetch_stream( + streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction + )[block_start:block_end] + depletion_state_df = fetch_stream( + streams.UndergroundFeederDepletionState & patch_key & chunk_restriction + )[block_start:block_end] + encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ + block_start:block_end + ] + # filter out maintenance period based on logs + pellet_df = filter_out_maintenance_periods( + delivered_pellet_df, + maintenance_period, + block_end, + dropna=True, + ) + depletion_state_df = filter_out_maintenance_periods( + depletion_state_df, + maintenance_period, + block_end, + dropna=True, + ) + encoder_df = filter_out_maintenance_periods( + encoder_df, maintenance_period, block_end, dropna=True + ) + + encoder_df["distance_travelled"] = analysis_utils.distancetravelled(encoder_df.angle) + + patch_rate = depletion_state_df.rate.unique() + assert len(patch_rate) == 1 # expects a single rate for this block + patch_rate = patch_rate[0] + + self.Patch.insert1( + { + **key, + "patch_name": patch_name, + "pellet_count": len(pellet_df), + "pellet_timestamps": pellet_df.index.values, + "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ + ::wheel_downsampling_factor + ], + "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], + "patch_threshold": depletion_state_df.threshold.values, + "patch_threshold_timestamps": depletion_state_df.index.values, + "patch_rate": patch_rate, + } + ) + + # Subject data + subject_events_query = acquisition.Environment.SubjectState & key & chunk_restriction + subject_events_df = fetch_stream(subject_events_query) + + subject_names = set(subject_events_df.id) + for subject_name in subject_names: + # positions - query for CameraTop, identity_name matches subject_name, + pos_query = ( + streams.SpinnakerVideoSource + * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", anchor_part="part_name") + * tracking.SLEAPTracking.Part + & { + "spinnaker_video_source_name": "CameraTop", + "identity_name": subject_name, + } + & chunk_restriction + ) + pos_df = fetch_stream(pos_query)[block_start:block_end] + pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) + + position_diff = np.sqrt( + (np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float)))) + ) + cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) + + # weights + weight_query = acquisition.Environment.SubjectWeight & key & chunk_restriction + weight_df = fetch_stream(weight_query)[block_start:block_end] + weight_df.query(f"subject_id == '{subject_name}'", inplace=True) + + self.Subject.insert1( + { + **key, + "subject_name": subject_name, + "weights": weight_df.weight.values, + "weight_timestamps": weight_df.index.values, + "position_x": pos_df.x.values, + "position_y": pos_df.y.values, + "position_likelihood": pos_df.likelihood.values, + "position_timestamps": pos_df.index.values, + "cumsum_distance_travelled": cumsum_distance_travelled, + } + ) + + +@schema +class BlockSubjectAnalysis(dj.Computed): + definition = """ + -> BlockAnalysis + """ + + class Patch(dj.Part): + definition = """ + -> master + -> BlockAnalysis.Patch + -> BlockAnalysis.Subject + --- + in_patch_timestamps: longblob # timestamps in which a particular subject is spending time at a particular patch + in_patch_time: float # total seconds spent in this patch for this block + pellet_count: int + pellet_timestamps: longblob + wheel_distance_travelled: longblob # wheel's cumulative distance travelled + wheel_timestamps: longblob + cumulative_sum_preference: longblob + windowed_sum_preference: longblob + """ + + def make(self, key): + pass + + +@schema +class BlockPlots(dj.Computed): + definition = """ + -> BlockAnalysis + --- + subject_positions_plot: longblob + subject_weights_plot: longblob + patch_distance_travelled_plot: longblob + """ + + def make(self, key): + import plotly.graph_objs as go + + # For position data , set confidence threshold to return position values and downsample by 5x + conf_thresh = 0.9 + downsampling_factor = 5 + + # Make plotly plots + weight_fig = go.Figure() + pos_fig = go.Figure() + for subject_data in (BlockAnalysis.Subject & key).fetch(as_dict=True): + weight_fig.add_trace( + go.Scatter( + x=subject_data["weight_timestamps"], + y=subject_data["weights"], + mode="lines", + name=subject_data["subject_name"], + ) + ) + mask = subject_data["position_likelihood"] > conf_thresh + pos_fig.add_trace( + go.Scatter3d( + x=subject_data["position_x"][mask][::downsampling_factor], + y=subject_data["position_y"][mask][::downsampling_factor], + z=subject_data["position_timestamps"][mask][::downsampling_factor], + mode="lines", + name=subject_data["subject_name"], + ) + ) + + wheel_fig = go.Figure() + for patch_data in (BlockAnalysis.Patch & key).fetch(as_dict=True): + wheel_fig.add_trace( + go.Scatter( + x=patch_data["wheel_timestamps"][::2], + y=patch_data["cumulative_distance_travelled"][::2], + mode="lines", + name=patch_data["patch_name"], + ) + ) + + # insert figures as json-formatted plotly plots + self.insert1( + { + **key, + "subject_positions_plot": json.loads(pos_fig.to_json()), + "subject_weights_plot": json.loads(weight_fig.to_json()), + "patch_distance_travelled_plot": json.loads(wheel_fig.to_json()), + } + ) + + +@schema +class BlockDetection(dj.Computed): + definition = """ + -> acquisition.Chunk + """ + + def make(self, key): + """ + On a per-chunk basis, check for the presence of new block, insert into Block table + """ + # find the 0s + # that would mark the start of a new block + # if the 0 is the first index - look back at the previous chunk + # if the previous timestamp belongs to a previous epoch -> block_end is the previous timestamp + # else block_end is the timestamp of this 0 + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + exp_key = {"experiment_name": key["experiment_name"]} + # only consider the time period between the last block and the current chunk + previous_block = Block & exp_key & f"block_start <= '{chunk_start}'" + if previous_block: + previous_block_key = previous_block.fetch("KEY", limit=1, order_by="block_start DESC")[0] + previous_block_start = previous_block_key["block_start"] + else: + previous_block_key = None + previous_block_start = (acquisition.Chunk & exp_key).fetch( + "chunk_start", limit=1, order_by="chunk_start" + )[0] + + chunk_restriction = acquisition.create_chunk_restriction( + key["experiment_name"], previous_block_start, chunk_end + ) + + block_query = acquisition.Environment.BlockState & chunk_restriction + block_df = fetch_stream(block_query)[previous_block_start:chunk_end] + + block_ends = block_df[block_df.pellet_ct.diff() < 0] + + block_entries = [] + for idx, block_end in enumerate(block_ends.index): + if idx == 0: + if previous_block_key: + # if there is a previous block - insert "block_end" for the previous block + previous_pellet_time = block_df[:block_end].index[-2] + previous_epoch = ( + acquisition.Epoch.join(acquisition.EpochEnd, left=True) + & exp_key + & f"'{previous_pellet_time}' BETWEEN epoch_start AND IFNULL(epoch_end, '2200-01-01')" + ).fetch1("KEY") + current_epoch = ( + acquisition.Epoch.join(acquisition.EpochEnd, left=True) + & exp_key + & f"'{block_end}' BETWEEN epoch_start AND IFNULL(epoch_end, '2200-01-01')" + ).fetch1("KEY") + + previous_block_key["block_end"] = ( + block_end if current_epoch == previous_epoch else previous_pellet_time + ) + Block.update1(previous_block_key) + else: + block_entries[-1]["block_end"] = block_end + block_entries.append({**exp_key, "block_start": block_end, "block_end": None}) + + Block.insert(block_entries) + self.insert1(key) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 3c7e7be7..babae2fb 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,11 +1,13 @@ import datetime - import datajoint as dj import pandas as pd +import numpy as np +from collections import deque from aeon.analysis import utils as analysis_utils -from .. import acquisition, get_schema_name, lab, qc, tracking +from aeon.dj_pipeline import get_schema_name, fetch_stream +from aeon.dj_pipeline import acquisition, lab, qc, tracking schema = dj.schema(get_schema_name("analysis")) @@ -182,3 +184,58 @@ def ingest_environment_visits(experiment_names: list | None = None): }, skip_duplicates=True, ) + + +def get_maintenance_periods(experiment_name, start, end): + # get states from acquisition.Environment.EnvironmentState + chunk_restriction = acquisition.create_chunk_restriction(experiment_name, start, end) + state_query = ( + acquisition.Environment.EnvironmentState & {"experiment_name": experiment_name} & chunk_restriction + ) + env_state_df = fetch_stream(state_query)[start:end] + if env_state_df.empty: + return deque([]) + + env_state_df.reset_index(inplace=True) + env_state_df = env_state_df[env_state_df["state"].shift() != env_state_df["state"]].reset_index( + drop=True + ) # remove duplicates and keep the first one + # An experiment starts with visit start (anything before the first maintenance is experiment) + # Delete the row if it starts with "Experiment" + if env_state_df.iloc[0]["state"] == "Experiment": + env_state_df.drop(index=0, inplace=True) # look for the first maintenance + if env_state_df.empty: + return deque([]) + + # Last entry is the visit end + if env_state_df.iloc[-1]["state"] == "Maintenance": + log_df_end = pd.DataFrame({"time": [pd.Timestamp(end)], "state": ["VisitEnd"]}) + env_state_df = pd.concat([env_state_df, log_df_end]) + env_state_df.reset_index(drop=True, inplace=True) + + maintenance_starts = env_state_df.loc[env_state_df["state"] == "Maintenance", "time"].values + maintenance_ends = env_state_df.loc[env_state_df["state"] != "Maintenance", "time"].values + + return deque( + [ + (pd.Timestamp(start), pd.Timestamp(end)) + for start, end in zip(maintenance_starts, maintenance_ends) + ] + ) # queue object. pop out from left after use + + +def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna=False): + maint_period = maintenance_period.copy() + while maint_period: + (maintenance_start, maintenance_end) = maint_period[0] + if end_time < maintenance_start: # no more maintenance for this date + break + maintenance_filter = (data_df.index >= maintenance_start) & (data_df.index <= maintenance_end) + data_df[maintenance_filter] = np.nan + if end_time >= maintenance_end: # remove this range + maint_period.popleft() + else: + break + if dropna: + data_df.dropna(inplace=True) + return data_df diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index f8d6f23f..f66094f6 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -1,13 +1,18 @@ import datetime -from collections import deque from datetime import time import datajoint as dj import numpy as np import pandas as pd -from .. import acquisition, get_schema_name, lab, tracking -from .visit import Visit, VisitEnd +from aeon.dj_pipeline import get_schema_name +from aeon.dj_pipeline import acquisition, lab, tracking +from aeon.dj_pipeline.analysis.visit import ( + Visit, + VisitEnd, + get_maintenance_periods, + filter_out_maintenance_periods, +) logger = dj.logger schema = dj.schema(get_schema_name("analysis")) @@ -600,59 +605,3 @@ def make(self, key): "pellet_count": len(patch.loc[wheel_start:wheel_end]), } ) - - -def get_maintenance_periods(experiment_name, start, end): - # get logs from acquisition.ExperimentLog - query = ( - acquisition.ExperimentLog.Message.proj("message") - & {"experiment_name": experiment_name} - & 'message IN ("Maintenance", "Experiment")' - & f'message_time BETWEEN "{start}" AND "{end}"' - ) - - if len(query) == 0: - return None - - log_df = query.fetch(format="frame", order_by="message_time").reset_index() - log_df = log_df[log_df["message"].shift() != log_df["message"]].reset_index( - drop=True - ) # remove duplicates and keep the first one - - # An experiment starts with visit start (anything before the first maintenance is experiment) - # Delete the row if it starts with "Experiment" - if log_df.iloc[0]["message"] == "Experiment": - log_df.drop(index=0, inplace=True) # look for the first maintenance - - # Last entry is the visit end - if log_df.iloc[-1]["message"] == "Maintenance": - log_df_end = log_df.tail(1) - log_df_end["message_time"], log_df_end["message"] = ( - pd.Timestamp(end), - "VisitEnd", - ) - log_df = pd.concat([log_df, log_df_end]) - log_df.reset_index(drop=True, inplace=True) - - start_timestamps = log_df.loc[log_df["message"] == "Maintenance", "message_time"].values - end_timestamps = log_df.loc[log_df["message"] != "Maintenance", "message_time"].values - - return deque( - [(pd.Timestamp(start), pd.Timestamp(end)) for start, end in zip(start_timestamps, end_timestamps)] - ) # queue object. pop out from left after use - - -def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna=False): - while maintenance_period: - (maintenance_start, maintenance_end) = maintenance_period[0] - if end_time < maintenance_start: # no more maintenance for this date - break - maintenance_filter = (data_df.index >= maintenance_start) & (data_df.index <= maintenance_end) - data_df[maintenance_filter] = np.nan - if end_time >= maintenance_end: # remove this range - maintenance_period.popleft() - else: - break - if dropna: - data_df.dropna(inplace=True) - return data_df diff --git a/aeon/dj_pipeline/docs/notebooks/social01_block_analysis.ipynb b/aeon/dj_pipeline/docs/notebooks/social01_block_analysis.ipynb new file mode 100644 index 00000000..9c012913 --- /dev/null +++ b/aeon/dj_pipeline/docs/notebooks/social01_block_analysis.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import datajoint as dj" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create VirtualModule to access `aeon_test_analysis` schema\n", + "Currently, the analysis is on `aeon_test_`, will move to `aeon_` soon (once ready for production)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "analysis_vm = dj.create_virtual_module('aeon_test_analysis', 'aeon_test_analysis')" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Browse Block and BlockAnalysis" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "analysis_vm.Block()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "analysis_vm.BlockAnalysis()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Pick a block of interest\n", + "block_key = {'experiment_name': 'social0.1-aeon3', 'block_start': '2023-11-30 18:49:05.001984'}" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "analysis_vm.BlockAnalysis & block_key" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "analysis_vm.BlockAnalysis.Patch & block_key" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "analysis_vm.BlockAnalysis.Subject & block_key" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Fetch back patch data for the block" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "block_patch_data = (analysis_vm.BlockAnalysis.Patch & block_key).fetch(as_dict=True)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Fetch back subject data for the block" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "block_subject_data = (analysis_vm.BlockAnalysis.Subject & block_key).fetch(as_dict=True)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py new file mode 100644 index 00000000..ee1d7356 --- /dev/null +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -0,0 +1,110 @@ +"""Jan 2024 +Cloning and archiving schemas and data for experiment 0.2. +The pipeline code associated with this archived data pipeline is here +https://github.com/SainsburyWellcomeCentre/aeon_mecha/releases/tag/dj_exp02_stable +""" +import os +import inspect +import datajoint as dj +from datajoint_utilities.dj_data_copy import db_migration +from datajoint_utilities.dj_data_copy.pipeline_cloning import ClonedPipeline + +logger = dj.logger +os.environ["DJ_SUPPORT_FILEPATH_MANAGEMENT"] = "TRUE" + +source_db_prefix = "aeon_" +target_db_prefix = "aeon_archived_exp02_" + +schema_name_mapper = { + source_db_prefix + schema_name: target_db_prefix + schema_name + for schema_name in ("lab", "subject", "acquisition", "tracking", "qc", "analysis", "report") +} + +restriction = [{"experiment_name": "exp0.2-r0"}, {"experiment_name": "social0-r1"}] + +table_block_list = {} + +batch_size = None + + +def clone_pipeline(): + diagram = None + for orig_schema_name in schema_name_mapper: + virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) + if diagram is None: + diagram = dj.Diagram(virtual_module) + else: + diagram += dj.Diagram(virtual_module) + + cloned_pipeline = ClonedPipeline(diagram, schema_name_mapper, verbose=True) + cloned_pipeline.instantiate_pipeline(prompt=False) + + +def data_copy(restriction, table_block_list, batch_size=None): + for orig_schema_name, cloned_schema_name in schema_name_mapper.items(): + orig_schema = dj.create_virtual_module(orig_schema_name, orig_schema_name) + cloned_schema = dj.create_virtual_module(cloned_schema_name, cloned_schema_name) + + db_migration.migrate_schema( + orig_schema, + cloned_schema, + restriction=restriction, + table_block_list=table_block_list.get(cloned_schema_name, []), + allow_missing_destination_tables=True, + force_fetch=False, + batch_size=batch_size, + ) + + +def validate(): + """ + Validation of schemas migration + 1. for the provided list of schema names - validate all schemas have been migrated + 2. for each schema - validate all tables have been migrated + 3. for each table, validate all entries have been migrated + """ + missing_schemas = [] + missing_tables = {} + missing_entries = {} + + for orig_schema_name, cloned_schema_name in schema_name_mapper.items(): + logger.info(f"Validate schema: {orig_schema_name}") + source_vm = dj.create_virtual_module(orig_schema_name, orig_schema_name) + + try: + target_vm = dj.create_virtual_module(cloned_schema_name, cloned_schema_name) + except dj.errors.DataJointError: + missing_schemas.append(orig_schema_name) + continue + + missing_tables[orig_schema_name] = [] + missing_entries[orig_schema_name] = {} + + for attr in dir(source_vm): + obj = getattr(source_vm, attr) + if isinstance(obj, dj.user_tables.UserTable) or ( + inspect.isclass(obj) and issubclass(obj, dj.user_tables.UserTable) + ): + source_tbl = obj + try: + target_tbl = getattr(target_vm, attr) + except AttributeError: + missing_tables[orig_schema_name].append(source_tbl.table_name) + continue + logger.info(f"\tValidate entry count: {source_tbl.__name__}") + source_entry_count = len(source_tbl()) + target_entry_count = len(target_tbl()) + missing_entries[orig_schema_name][source_tbl.__name__] = { + "entry_count_diff": source_entry_count - target_entry_count, + "db_size_diff": source_tbl().size_on_disk - target_tbl().size_on_disk, + } + + return { + "missing_schemas": missing_schemas, + "missing_tables": missing_tables, + "missing_entries": missing_entries, + } + + +if __name__ == "__main__": + print("This is not meant to be run as a script (yet)") diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 0e3b4ec2..d2460fcd 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -232,8 +232,6 @@ class SLEAPTracking(dj.Imported): -> acquisition.Chunk -> streams.SpinnakerVideoSource -> TrackingParamSet - --- - sample_count: int # number of data points acquired from this stream for a given chunk """ class PoseIdentity(dj.Part): @@ -251,6 +249,7 @@ class Part(dj.Part): -> master.PoseIdentity part_name: varchar(16) --- + sample_count: int # number of data points acquired from this stream for a given chunk x: longblob y: longblob likelihood: longblob @@ -294,7 +293,7 @@ def make(self, key): ) if not len(pose_data): - self.insert1({**key, "sample_count": 0}) + self.insert1(key) return # Find the config file for the SLEAP model @@ -317,8 +316,7 @@ def make(self, key): class_names = stream_reader.get_class_names(config_file) # ingest parts and classes - sample_count = 0 - class_entries, part_entries = [], [] + pose_identity_entries, part_entries = [], [] for class_idx in set(pose_data["class"].values.astype(int)): class_position = pose_data[pose_data["class"] == class_idx] for part in set(class_position.part.values): @@ -332,12 +330,12 @@ def make(self, key): "x": part_position.x.values, "y": part_position.y.values, "likelihood": part_position.part_likelihood.values, + "sample_count": len(part_position.index.values), } ) if part == anchor_part: class_likelihood = part_position.class_likelihood.values - sample_count = len(part_position.index.values) - class_entries.append( + pose_identity_entries.append( { **key, "identity_idx": class_idx, @@ -347,41 +345,10 @@ def make(self, key): } ) - self.insert1({**key, "sample_count": sample_count}) - self.Class.insert(class_entries) + self.insert1(key) + self.PoseIdentity.insert(pose_identity_entries) self.Part.insert(part_entries) - @classmethod - def get_object_position( - cls, - experiment_name, - subject_name, - start, - end, - camera_name="CameraTop", - tracking_paramset_id=1, - in_meter=False, - ): - table = ( - cls.Class.proj(part_name="anchor_part") * cls.Part * acquisition.Chunk.proj("chunk_end") - & {"experiment_name": experiment_name} - & {"tracking_paramset_id": tracking_paramset_id} - & (streams.SpinnakerVideoSource & {"spinnaker_video_source_name": camera_name}) - ) - - return _get_position( - table, - object_attr="class_name", - object_name=subject_name, - start_attr="chunk_start", - end_attr="chunk_end", - start=start, - end=end, - fetch_attrs=["timestamps", "x", "y", "likelihood"], - attrs_to_scale=["position_x", "position_y"], - scale_factor=pixel_scale if in_meter else 1, - ) - # ---------- HELPER ------------------ diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 14d76b9f..e8610747 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -70,12 +70,32 @@ def feeder(pattern): def beam_break(pattern): """Beam break events for pellet detection.""" - return {"BeamBreak": _reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected")} + return {"BeamBreak": _reader.BitmaskEvent(f"{pattern}_32_*", 34, "PelletDetected")} def deliver_pellet(pattern): """Pellet delivery commands.""" - return {"DeliverPellet": _reader.BitmaskEvent(f"{pattern}_35_*", 0x80, "TriggerPellet")} + return {"DeliverPellet": _reader.BitmaskEvent(f"{pattern}_35_*", 1, "TriggerPellet")} + + +def pellet_manual_delivery(pattern): + """Manual pellet delivery.""" + return {"ManualDelivery": _reader.Harp(f"{pattern}_201_*", ["manual_delivery"])} + + +def missed_pellet(pattern): + """Missed pellet delivery.""" + return {"MissedPellet": _reader.Harp(f"{pattern}_202_*", ["missed_pellet"])} + + +def pellet_retried_delivery(pattern): + """Retry pellet delivery.""" + return {"RetriedDelivery": _reader.Harp(f"{pattern}_203_*", ["retried_delivery"])} + + +def pellet_depletion_state(pattern): + """Pellet delivery state.""" + return {"DepletionState": _reader.Csv(f"{pattern}_State_*", ["threshold", "offset", "rate"])} def pellet_manual_delivery(pattern):