Skip to content

Commit

Permalink
benchmark clustering return fig
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed May 15, 2024
1 parent 1517cc9 commit 1d03eec
Showing 1 changed file with 21 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)):
ax = axs[0, count]
ax.set_title(self.cases[key]["label"])
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)

return fig

def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
if case_keys is None:
Expand All @@ -210,6 +212,8 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
if count == 2:
ax.legend()

return fig

def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):

if case_keys is None:
Expand Down Expand Up @@ -243,6 +247,8 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
fig.colorbar(im, ax=axs[0, count])
label = self.cases[key]["label"]
axs[0, count].set_title(label)

return fig

def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):

Expand Down Expand Up @@ -295,6 +301,8 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5
label = self.cases[key]["label"]
axs[0, count].set_title(label)
axs[0, count].legend()

return fig

def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):

Expand Down Expand Up @@ -353,6 +361,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
label = self.cases[key]["label"]
axs[0, count].set_title(label)
# axs[0, count].legend()

return fig

def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None):

Expand Down Expand Up @@ -384,6 +394,7 @@ def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=
fig.colorbar(im, ax=ax)
ax.set_title(k)
ax.set_ylabel("snr")
return fig

def plot_comparison_clustering(
self,
Expand Down Expand Up @@ -444,10 +455,13 @@ def plot_comparison_clustering(

plt.tight_layout(h_pad=0, w_pad=0)

return fig

def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())

figs = []
for count, key in enumerate(case_keys):
label = self.cases[key]["label"]
comp = self.get_result(key)["gt_comparison"]
Expand Down Expand Up @@ -475,13 +489,17 @@ def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no overmerged")

return figs

def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())

figs = []
for count, key in enumerate(case_keys):
label = self.cases[key]["label"]
comp = self.get_result(key)["gt_comparison"]
Expand Down Expand Up @@ -509,5 +527,8 @@ def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no over splited")

return figs

0 comments on commit 1d03eec

Please sign in to comment.