Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Dec 18, 2023
1 parent ff29a57 commit 4b4e324
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 138 deletions.
125 changes: 0 additions & 125 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,131 +763,6 @@ def make_video(
# ------------------------ Helper classes and functions ------------------------


class NodePicker:
"""Interactive creation of track graph by looking at video frames."""

def __init__(
self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100
):
if ax is None:
ax = plt.gca()
self.ax = ax
self.canvas = ax.get_figure().canvas
self.cid = None
self._nodes = []
self.node_color = node_color
self._nodes_plot = ax.scatter(
[], [], zorder=5, s=node_size, color=node_color
)
self.edges = [[]]
self.video_filename = video_filename

if video_filename is not None:
self.video = cv2.VideoCapture(video_filename)
frame = self.get_video_frame()
ax.imshow(frame, picker=True)
ax.set_title(
"Left click to place node.\nRight click to remove node."
"\nShift+Left click to clear nodes."
"\nCntrl+Left click two nodes to place an edge"
)

self.connect()

@property
def node_positions(self):
return np.asarray(self._nodes)

def connect(self):
if self.cid is None:
self.cid = self.canvas.mpl_connect(
"button_press_event", self.click_event
)

def disconnect(self):
if self.cid is not None:
self.canvas.mpl_disconnect(self.cid)
self.cid = None

def click_event(self, event):
if not event.inaxes:
return
if (event.key not in ["control", "shift"]) & (
event.button == 1
): # left click
self._nodes.append((event.xdata, event.ydata))
if (event.key not in ["control", "shift"]) & (
event.button == 3
): # right click
self.remove_point((event.xdata, event.ydata))
if (event.key == "shift") & (event.button == 1):
self.clear()
if (event.key == "control") & (event.button == 1):
point = (event.xdata, event.ydata)
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
if len(self.edges[-1]) < 2:
self.edges[-1].append(closest_node_ind)
else:
self.edges.append([closest_node_ind])

self.redraw()

def redraw(self):
# Draw Node Circles
if len(self.node_positions) > 0:
self._nodes_plot.set_offsets(self.node_positions)
else:
self._nodes_plot.set_offsets([])

# Draw Node Numbers
self.ax.texts = []
for ind, (x, y) in enumerate(self.node_positions):
self.ax.text(
x,
y,
ind,
zorder=6,
fontsize=12,
horizontalalignment="center",
verticalalignment="center",
clip_on=True,
bbox=None,
transform=self.ax.transData,
)
# Draw Edges
self.ax.lines = [] # clears the existing lines
for edge in self.edges:
if len(edge) > 1:
x1, y1 = self.node_positions[edge[0]]
x2, y2 = self.node_positions[edge[1]]
self.ax.plot(
[x1, x2], [y1, y2], color=self.node_color, linewidth=2
)

self.canvas.draw_idle()

def remove_point(self, point):
if len(self._nodes) > 0:
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
self._nodes.pop(closest_node_ind)

def clear(self):
self._nodes = []
self.edges = [[]]
self.redraw()

def get_video_frame(self):
is_grabbed, frame = self.video.read()
if is_grabbed:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


def _fix_col_names(spatial_df):
"""Renames columns in spatial dataframe according to previous norm
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/linearization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# from spyglass.linearization.merge import LinearizedOutput
from spyglass.linearization.merge import LinearizedOutput
15 changes: 3 additions & 12 deletions src/spyglass/linearization/v0/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import cv2
import datajoint as dj
import matplotlib.pyplot as plt
import numpy as np
import pynwb
import pynwb.behavior
from tqdm import tqdm_notebook as tqdm
from track_linearization import (
get_linearized_position,
make_track_graph,
plot_graph_as_1D,
plot_track_graph,
)

from spyglass.common.common_behav import RawPosition, VideoFile
from spyglass.common.common_behav import RawPosition, VideoFile # noqa F401
from spyglass.common.common_interval import IntervalList # noqa F401
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.common.common_position import IntervalPositionInfo # noqa F401
from spyglass.settings import raw_dir, video_dir
from spyglass.settings import raw_dir
from spyglass.utils.dj_helper_fn import fetch_nwb

schema = dj.schema("linearization_v0")
Expand All @@ -29,17 +24,13 @@ class LinearizationParams(dj.Lookup):
This can help when the euclidean distances between separate arms are too
close and the previous position has some information about which arm the
animal is on.
route_euclidean_distance_scaling: How much to prefer route distances between
successive time points that are closer to the euclidean distance. Smaller
numbers mean the route distance is more likely to be close to the euclidean
distance.
"""

definition = """
linearization_param_name : varchar(80) # name for this set of parameters
---
use_hmm = 0 : int # use HMM to determine linearization
# How much to prefer route distances between successive time points that are closer to the euclidean distance. Smaller numbers mean the route distance is more likely to be close to the euclidean distance.
route_euclidean_distance_scaling = 1.0 : float # Preference for euclidean.
sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm).
# Biases the transition matrix to prefer the current track segment.
Expand Down

0 comments on commit 4b4e324

Please sign in to comment.