Skip to content

Commit

Permalink
more matplotlib
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed May 31, 2024
1 parent a4be1ab commit 56e492c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,13 @@
from spikeinterface.widgets import (
plot_probe_map,
plot_agreement_matrix,
plot_comparison_collision_by_similarity,
plot_unit_templates,
plot_unit_waveforms,
)
from spikeinterface.comparison.comparisontools import make_matching_events

import matplotlib.patches as mpatches

# from spikeinterface.postprocessing import get_template_extremum_channel
from spikeinterface.core import get_noise_levels

import pylab as plt
import numpy as np


from .benchmark_tools import BenchmarkStudy, Benchmark
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel

Expand Down Expand Up @@ -180,6 +169,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None, **extra_kwargs):
def plot_agreements(self, case_keys=None, figsize=(15, 15)):
if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand All @@ -193,6 +183,7 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)):
def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)

Expand All @@ -218,6 +209,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand Down Expand Up @@ -254,6 +246,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand Down Expand Up @@ -308,6 +301,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand Down Expand Up @@ -365,6 +359,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
return fig

def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None):
import pylab as plt

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)

Expand Down Expand Up @@ -407,6 +402,7 @@ def plot_comparison_clustering(

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

num_methods = len(case_keys)
fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10))
Expand Down Expand Up @@ -442,6 +438,8 @@ def plot_comparison_clustering(
ax.set_xticks([])
if i == num_methods - 1 and j == num_methods - 1:
patches = []
import matplotlib.patches as mpatches

for color, name in zip(colors, performance_names):
patches.append(mpatches.Patch(color=color, label=name))
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
Expand All @@ -460,6 +458,7 @@ def plot_comparison_clustering(
def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

figs = []
for count, key in enumerate(case_keys):
Expand Down Expand Up @@ -498,6 +497,7 @@ def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units
def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

figs = []
for count, key in enumerate(case_keys):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis


import matplotlib.pyplot as plt


class MotionInterpolationBenchmark(Benchmark):
def __init__(
self,
Expand Down Expand Up @@ -128,6 +125,7 @@ def plot_sorting_accuracy(
ax=None,
axes=None,
):
import matplotlib.pyplot as plt

if case_keys is None:
case_keys = list(self.cases.keys())
Expand All @@ -139,6 +137,7 @@ def plot_sorting_accuracy(

if mode == "ordered_accuracy":
if ax is None:

fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
Expand Down

0 comments on commit 56e492c

Please sign in to comment.